In [None]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.cuda import amp
import timm
from pytorch_toolbelt.inference import tta
from pytorch_toolbelt import losses as L

import random
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import cv2
from matplotlib import pyplot as plt
import os
import json
import pandas as pd
import segmentation_models_pytorch as smp
from augmentation import *
from monai.inferers import sliding_window_inference
from tifffile import imread

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
rot_aug= A.Compose([
            A.ShiftScaleRotate(shift_limit=0., scale_limit=0., rotate_limit= 180,
                                            interpolation=cv2.INTER_LINEAR, border_mode=0, p=1.0),
            ToTensorV2(p=1.0),
        ])

def get_train_transform():
    return A.Compose([
        A.PadIfNeeded(min_height=CFG['window_size'], min_width=CFG['window_size'], p=1),
        A.RandomCrop(CFG['window_size'], CFG['window_size'], p=1),
        
        A.OneOf([
            A.HueSaturationValue(hue_shift_limit=100, sat_shift_limit=15, val_shift_limit=10, p=0.7),
            A.CLAHE(clip_limit=2, p=0.3),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
        ], p=0.9),
        
        A.Blur(blur_limit= 2, p=0.3),
        A.GaussNoise(p=0.3),
        
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5),
        
#         A.OneOf([
#             A.GridDistortion(num_steps=5, distort_limit=0.5, p=1.0),
#             A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
#         ], p=0.7),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.2, rotate_limit= 15,
                                        interpolation=cv2.INTER_LINEAR, border_mode=0, p=0.9),
        A.CoarseDropout(max_holes=8, max_height=CFG['window_size']//20, max_width=CFG['window_size']//20,
                        min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        ToTensorV2(p=1.0),
    ])


def get_test_transform():
    return A.Compose([
        ToTensorV2(p=1.0),
    ] )

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape).T

def read_data(data, mode):
    ## read img
    img_path= data['image_path']
    if CFG['domain_shift'] and mode=='train' and np.random.rand()<CFG['domain_shift']:
        img_path= img_path.replace('train_images', 'train_images_domain_shift')
    try: img= imread(img_path)
    except: img= np.array(Image.open(img_path))
    
    ## read mask
    mask= rle_decode(data['rle'], img.shape[:2][::-1])
    mask= np.expand_dims(mask, axis= 2)
    
    ## organ type
    organ= data['organ']
    return img, mask, organ


