In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install rasterio
!pip install -U albumentations
!pip install fvcore

In [3]:
!pip install -qq torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -qq git+https://github.com/qubvel/segmentation_models.pytorch
!pip install -qq timm==0.4.12
!pip install -qq einops

[K     |███████████████████████         | 834.1 MB 1.4 MB/s eta 0:03:56tcmalloc: large alloc 1147494400 bytes == 0x39614000 @  0x7f45875b5615 0x58e046 0x4f2e5e 0x4d19df 0x51b31c 0x5b41c5 0x58f49e 0x51b221 0x5b41c5 0x58f49e 0x51837f 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4ba70a 0x538136 0x590055 0x51b180 0x5b41c5 0x58f49e 0x51837f 0x5b41c5 0x58f49e 0x51740e 0x58f2a7 0x517947 0x5b41c5 0x58f49e
[K     |█████████████████████████████▏  | 1055.7 MB 1.2 MB/s eta 0:01:24tcmalloc: large alloc 1434370048 bytes == 0x7dc6a000 @  0x7f45875b5615 0x58e046 0x4f2e5e 0x4d19df 0x51b31c 0x5b41c5 0x58f49e 0x51b221 0x5b41c5 0x58f49e 0x51837f 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4ba70a 0x538136 0x590055 0x51b180 0x5b41c5 0x58f49e 0x51837f 0x5b41c5 0x58f49e 0x51740e 0x58f2a7 0x517947 0x5b41c5 0x58f49e
[K     |████████████████████████████████| 1156.7 MB 1.2 MB/s eta 0:00:01tcmalloc: large alloc 1445945344 bytes == 0xd3456000 @  0x7f45875b5615 0x58e046 0x4f2e5e 0x4d19df 0x

In [5]:
!cp /content/drive/MyDrive/Segmentation/hubmap-organ-segmentation.zip /content
!unzip /content/hubmap-organ-segmentation.zip > /dev/null

In [6]:
!cp /content/drive/MyDrive/UneSt101/train_256_12.zip /content
!cp /content/drive/MyDrive/UneSt101/masks_256_12.zip /content

!unzip /content/train_256_12.zip -d /content/train_256_12 > /dev/null
!unzip /content/masks_256_12.zip -d /content/masks_256_12 > /dev/null

In [63]:
!cp /content/drive/MyDrive/UneSt101/train_512_6.zip /content
!cp /content/drive/MyDrive/UneSt101/masks_512_6.zip /content

!unzip /content/train_512_6.zip -d /content/train_512_6 > /dev/null
!unzip /content/masks_512_6.zip -d /content/masks_512_6 > /dev/null

In [1]:
!cp /content/drive/MyDrive/UneSt101/train_256_6.zip /content
!cp /content/drive/MyDrive/UneSt101/masks_256_6.zip /content

!unzip /content/train_256_6.zip -d /content/train_256_6 > /dev/null
!unzip /content/masks_256_6.zip -d /content/masks_256_6 > /dev/null

In [20]:
!cp /content/drive/MyDrive/Coat/coat.py /content
!cp /content/drive/MyDrive/Coat/daformer.py /content
!cp /content/drive/MyDrive/Coat/helper.py /content
!cp /content/drive/MyDrive/Coat/coat_lite_medium_384x384_f9129688.pth /content

# config

In [5]:
import os
import gc
import sys
import glob

import torch
import torch.nn as nn
import albumentations as A

import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tqdm
import segmentation_models_pytorch as smp
from sklearn.model_selection import StratifiedKFold

import tifffile as tiff
import shutil

torch.backends.cudnn.benchmark = True

In [6]:
fold = 0
nfolds = 10
imsize = 256
train_csv = '/content/train.csv'
BATCH_SIZE = 12
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 50
NUM_WORKERS = 1
SEED = 24
TRAIN_PATH = '/content/train_256_6/'
MASK_PATH = '/content/masks_256_6/'

In [7]:
def set_seed(seed=12):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
set_seed(12)

In [8]:
class HuBMAPDataset(torch.utils.data.Dataset):
    def __init__(self, fold=fold, train=True, tfms=None):
        self.train = train
        ids = pd.read_csv(train_csv).id.values
        labels = pd.read_csv(train_csv).organ.values
        kf = StratifiedKFold(n_splits=nfolds,random_state=SEED,shuffle=True)
        ids = (ids[list(kf.split(ids,labels))[fold][0 if train else 1]]).tolist()
        self.fnames = [fname for fname in os.listdir(TRAIN_PATH) if int(fname.split('_')[0]) in ids]
        self.image_size = imsize
        self.tfms = tfms
        
    def img2tensor(self, img,dtype:np.dtype=np.float32):
        if img.ndim==2 : img = np.expand_dims(img,2)
        img = np.transpose(img,(2,0,1)) # C , H , W
        return torch.from_numpy(img.astype(dtype, copy=False))
    
    def __len__(self):
        return len(self.fnames)
    
    def resize(self, img, interp):
        return  cv2.resize(
            img, (self.image_size, self.image_size), interpolation=interp)
    
    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.cvtColor(cv2.imread(TRAIN_PATH + fname), cv2.COLOR_BGR2RGB)
        mask = cv2.imread((MASK_PATH + fname),cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented = self.tfms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']
        return self.img2tensor(self.resize(img , cv2.INTER_NEAREST)) , self.img2tensor(self.resize(mask , cv2.INTER_NEAREST))

In [9]:
def transformer(p=1.0):
    return A.Compose([
        A.GridDropout(ratio=0.4, unit_size_min=None, unit_size_max=None, holes_number_x=None, holes_number_y=None, shift_x=0, shift_y=0, 
                     random_offset=True, fill_value=0, mask_fill_value=0, always_apply=False, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        # Morphology
        A.ShiftScaleRotate(shift_limit=(-0.1, 0.1), scale_limit=(-0.2, 0.2), rotate_limit=(-30, 30), interpolation=1, border_mode=0, value=(0, 0, 0), p=0.4),
        A.GaussNoise(var_limit=(5.0, 50.0), mean=10, p=0.5),
        A.GaussianBlur(blur_limit=(3, 7), p=0.5),
        # Color
        A.RandomBrightnessContrast(brightness_limit=0.35, contrast_limit=0.5,
                                 brightness_by_max=True, p=0.5),
        A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=30,
                           val_shift_limit=30, p=0.5),
        A.OneOf([
            A.OpticalDistortion(p=0.5),
            A.GridDistortion(p=0.5),
            A.PiecewiseAffine(p=0.5),
        ], p=0.5),
    ], p=p)

In [10]:
# ds = HuBMAPDataset(tfms=transformer())
# dl = torch.utils.data.DataLoader(ds,batch_size=64,shuffle=False,num_workers=NUM_WORKERS)
# it = iter(dl)
# imgs,masks = next(it)

In [11]:
# plt.figure(figsize=(16,16))
# for i,(img,mask) in enumerate(zip(imgs,masks)):
#     img = ((img.permute(1,2,0))).numpy().astype(np.uint8)  # H , W , C
#     plt.subplot(8,8,i+1)
#     plt.imshow(img,vmin=0,vmax=255)
#     plt.imshow(mask.squeeze().numpy(), alpha=0.2)
#     plt.axis('off')
#     plt.subplots_adjust(wspace=None, hspace=None)
    
# del ds,dl,imgs,masks

# model

In [12]:
from coat import *
from daformer import *
from helper import *

In [13]:
class Net(nn.Module):
    
    def __init__(self,
                 encoder=coat_lite_medium,
                 decoder=daformer_conv3x3,
                 encoder_cfg={},
                 decoder_cfg={},
                 ):
        
        super(Net, self).__init__()
        decoder_dim = decoder_cfg.get('decoder_dim', 320)

        self.encoder = encoder

        self.rgb = RGB()

        encoder_dim = self.encoder.embed_dims
        # [64, 128, 320, 512]

        self.decoder = decoder(
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim,
        )
        self.logit = nn.Sequential(
            nn.Conv2d(decoder_dim, 1, kernel_size=1),
            nn.Upsample(scale_factor = 4, mode='bilinear', align_corners=False),
        )

    def forward(self, batch):

        x = self.rgb(batch)

        B, C, H, W = x.shape
        encoder = self.encoder(x)

        last, decoder = self.decoder(encoder)
        logit = self.logit(last)

        output = {}
        probability_from_logit = torch.sigmoid(logit)
        output['probability'] = probability_from_logit

        return output

In [14]:
def init_model():
    encoder = coat_lite_medium()
    checkpoint = '/content/coat_lite_medium_384x384_f9129688.pth'
    checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    state_dict = checkpoint['model']
    encoder.load_state_dict(state_dict,strict=False)
    
    net = Net(encoder=encoder).cuda()
    
    return net

# metric

In [15]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss,self).__init__()
        self.diceloss = smp.losses.DiceLoss(mode='binary')
        self.binloss = smp.losses.SoftBCEWithLogitsLoss(reduction = 'mean' , smooth_factor = 0.1)

    def forward(self, output, mask):
        dice = self.diceloss(outputs,mask)
        bce = self.binloss(outputs , mask)
        loss = dice * 0.7 + bce * 0.3
        return loss

In [16]:
class DiceCoef(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()

    def forward(self, y_pred, y_true, smooth=1.):
        y_true = y_true.view(-1)
        y_pred = y_pred.view(-1)
        
        #Round off y_pred
        y_pred = torch.round((y_pred - y_pred.min()) / (y_pred.max() - y_pred.min()))
        
        intersection = (y_true * y_pred).sum()
        dice = (2.0*intersection + smooth)/(y_true.sum() + y_pred.sum() + smooth)
        
        return dice

In [17]:
def plot_df(df):
    fig,ax = plt.subplots(1,2,figsize=(15,5))
    ax[0].plot(df['Train_loss'])
    ax[0].plot(df['Val_loss'])
    ax[0].legend()
    ax[0].set_title('Loss')
    ax[1].plot(df['Train_Dice'])
    ax[1].plot(df['Val_Dice'])
    ax[1].legend()
    ax[1].set_title('Dice')

# train

In [18]:
import shutil

shutil.rmtree('/content/models')

In [None]:
print(f"Running on device :  {DEVICE}" )
if not os.path.exists("models"):
        os.mkdir("models")
for fold in tqdm.notebook.tqdm(range(0, 1)):

    if not os.path.exists(f"models/fold_{fold}"):
        os.mkdir(f"models/fold_{fold}")
    
    val_losses = []
    losses = []
    train_scores=[]
    val_scores = []
    best_loss = 999
    best_score = 0
    
    ds_train = HuBMAPDataset(fold=fold, train=True, tfms=transformer())
    ds_val = HuBMAPDataset(fold=fold, train=False)
    
    dataloader_train = torch.utils.data.DataLoader(ds_train,batch_size=BATCH_SIZE, shuffle=True,num_workers=NUM_WORKERS)
    dataloader_val = torch.utils.data.DataLoader(ds_val,batch_size=BATCH_SIZE, shuffle=False,num_workers=NUM_WORKERS)
    
    model = init_model().to(DEVICE)
    
    optimizer = torch.optim.Adam([
        {'params': model.decoder.parameters(), 'lr': 5e-5}, 
        {'params': model.encoder.parameters(), 'lr': 8e-5},  
    ])
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=1e3, 
                                              max_lr=1e-3, epochs=EPOCHS, steps_per_epoch=len(dataloader_train))
    
    loss_func = CustomLoss()
    dice_coe = DiceCoef()
    
    print(f"######## FOLD: {fold} ##############")
    
    for epoch in tqdm.notebook.tqdm(range(EPOCHS)):


        
        ### Train ###########################################################################################
        
        model.train()
        train_loss = 0
        score = 0
        
        for data in tqdm.notebook.tqdm(dataloader_train ,total = len(dataloader_train)):
            optimizer.zero_grad()
            img, mask = data
            img = img.to(DEVICE)
            mask = mask.to(DEVICE)
        
            outputs = model(img)['probability']    

            loss = loss_func(outputs, mask)
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
            score += dice_coe(outputs,mask).item()
            
        train_loss /= len(dataloader_train)
        score /= len(dataloader_train)
        losses.append(train_loss)
        train_scores.append(score)
        print(f"FOLD: {fold}, EPOCH: {epoch + 1}, train_loss: {train_loss} , Dice coe : {score} ") #
        
        
        gc.collect()
        torch.cuda.empty_cache()
        
        ### Validation ####################################################################################
        
        model.eval()
        
        with torch.no_grad():
            
            valid_loss = 0
            val_score = 0
            
            for data in dataloader_val:
                
                img, mask = data
                img = img.to(DEVICE)
                mask = mask.to(DEVICE)

                outputs = model(img)['probability']

                loss = loss_func(outputs, mask)
                valid_loss += loss.item()
                val_score += dice_coe(outputs,mask).item()
                
            valid_loss /= len(dataloader_val)
            val_losses.append(valid_loss)
            
            val_score /= len(dataloader_val)
            val_scores.append(val_score)
            
            print(f"FOLD: {fold}, EPOCH: {epoch + 1}, valid_loss: {valid_loss} , Val Dice COE : {val_score}") #
            
            gc.collect()
            torch.cuda.empty_cache()





            
        if val_score > best_score:
            best_score = val_score
            torch.save(model.state_dict(), f"models/fold_{fold}/FOLD{fold}_best_score.pth")
            print(f"Saved model for best score : FOLD{fold}_best_score.pth")
            
        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), f"models/fold_{fold}/FOLD{fold}_best_loss.pth")
            print(f"Saved model for best loss : FOLD{fold}_best_loss.pth")


    save_path = '/content/drive/MyDrive/Coat/256_6/'
    
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    shutil.copy(f"models/fold_{fold}/FOLD{fold}_best_score.pth", save_path)
    shutil.copy(f"models/fold_{fold}/FOLD{fold}_best_loss.pth", save_path)

    column_names = ['Train_loss','Val_loss','Train_Dice','Val_Dice']
    df = pd.DataFrame(np.stack([losses,val_losses,train_scores,val_scores],axis=1),columns=column_names)
    print(f" ################# FOLD {fold} #####################")
    plot_df(df)
    plt.show(block=False)
    df.to_csv(f"logs_fold{fold}.csv")

    shutil.copy(f"logs_fold{fold}.csv", save_path)

Running on device :  cuda


  0%|          | 0/1 [00:00<?, ?it/s]

######## FOLD: 0 ##############


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

FOLD: 0, EPOCH: 1, train_loss: 0.841938020501818 , Dice coe : 0.24039449784017744 
FOLD: 0, EPOCH: 1, valid_loss: 0.819345697760582 , Val Dice COE : 0.36771376555164653
Saved model for best score : FOLD0_best_score.pth
Saved model for best loss : FOLD0_best_loss.pth


  0%|          | 0/105 [00:00<?, ?it/s]

FOLD: 0, EPOCH: 2, train_loss: 0.8117676133201236 , Dice coe : 0.4027974687871479 
FOLD: 0, EPOCH: 2, valid_loss: 0.7734880795081457 , Val Dice COE : 0.6384510397911072
Saved model for best score : FOLD0_best_score.pth
Saved model for best loss : FOLD0_best_loss.pth


  0%|          | 0/105 [00:00<?, ?it/s]

FOLD: 0, EPOCH: 3, train_loss: 0.8002212734449478 , Dice coe : 0.5444512161470595 
FOLD: 0, EPOCH: 3, valid_loss: 0.772734671831131 , Val Dice COE : 0.6920697639385859
Saved model for best score : FOLD0_best_score.pth
Saved model for best loss : FOLD0_best_loss.pth


  0%|          | 0/105 [00:00<?, ?it/s]

FOLD: 0, EPOCH: 4, train_loss: 0.7999696277436756 , Dice coe : 0.5136589241879327 
FOLD: 0, EPOCH: 4, valid_loss: 0.7663244853417078 , Val Dice COE : 0.7065657675266266
Saved model for best score : FOLD0_best_score.pth
Saved model for best loss : FOLD0_best_loss.pth


  0%|          | 0/105 [00:00<?, ?it/s]

FOLD: 0, EPOCH: 5, train_loss: 0.7967713055156526 , Dice coe : 0.5588312750770932 
FOLD: 0, EPOCH: 5, valid_loss: 0.7708547860383987 , Val Dice COE : 0.6394923776388168


  0%|          | 0/105 [00:00<?, ?it/s]

FOLD: 0, EPOCH: 6, train_loss: 0.7932513969285148 , Dice coe : 0.5767405116841906 
FOLD: 0, EPOCH: 6, valid_loss: 0.763994445403417 , Val Dice COE : 0.6955267091592153
Saved model for best loss : FOLD0_best_loss.pth


  0%|          | 0/105 [00:00<?, ?it/s]

FOLD: 0, EPOCH: 7, train_loss: 0.7932063119752066 , Dice coe : 0.5927230351027988 
FOLD: 0, EPOCH: 7, valid_loss: 0.7643148452043533 , Val Dice COE : 0.7129199703534445
Saved model for best score : FOLD0_best_score.pth


  0%|          | 0/105 [00:00<?, ?it/s]