## Imports

In [None]:
!pip install transformers

In [None]:
import numpy as np
import pandas as pd

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
import cv2
import matplotlib.pyplot as plt
import albumentations as A
import torchinfo
import timm

from transformers import SegformerModel, SegformerConfig, SegformerForSemanticSegmentation
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

## Config

In [None]:
device = "cuda" if torch.cuda.is_available() else 'cpu'
args = {
    "fold" : 0,
    "epoches" : 20,
    "batch_size" : 2,
    "start_lr" : 4e-5,
    "image_size" : 720
}


## Prepare Dataset

In [None]:
df = pd.read_csv('/kaggle/input/hubmap-organ-segmentation/train.csv')
df.sample(5)

In [None]:
kf = KFold(n_splits=5)
df['fold'] = -1
for idx, (train_idx, valid_idx) in enumerate(kf.split(X=df)):
    df.loc[valid_idx, 'fold'] = idx

In [None]:
df['fold'].value_counts()

In [None]:
train_augment = A.Compose([
    A.RandomCrop(512, 512, p=0.3),
    A.Resize(args['image_size'], args['image_size'], p = 1),
    A.HorizontalFlip(p=0.3),
    A.VerticalFlip(p=0.3),
    A.RandomRotate90(p=0.3),
    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3),
    A.OneOf([
            A.RandomGamma(gamma_limit=(80, 120), p=0.3),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, brightness_by_max=True, p=0.3)], p=0.5)
])

In [None]:
class segDataset(Dataset):   
    def __init__(self, df, augment):
        self.df = df              
        self.augment = augment    

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, index): 
        d = self.df.iloc[index]
        id = d['id']
        height = d['img_height']
        width = d['img_width']

        image = cv2.imread(f'/kaggle/input/hubmap-data/train_image/{id}.png')                       
        mask  = cv2.imread(f'/kaggle/input/hubmap-data/train_mask/{id}.png', cv2.IMREAD_GRAYSCALE)
        
        if self.augment is not None:
            aug = self.augment(image= image, mask=mask)
            image = aug['image']
            mask  = aug['mask']
        mask = np.stack([mask], axis=0)
        image = image / 255.  
        mask  = mask  / 255. 

        out = {}
        out['image'] = torch.tensor(image).permute(2,0,1).float() # h, w, c -> c, h, w
        out['mask']  = torch.tensor(mask>0.5).float()
        
        return out


#albumentations -> channel : 3 or 0

## Model

In [None]:
# class UNetEncoder(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.maxpool = nn.MaxPool2d(2)
#         self.block1 = nn.Sequential(
#             nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = 64, out_channels= 64, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.block2 = nn.Sequential(
#             nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = 128, out_channels= 128, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.block3 = nn.Sequential(
#             nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.block4 = nn.Sequential(
#             nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )

#     def forward(self, batch):
#         x = batch['image']
#         x = self.block1(x)
#         x = self.maxpool(x)
#         x = self.block2(x)
#         x = self.maxpool(x)
#         x = self.block3(x)
#         x = self.maxpool(x)
#         x = self.block4(x)

#         return x


In [None]:
# class UNetDecoder(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.upsample = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners=True)
#         self.block1 = nn.Sequential(
#             nn.Conv2d(in_channels = dim, out_channels = dim // 2, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = dim // 2 , out_channels = dim // 2, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.block2 = nn.Sequential(
#             nn.Conv2d(in_channels = dim // 2, out_channels= dim // 4, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = dim // 4, out_channels= dim // 4, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.block3 = nn.Sequential(
#             nn.Conv2d(in_channels = dim // 4, out_channels = dim // 8, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = dim // 8, out_channels= dim // 8, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.last_conv = nn.Conv2d(in_channels = dim // 8, out_channels = 1, kernel_size = 1)
#     def forward(self, x):
#     #TODO Skip Connection
#         x = self.upsample(x)
#         x = self.block1(x)
#         x = self.upsample(x)
#         x = self.block2(x)
#         x = self.upsample(x)
#         x = self.block3(x)
#         x = self.last_conv(x)
#         x = F.interpolate(x, size=(720,720))
#         return x

In [None]:
# class Net(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.encoder = timm.create_model('tf_efficientnet_b6', 
#                                          pretrained = True, 
#                                          num_classes = 0,   
#                                          global_pool = '') 
        