class custom_dataset(Dataset):
    
    def __init__(self, dataset, transforms,
                 mode='valid',
                 cutmix=False,
                 mixup=False,
                 mosaic=False,
                 copypaste=False):
        self.dataset= dataset
        self.transforms= transforms
        self.mode= mode
        self.cutmix= cutmix
        self.mixup= mixup
        self.mosaic= mosaic
        self.copypaste= copypaste
        
    def __getitem__(self, index):
        
        data= self.dataset.loc[index]
        img, mask, organ= read_data(data, self.mode)
        
        ## scale adjust
        img_size= min(img.shape[0], img.shape[1])
        img_size= img_size//CFG['img_scale']
        img_size= CFG['window_size'] if img_size < CFG['window_size'] else img_size
        img= cv2.resize(img, (img_size, img_size))
        mask= cv2.resize(mask, (img_size, img_size))
        mask= np.expand_dims(mask, axis= 2)
        
        ## augmentation
        while True:
            aug= self.transforms(image= img, mask= mask)
            aug_img, aug_mask= aug['image'], aug['mask']
            if aug_mask.numpy().any(): break
        img, mask= aug_img, aug_mask
        
        # use mosaic
        if self.mosaic and np.random.rand() >= (1-self.mosaic) and organ!='lung':
            img_1= img.permute(1,2,0).numpy()
            mask_1= np.array(mask)
            imgs, masks= [], []
            imgs.append(img_1)
            masks.append(mask_1)
            for i in range(3):
                while True:
                    indx= np.random.randint(len(self.dataset))
                    data= self.dataset.loc[indx]
                    img_2, mask_2, organ_2= read_data(data, self.mode)
                    if organ_2!='lung': break
                imgs.append(img_2)
                masks.append(mask_2)
            img, mask= mosaic_aug(imgs[0].shape[0],
                                  CFG['window_size'],
                                  imgs[0], masks[0],
                                  imgs[1], masks[1],
                                  imgs[2], masks[2],
                                  imgs[3], masks[3])
            img= torch.tensor(img, dtype= torch.float32).permute(2,0,1)
            
        # use mixup
        if self.mixup and np.random.rand() >= (1-self.mixup):
            img_1= img.permute(1,2,0).numpy()
            mask_1= np.array(mask)
            indx= np.random.randint(len(self.dataset))
            data= self.dataset.loc[indx]
            img_2, mask_2= read_data(data)
            img, mask= mixup_aug(img_1.shape[0], 
                                 CFG['window_size'],
                                 img_1, mask_1, 
                                 img_2, mask_2)
            img= torch.tensor(img, dtype= torch.float32).permute(2,0,1)
            
        # use cutmix
        if self.cutmix and np.random.rand() >= (1-self.cutmix) and organ!='lung':
            img_1= img.permute(1,2,0).numpy()
            mask_1= np.array(mask)
            while True:
                indx= np.random.randint(len(self.dataset))
                data= self.dataset.loc[indx]
                img_2, mask_2, organ_2= read_data(data, self.mode)
                if organ_2!='lung': break
            img, mask= cutmix_aug(img_1.shape[0], 
                                  CFG['window_size'],
                                  img_1, mask_1, 
                                  img_2, mask_2)
            img= torch.tensor(img, dtype= torch.float32).permute(2,0,1)
            
        # use copypaste
        if self.copypaste and np.random.rand() >= (1-self.copypaste):
            img_1= img.permute(1,2,0).numpy()
            mask_1= np.array(mask)
            while True:
                indx= np.random.randint(len(self.dataset))
                data= self.dataset.loc[indx]
                img_2, mask_2= read_data(data)
                if mask_2.any(): break
            img, mask= copy_paste(img_1.shape[0], 
                                  CFG['window_size'],
                                  img_1, mask_1, 
                                  img_2, mask_2)
            img= torch.tensor(img, dtype= torch.float32).permute(2,0,1)

        ## rot aug
        img= img.permute(1,2,0).numpy()
        mask= np.array(mask)
        aug= rot_aug(image= img, mask= mask)
        img, mask= aug['image'], aug['mask']
            
        mask= torch.tensor(mask, dtype= torch.float)
        mask= mask.permute(2,0,1)

        ## aux label
        if CFG['aux']:
            aux_label= np.zeros((len(mask),))
            for i in range(len(mask)):
                if mask[i].numpy().any(): 
                    aux_label[i]= 1
        else:
            aux_label= [-1]
            
        if CFG['amp']: return img/255, mask, aux_label
        else: return img/255, mask
    
    def __len__(self):
        return len(self.dataset)

In [None]:
class custom_loss(nn.Module):
    def  __init__(self):
        super().__init__()
        self.JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')
        self.DiceLoss    = smp.losses.DiceLoss(mode='binary', from_logits=True)
        self.BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
        self.LovaszLoss  = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
        self.TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False, from_logits=True)
    
    def forward(self, y_pred, y_true):
        loss= 0.5*self.DiceLoss(y_pred, y_true) + 0.5*self.BCELoss(y_pred, y_true)
        return loss

In [None]:
def train_epoch_amp(dataloader, model, criterion, optimizer):
    scaler= amp.GradScaler()
    model.train()

    ep_loss= []
    for i, (imgs, masks, aux_labels) in enumerate(tqdm(dataloader)):

        imgs= imgs.to('cuda')
        masks= masks.to('cuda')
        if CFG['aux']: aux_labels= aux_labels.to('cuda')
        
        with amp.autocast():
            preds= model(imgs)
            if CFG['aux']:
                loss_1= criterion(preds[0], masks)
                loss_2= criterion(preds[1], aux_labels)
                loss= (1-CFG['aux'])*loss_1 + CFG['aux']*loss_2
            else:
                loss= criterion(preds, masks)
            ep_loss.append(loss.item())
            loss/= CFG['gradient_accumulation']
            scaler.scale(loss).backward()
            
            if (i+1) % CFG['gradient_accumulation']== 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                
    return np.mean(ep_loss)

