In [None]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [None]:
class TGSSaltTrainDataset(Dataset):
    def __init__(self, image_dir,mask_dir,depth_csv,train_csv):
        self.image_dir=image_dir
        self.mask_dir=mask_dir
        
        depth=pd.read_csv(depth_csv)
        depth["z"]=(depth["z"]-depth["z"].min())/(depth["z"].max()-depth["z"].min())
        
        self.input = pd.read_csv(train_csv)
        self.input.drop(['rle_mask'],axis=1,inplace=True)
        self.input = self.input.merge(depth,how="left",on="id")
        
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
        ])
        
    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir,self.input.iloc[idx,0]+".png")
        image = self.transform(Image.open(img_name))
        msk_name = os.path.join(self.mask_dir,self.input.iloc[idx,0]+".png")
        mask = self.transform(Image.open(msk_name))
        depth = self.input.iloc[idx,1].reshape(1)
        return image,depth,mask

In [None]:
train_dataset = TGSSaltTrainDataset("../input/train/images","../input/train/masks","../input/depths.csv","../input/train.csv")

In [None]:
batch_size = 50
validation_split = 0.2
shuffle_dataset = True
random_seed= 42
# Creating data indices for training and validation splits:
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
validation_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,sampler=validation_sampler)

In [17]:
class TGSUNetModel(nn.Module):        
    def __init__(self):
        super(TGSUNetModel,self).__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
        
        self.maxpool1 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        
        self.down2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())
        
        self.maxpool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        
        self.down3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        
        self.maxpool3 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        
        self.down4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())
        
        self.maxpool4 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        
        self.center = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU())
        
        self.up11 = nn.Sequential(
            nn.ConvTranspose2d(1024,512,kernel_size=(2, 2), stride=2, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU())
        
        self.cropcat1 = self.CropCat
        
        self.up12 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())

        self.up21 = nn.Sequential(
            nn.ConvTranspose2d(512,256,kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(256),
            nn.ReLU())
        
        self.cropcat2 = self.CropCat
            
        self.up22 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        
        self.up31 = nn.Sequential(
            nn.ConvTranspose2d(256,128,kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU())
        
        self.cropcat3 = self.CropCat
            
        self.up32 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())
        
        self.up41 = nn.Sequential(
            nn.ConvTranspose2d(128,64,kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU())
        
        self.cropcat4 = self.CropCat
            
        self.up42 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
            
        # 1x1 convolution at the last layer
        # Different from the paper is the output size here
        self.output_seg_map = nn.Sequential(
            nn.Conv2d(64,1, kernel_size=(3, 3), padding=1, stride=1),
            nn.Sigmoid())
        
        self.linear = nn.Sequential(
            nn.BatchNorm1d(101*101),
            nn.Linear(101*101,101*101),
            nn.ReLU())
        
        self.classifier = nn.Sequential(
            nn.Linear(64,2),
            nn.BatchNorm1d(2),
            nn.Softmax())
            
    def CropCat(self, upsampled, bypass):
        """
         Crop y to the (h, w) of x and concat them.
         Used for the expansive path.
        Returns:
            The concatenated tensor
        """
        c = (bypass.size()[2] - upsampled.size()[2]) // 2
        bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
        
    def forward(self,x):
        down1 = self.down1(x)
        output = self.maxpool1(down1)
        down2 = self.down2(output)
        output = self.maxpool2(down2)
        down3 = self.down3(output)
        output = self.maxpool3(down3)
        down4 = self.down4(output)
        output = self.maxpool4(down4)
        output = self.center(output)
        output = self.up11(output)
        output = self.cropcat1(output,down4)
        output = self.up12(output)
        output = self.up21(output)
        output = self.cropcat2(output,down3)
        output = self.up22(output)
        output = self.up31(output)
        output = self.cropcat3(output,down2)
        output = self.up32(output)
        output = self.up41(output)
        output = self.cropcat4(output,down1)
        output = self.up42(output)
#         output = self.output_seg_map(output)
        output = output.view(-1,101*101)
        output = self.linear(output)
        output = output.view(-1,64)
        output = self.classifier(output)
        return output

In [None]:
class BinaryCrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        """
        Binary cross entropy loss 2D
        Args:
            weight:
            size_average:
        """
        super(BinaryCrossEntropyLoss2d, self).__init__()
        self.bce_loss = nn.BCELoss(weight, size_average)
        if torch.cuda.is_available():
            self.bce_loss = self.bce_loss.cuda()

    def forward(self, pred, target):
#         pred = F.sigmoid(pred)
        pred = pred.view(-1)
        target = target.view(-1)
        return self.bce_loss(pred, target)

In [None]:
class SoftDiceLoss(nn.Module):
    def __init__(self):
        super(SoftDiceLoss, self).__init__()
    def forward(self, pred, target):
        smooth = 1
        num = target.size(0)
        pred = pred.max(1,keepdim=True)[1].float()
        pred = pred.view(num, -1)
        target = target.view(num, -1)
        intersection = (pred * target)
        score = 2. * (intersection.sum(1) + smooth) / (pred.sum(1) + target.sum(1) + smooth)
        score = 1 - score.sum() / num
        return score

In [None]:
def dice_coeff(pred, target):
    smooth = 1.
    num = target.size(0)
    pred = pred.view(num, -1)  # Flatten
    target = target.view(num, -1)  # Flatten
    intersection = (pred * target)
    score = (2. * intersection.sum(1) + smooth).float() / (pred.sum(1) + target.sum(1) + smooth).float()
    return score.sum()/num

In [None]:
def bce_dice_loss(y_pred,y_true):
    return nn.CrossEntropyLoss()(y_pred,y_true)+0.5*SoftDiceLoss()(y_pred,y_true.float())

In [18]:
model=TGSUNetModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

In [None]:
def validate():
    total_loss = 0
    total_accuracy = 0
    model.eval()
    for batch_idx, (data,depth,target) in enumerate(validation_loader):
        target = target.long()
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        # forward
        output = model(data)
        predict = output.max(1,keepdim=True)[1]
        # backward + optimize
        loss = criterion(output, target.view(-1))
        # print statistics
        accuracy = dice_coeff(predict.view_as(target),target)
        total_accuracy+=accuracy.item()
        total_loss+=loss.item()
    print('Validation Loss: {:.5f} Validation Accuracy: {:.5f}'.format(total_loss*batch_size/len(val_indices),total_accuracy*batch_size/len(val_indices)))

In [None]:
def train():
    epoch=1
    while True:
        total_loss = 0
        total_accuracy = 0
        model.train()
        exp_lr_scheduler.step()
        print(exp_lr_scheduler.get_lr())
        for batch_idx, (data,depth,target) in enumerate(train_loader):
            target = target.long()
            if torch.cuda.is_available():
                data = data.cuda()
                depth = depth.cuda()
                target = target.cuda()
            # forward
            output = model(data)
            predict = output.max(1,keepdim=True)[1]
            # backward + optimize
            loss = criterion(output, target.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print statistics
            accuracy = dice_coeff(predict.view_as(target),target)
            total_accuracy+=accuracy.item()
            total_loss+=loss.item()
            print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.5f}\tAccuracy: {:.5f}'.format(epoch, (batch_idx + 1) * len(data), len(train_indices),100*(batch_idx + 1)* len(data) / len(train_indices), loss.item(),accuracy))
        print('Train Loss: {:.5f} Train Accuracy: {:.5f}'.format(total_loss*batch_size/len(train_indices),total_accuracy*batch_size/len(train_indices)))
        validate()
        epoch+=1

In [None]:
train()

[0.001]


In [None]:
state = {'epoch': 108+1, 'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(), 'losslogger': criterion, }
torch.save(state, 'checkpoint.pth.tar')