#         dim = self.encoder.conv_head.out_channels # effnet_b4 = 1792
#         self.decoder = UNetDecoder(dim = dim)

#     def forward(self, batch):
#         x = self.encoder(batch['image'])
#         logit = self.decoder(x)
        
#         out = {}
        
#         if self.training :
#             out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])

#         else :
#             out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])
#             out['probability'] = torch.sigmoid(logit)
        
#         return out

In [None]:
#model = Net().to(device)
_model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b4-finetuned-ade-512-512",
                                                        num_labels = 1,
                                                        ignore_mismatched_sizes=True).to(device)

In [None]:
class Net(nn.Module):
    def __init__(self, model,image_size):
        super().__init__()
        self.model = model
        self.image_size = image_size
    
    def forward(self, batch):
        logit = self.model(batch['image']).logits
        logit = F.interpolate(logit, (720, 720))
        out = {}
        
        if self.training :
            out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])

        else :
            out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])
            out['probability'] = torch.sigmoid(logit)
        
        return out

In [None]:
train_df = df[df['fold']!=args['fold']]
valid_df = df[df['fold']==args['fold']]

In [None]:
print(len(train_df), len(valid_df))

In [None]:
 model = Net(model = _model)

In [None]:
train_ds = segDataset(df = train_df, augment = train_augment)

train_dl = DataLoader(train_ds,
                batch_size = args['batch_size'],
                shuffle = True,
                pin_memory = True,
                drop_last = False
                     )

valid_ds = segDataset(df = valid_df, augment = None)

valid_dl = DataLoader(valid_ds,
                batch_size = args['batch_size'],
                shuffle = False,
                pin_memory = True,
                drop_last = False
                )


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = args['start_lr'])
#scheduler =

In [None]:
def compute_dice_score(probability, mask):
    N = len(probability)
    p = probability.reshape(N,-1)
    t = mask.reshape(N,-1)
    p = p>0.5
    t = t>0.5
    uion = p.sum(-1) + t.sum(-1)
    overlap = (p*t).sum(-1)
    dice = 2*overlap/(uion+0.0001)
    return dice

## Training

In [None]:
scaler = torch.cuda.amp.GradScaler()

In [None]:
def valid(model, valid_dl):
    logits = []
    losses = []
    masks  = []
    for batch in tqdm(valid_dl, total=len(valid_dl)):
        batch['image'] = batch['image'].to(device)
        batch['mask']  = batch['mask'].to(device)

        model.eval()
        with torch.no_grad():
            out = model(batch)
        loss = out['bce_loss'].mean()
        logit = out['probability']
        
        logits.append(logit.detach().cpu().numpy())
        losses.append(loss.detach().cpu().numpy())
        masks.append(batch['mask'].detach().cpu().numpy())
    
    logits = np.concatenate(logits)
    masks = np.concatenate(masks)
    
    score = compute_dice_score(logits, masks).mean()
    losses = np.asarray(losses).mean()
    
    return score, losses
        

In [None]:
save_path = '/kaggle/working/model/'
os.makedirs(save_path, exist_ok=True)

In [None]:
best_score = 0
for ep in tqdm(range(args['epoches'])):
    losses = []
    for batch in tqdm(train_dl, total = len(train_dl)):
        model.train()
        with torch.cuda.amp.autocast():
            batch['image'] = batch['image'].to(device)
            batch['mask']  = batch['mask'].to(device)
            out = model(batch)
            loss = out['bce_loss'].mean()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        #scheduler.step()
        losses.append(loss.detach().cpu().numpy())
        
    train_loss = np.asarray(losses).mean()
    score, valid_loss = valid(model, valid_dl)
    
    if best_score < score :
        torch.save(model.state_dict(), save_path + f'ep_{ep}_unet_model.pt')
        best_score = score
    print(f'ep_{ep} train_loss : {train_loss}, valid_loss : {valid_loss}, dice_score : {score}')
    #torch.save(optimizer.state_dict(), save_path + f'ep_{ep}_unet_optimizer.pt')
    #torch.save(scheduler.state_dict(), save_path + f'ep_{ep}_unet_scheduler.pt')
print('train finished')