In [None]:
def valid_epoch_amp(dataloader, model, criterion):
    model.eval()
    
    ep_loss= []
    masks_dice= []
    auxs_dice= []
    for i, (imgs, masks, aux_label) in enumerate(tqdm(dataloader)):

        imgs= imgs.to('cuda')
        masks= masks.to('cuda')
        if CFG['aux']: aux_label= aux_label.to('cuda')
            
        with torch.no_grad():
            preds= sliding_window_inference(imgs, 
                                            (CFG['window_size'], CFG['window_size']), 
                                            sw_batch_size= 2, 
                                            predictor= model, 
                                            mode= 'gaussian',
                                            overlap= 0.25)
            if CFG['aux']:
                loss_1= criterion(preds[0], masks)
                loss_2= criterion(preds[1], aux_label)
                loss= (1-CFG['aux'])*loss_1 + CFG['aux']*loss_2
            else:
                loss= criterion(preds, masks)
            ep_loss.append(loss.item())
            
        ## evaluation
        if CFG['aux']: pred_mask= preds[0].cpu()
        else: pred_mask= preds.cpu()
        mask_dice= dice( pred_mask, masks.cpu() )
        masks_dice.append(mask_dice)
            
        if CFG['aux']:
            pred_aux= preds[1].cpu()
            aux_dice= dice( pred_aux, aux_label.cpu() )
            auxs_dice.append(aux_dice)
        
    ## metrice
    dice_mask= 1 - np.mean(masks_dice)
    dice_aux= 1 - np.mean(auxs_dice) if CFG['aux'] else None
        
    return np.mean(ep_loss), dice_mask, dice_aux

In [None]:
from transformers import SegformerForSemanticSegmentation

class customize_model(nn.Module):
    def __init__(self, model_name):
        super(customize_model, self).__init__()
        
        self.model = SegformerForSemanticSegmentation.from_pretrained(model_name)
        self.model.decode_head.classifier= nn.Sequential(
                                                nn.Conv2d(768, 384, kernel_size=(1, 1), stride=(1, 1)),
                                                nn.Upsample(scale_factor=2, mode='nearest'),
                                                nn.Conv2d(384, 1, kernel_size=(1, 1), stride=(1, 1)),
                                                nn.Upsample(scale_factor=2, mode='nearest'),
                                            )
        
    def forward(self, images):
        out= self.model(images)['logits']
        return out

# CFG

In [None]:
CFG= {
    'fold': 0,
    'epoch': 150,
    'img_scale': 3,
    'window_size': 1024,
    'model_name': 'nvidia/segformer-b5-finetuned-ade-640-640',
    'aux': False,
    'classes_balance': True,
    'finetune': False,
    'domain_shift': 0.5,
    
    'lr': 6e-5,
    'weight_decay': 0,
    'batch_size': 1,
    'gradient_accumulation': 8,
    'amp': True,
    
    'cutmix': 0.5,
    'mixup': False,
    'mosaic': 0.5,
    'copypaste': False,
    
    'load_model': False,
    'save_model': 'train_model',
}
if CFG['aux']!=False: CFG['amp']= True
if CFG['finetune']:
    print('finetune')
    CFG['epoch']= 20
    CFG['lr']= 3e-5
    CFG['load_model']= f"train_model/model_cv{CFG['fold']}_best.pth"
    
# CFG['load_model']= f"train_model/model_cv{CFG['fold']}_best.pth"

# Prepare Dataset

