## Imports

In [1]:
import numpy as np
import pandas as pd

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
import cv2
import matplotlib.pyplot as plt

import torchinfo

from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

## Config

In [2]:
device = "cuda" if torch.cuda.is_available() else 'cpu'
args = {
    "fold" : 0,
    "epoches" : 20,
    "batch_size" : 1,
    "start_lr" : 4e-5,
}


## Prepare Dataset

In [3]:
df = pd.read_csv('/kaggle/input/hubmap-organ-segmentation/train.csv')
df.sample(5)

Unnamed: 0,id,organ,data_source,img_height,img_width,pixel_size,tissue_thickness,rle,age,sex
214,28940,kidney,HPA,3000,3000,0.4,4,1136037 3 1139019 29 1142004 44 1144998 50 114...,56.0,Female
41,13396,largeintestine,HPA,3000,3000,0.4,4,1306024 34 1309000 2 1309015 50 1311998 6 1312...,83.0,Male
185,26982,prostate,HPA,3000,3000,0.4,4,1638919 12 1641918 14 1644910 24 1647909 27 16...,55.0,Male
75,16728,largeintestine,HPA,3000,3000,0.4,4,856875 11 859864 38 862861 46 865859 53 868855...,84.0,Male
165,24833,spleen,HPA,3000,3000,0.4,4,1956905 6 1959904 7 1962903 8 1965902 9 196890...,50.0,Female


In [4]:
kf = KFold(n_splits=5)
df['fold'] = -1
for idx, (train_idx, valid_idx) in enumerate(kf.split(X=df)):
    df.loc[valid_idx, 'fold'] = idx

In [5]:
df['fold'].value_counts()

0    71
1    70
2    70
3    70
4    70
Name: fold, dtype: int64

In [6]:
class segDataset(Dataset):   
    def __init__(self, df, augment):
        self.df = df              
        self.augment = augment    

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, index): 
        d = self.df.iloc[index]
        id = d['id']
        height = d['img_height']
        width = d['img_width']

        image = cv2.imread(f'/kaggle/input/hubmap-data/train_image/{id}.png')                       
        mask  = cv2.imread(f'/kaggle/input/hubmap-data/train_mask/{id}.png', cv2.IMREAD_GRAYSCALE)  
        mask = np.stack([mask], axis=0)
        image = image / 255.  
        mask  = mask  / 255. 

        out = {}
        out['image'] = torch.tensor(image).permute(2,0,1).float() # h, w, c -> c, h, w
        out['mask']  = torch.tensor(mask>0.5).float()
        
        return out


## Model

In [7]:
class UNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, padding = "same"),
            nn.ReLU(),
            nn.Conv2d(in_channels = 64, out_channels= 64, kernel_size = 3, padding = "same"),
            nn.ReLU()
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, padding = "same"),
            nn.ReLU(),
            nn.Conv2d(in_channels = 128, out_channels= 128, kernel_size = 3, padding = "same"),
            nn.ReLU()
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, padding = "same"),
            nn.ReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, padding = "same"),
            nn.ReLU()
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, padding = "same"),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, padding = "same"),
            nn.ReLU()
        )

    def forward(self, batch):
        x = batch['image']
        x = self.block1(x)
        x = self.maxpool(x)
        x = self.block2(x)
        x = self.maxpool(x)
        x = self.block3(x)
        x = self.maxpool(x)
        x = self.block4(x)

        return x


In [8]:
class UNetDecoder(nn.Module):
    def __init__(self,):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners=True)
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels = 512, out_channels=256, kernel_size = 3, padding = "same"),
            nn.ReLU(),
            nn.Conv2d(in_channels = 256, out_channels=256, kernel_size = 3, padding = "same"),
            nn.ReLU()
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels = 256, out_channels=128, kernel_size = 3, padding = "same"),
            nn.ReLU(),
            nn.Conv2d(in_channels = 128, out_channels=128, kernel_size = 3, padding = "same"),
            nn.ReLU()
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels = 128, out_channels=64, kernel_size = 3, padding = "same"),
            nn.ReLU(),
            nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, padding = "same"),
            nn.ReLU()
        )
        self.last_conv = nn.Conv2d(in_channels = 64, out_channels = 1, kernel_size = 1)
    def forward(self, x):
    #TODO Skip Connection
        x = self.upsample(x)
        x = self.block1(x)
        x = self.upsample(x)
        x = self.block2(x)
        x = self.upsample(x)
        x = self.block3(x)
        x = self.last_conv(x)

        return x

In [9]:
class Net(nn.Module):
    def __init__(self ):
        super().__init__()
        self.encoder = UNetEncoder()
        self.decoder = UNetDecoder()

    def forward(self, batch):
        x = self.encoder(batch)
        logit = self.decoder(x)
        
        out = {}
        
        if self.training :
            out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])
        else :
            out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])
            out['probability'] = torch.sigmoid(logit)
        
        return out

In [10]:
model = Net().to(device)

In [11]:
train_df = df[df['fold']!=args['fold']]
valid_df = df[df['fold']==args['fold']]

In [12]:
print(len(train_df), len(valid_df))

280 71


In [13]:
train_ds = segDataset(df = train_df, augment = None)

train_dl = DataLoader(train_ds,
                batch_size = args['batch_size'],
                shuffle = True,
                pin_memory = True,
                drop_last = False
                     )

valid_ds = segDataset(df = valid_df, augment = None)

valid_dl = DataLoader(valid_ds,
                batch_size = args['batch_size'],
                shuffle = True,
                pin_memory = True,
                drop_last = False
                )


In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr = args['start_lr'])
#scheduler =

In [15]:
def compute_dice_score(probability, mask):
    print(probability)
    print(mask)
    N = len(probability)
    p = probability#.reshape(N,-1)
    t = mask#.reshape(N,-1)

    p = p>0.5
    t = t>0.5
    print(p, t)
    uion = p.sum(-1) + t.sum(-1)
    overlap = (p*t).sum(-1)
    dice = 2*overlap/(uion+0.0001)
    return dice

## Training

In [16]:
scaler = torch.cuda.amp.GradScaler()

In [17]:
def valid(model, valid_dl):
    logits = []
    losses = []
    masks  = []
    for batch in tqdm(valid_dl, total=len(valid_dl)):
        batch['image'] = batch['image'].to(device)
        batch['mask']  = batch['mask'].to(device)

        model.eval()
        with torch.no_grad():
            out = model(batch)
        loss = out['bce_loss']
        logit = out['probability']
        
        logits.append(logit.detach().cpu().numpy())
        losses.append(loss.detach().cpu().numpy())
        masks.append(batch['mask'].detach().cpu().numpy())
        
    logits = np.concatenate(logits)
    masks = np.concatenate(masks)
    
    score = compute_dice_score(logits, masks).mean()
    losses = np.asarray(losses).mean()
    
    return score, losses
        

In [18]:
train_ds[0]['mask'].unique()

tensor([0., 1.])

In [19]:
save_path = '/kaggle/working/model/'
os.makedirs(save_path, exist_ok=True)

In [None]:
losses = []
for ep in tqdm(range(args['epoches'])):
    for batch in tqdm(train_dl, total = len(train_dl)):
        batch['image'] = batch['image'].to(device)
        batch['mask']  = batch['mask'].to(device)

        model.train()
        
        out = model(batch)
        loss = out['bce_loss']

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    train_loss = np.asarray(losses).mean()
    score, valid_loss = valid(model, valid_dl)

    print(f'ep_{ep} train_loss : {train_loss}, valid_loss : {valid_loss}, dice_score : {score}')
    
    torch.save(model.state_dict(), save_path + f'ep_{ep}_unet_model.pt')
    torch.save(optimizer.state_dict(), save_path + f'ep_{ep}_unet_optimizer.pt')
    #torch.save(scheduler.state_dict(), save_path + f'ep_{ep}_unet_scheduler.pt')
print('train finished')

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/280 [00:00<?, ?it/s]