In [None]:
df= pd.read_csv('./Data/train.csv')
# ex_df= pd.read_csv('./Data/train_ex_largeintestine.csv').fillna('')
# ex_df= ex_df[ex_df['rle']!='']
# ex_df['fold']= -1
# df= pd.concat([ex_df, df], axis=0)

## choose organ
choose_organ= [
    'lung',
    'spleen',
    'prostate',
    'kidney',
    'largeintestine'
]
df= df[df['organ'].isin(choose_organ)]
df.loc[ df['organ']=='lung', 'fold' ]= -1

train_df= df[df['fold']!=CFG['fold']].reset_index()
valid_df= df[df['fold']==CFG['fold']].reset_index()

## classes balance
if CFG['classes_balance']:
    print('balance classes')
    max_sample= len(train_df[train_df['organ']=='kidney'])
    for organ in choose_organ:
        fill_num= max_sample - len(train_df[train_df['organ']==organ])
        if fill_num:
            df= train_df[train_df['organ']==organ].reset_index(drop=True)
            try: sample_df= df.sample(n= fill_num, replace= False)
            except: sample_df= df.sample(n= fill_num, replace= True)
            train_df= pd.concat([train_df, sample_df], axis=0).reset_index(drop=True)

print(f'train dataset: {len(train_df)}')
print(f'valid dataset: {len(valid_df)}')

train_dataset= custom_dataset(train_df,
                              get_train_transform(),
                              mode= 'train',
                              cutmix= CFG['cutmix'],
                              mixup= CFG['mixup'],
                              mosaic= CFG['mosaic'],
                              copypaste= CFG['copypaste'])
valid_dataset= custom_dataset(valid_df, get_test_transform())

train_loader = DataLoader(train_dataset, batch_size= CFG['batch_size'], shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size= 1, shuffle=False, num_workers=0)
train_df.head()

# training

In [None]:
DEVICE = 'cuda:0'

if not CFG['load_model']:
    model= customize_model(CFG['model_name'])
else:
    print(f"load_model: {CFG['load_model']}")
    model= torch.load(CFG['load_model'])

## loss
loss= custom_loss()
loss.__name__ = 'custom_loss'
dice= L.DiceLoss(mode= 'multilabel')
dice.__name__= 'dice_loss'

metrics = [
#     smp.utils.metrics.Fscore(threshold= 0.5, activation= 'sigmoid'),
    smp.utils.metrics.IoU(threshold=0.5, activation= 'sigmoid'),
    dice,
]
optimizer = torch.optim.AdamW([ 
    dict(params=model.parameters(), lr=CFG['lr']),
], weight_decay= CFG['weight_decay'])


model.train()
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

model.eval()
valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)


max_score = 0
for i in range(1, CFG['epoch']+1):
    print('\nEpoch: {}'.format(i))
    
    ## train
    if CFG['amp']: 
        train_loss= train_epoch_amp(train_loader, model, loss, optimizer)
        valid_loss, dice_mask, dice_aux= valid_epoch_amp(valid_loader, model, loss)
        score= dice_mask
        print(f'train_loss: {round(train_loss, 5)}')
        print(f'valid_loss: {round(valid_loss, 5)},\
                vali_score: {round(score, 5)},\
                valid_aux: {dice_aux}')
    else: 
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        score= 1-valid_logs['dice_loss']
        print('vali_score: ', round(score, 5))

    ## save best model
    if max_score < score:
        max_score = score
        torch.save(model, f"{CFG['save_model']}/model_cv{CFG['fold']}_best.pth")
        print('Model saved at: ', round(max_score, 5))
        
    ## save model every epoch
    torch.save(model, f"{CFG['save_model']}/model_cv{CFG['fold']}_ep{i}.pth")

#     ##adjust lr
#     if i == 70:
#         model= torch.load(f"{CFG['save_model']}/model_cv{CFG['fold']}_best.pth")
#         optimizer.param_groups[0]['lr'] = 1e-4
#         print(f"Decrease decoder learning rate to {optimizer.param_groups[0]['lr']}!")