<a href="https://colab.research.google.com/github/MattiaBrazzale/HuBMAP-22/blob/main/model/CoaT_main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## HuBMAP-22 Challenge

This is my notebook with the training and submitting for the HuBMAP-22 Challenge.

The training is done by performing a 5-fold split, and using 4/5 folds for training and 1/5 fold for validation.


### Requirements
Installing the required packages:

In [None]:
!pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -qq git+https://github.com/qubvel/segmentation_models.pytorch
!pip install timm==0.4.12
!pip install einops

Loading libraries and utilities:

In [None]:
import gc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random
from sklearn.model_selection import KFold
import segmentation_models_pytorch as smp
import tifffile
import torchvision
import cv2
import torch
import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2

from coat import *
from daformer import *
from utils import seed_everything, make_fold, get_mask, rle_encode
import config
seed_everything(config.SEED)

Loading the dataframes and splitting the train set into training and validation set:

In [None]:
train = pd.read_csv(config.TRAIN_PATH+'train.csv')
test_df = pd.read_csv(config.TEST_PATH+'test.csv')
train_df, val_df = make_fold(num_fold=config.NUM_FOLD, val_fold=config.VAL_FOLD, df=train)

Defining the PyTorch dataset for the images:

In [None]:
class HuBMAPData(Dataset):
    def __init__(self, transform=None, df=train_df):
        self.transform = transform
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        temp = self.df.iloc[index]
        id = temp['id']
        organ = temp['organ']

        image = tifffile.imread(config.TRAIN_PATH+str(id)+'.tiff')
        mask = get_mask(id, self.df)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
            mask = mask.float()

        return image, mask, organ

Definining the augmentations for the training set, that we need since we have only a small amount of images:

In [None]:
train_transform = A.Compose(
      [
        A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
        A.Rotate(limit=35, p=0.8),
        A.OneOf([
            A.HorizontalFlip(p=0.6),
            A.VerticalFlip(p=0.6),
            A.RandomRotate90(p=0.6)
        ], p=1.0),
        A.OneOf([
            A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
            A.GridDistortion(p=0.5),
            A.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=0.5) 
        ], p=0.8),
        A.ChannelShuffle(p=0.4),
        A.GaussNoise(var_limit=(10.0, 50.0), mean=0, per_channel=True, always_apply=False, p=0.4),
        A.OneOf([
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.75),
            A.RandomBrightnessContrast(p=0.7),
            A.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25, p=0.75)
        ], p=0.9),
        A.RandomGamma(p=0.6),
        ToTensorV2(transpose_mask=True),
      ]
  )

val_transform = A.Compose(
    [
        A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
        ToTensorV2(transpose_mask=True),
    ]
)

Initializing the datasets and the dataloaders:

In [None]:
train_dataset = HuBMAPData(transform=train_transform, df=train_df)
val_dataset = HuBMAPData(transform=val_transform, df=val_df)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.BATCH_SIZE)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=config.BATCH_SIZE)

### CoaT 

Loading the pretrained weights:



In [None]:
!mkdir -p ./coat-pretrained
!wget http://vcl.ucsd.edu/coat/pretrained/coat_lite_medium_a750cd63.pth -P ./coat-pretrained
!sha256sum ./coat-pretrained/coat_lite_medium_a750cd63.pth

Defining the model architecture and the function to initialize the model:

In [None]:
class RGB(nn.Module):
    IMAGE_RGB_MEAN = [0.485, 0.456, 0.406]  
    IMAGE_RGB_STD = [0.229, 0.224, 0.225] 
    
    def __init__(self, ):
        super(RGB, self).__init__()
        self.register_buffer('mean', torch.zeros(1, 3, 1, 1))
        self.register_buffer('std', torch.ones(1, 3, 1, 1))
        self.mean.data = torch.FloatTensor(self.IMAGE_RGB_MEAN).view(self.mean.shape)
        self.std.data = torch.FloatTensor(self.IMAGE_RGB_STD).view(self.std.shape)

    def forward(self, x):
        x = (x - self.mean) / self.std
        return x

    
    
class Net(nn.Module):
	
		def __init__(self,
								encoder=coat_lite_medium,
								decoder=daformer_conv3x3,
								encoder_cfg={},
								decoder_cfg={},
								):
				super(Net, self).__init__()
				decoder_dim = decoder_cfg.get('decoder_dim', 320)
				
				self.rgb = RGB()
				
				self.encoder = encoder
				encoder_dim = self.encoder.embed_dims
				
				self.decoder = decoder(
						encoder_dim=encoder_dim,
						decoder_dim=decoder_dim,
				)
				self.logit = nn.Sequential(
						nn.Conv2d(decoder_dim, 1, kernel_size=1),
						nn.Upsample(scale_factor = 4, mode='bilinear', align_corners=False),
				)
		
		def forward(self, batch):
			
				x = self.rgb(batch)
						
				B, C, H, W = x.shape
				encoder = self.encoder(x)
				
				last, decoder = self.decoder(encoder)
				logit = self.logit(last)
				
				output = {}
				probability_from_logit = torch.sigmoid(logit)
				output['probability'] = probability_from_logit
				
				return output


def init_model():
		"""
		Function used to initialize a CoaT model
		"""
    encoder = coat_lite_medium()
    checkpoint = './coat-pretrained/coat_lite_medium_a750cd63.pth' #pretrained weight available at the CoaT repository
    checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    state_dict = checkpoint['model']
    encoder.load_state_dict(state_dict,strict=False)
    net = Net(encoder=encoder).cuda()
    return net

Definining the class for the custom loss and the evaluation metric:

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss,self).__init__()
        self.diceloss = smp.losses.DiceLoss(mode='binary')
        self.binloss = smp.losses.SoftBCEWithLogitsLoss(reduction = 'mean' , smooth_factor = 0.1)
        self.jaccardloss = smp.losses.JaccardLoss(mode='binary')

    def forward(self, output, mask):
        dice = self.diceloss(outputs, mask)
        bce = self.binloss(outputs, mask)
        jaccard = self.jaccardloss(outputs, mask)
        loss = dice * 0.3 + jaccard * 0.7
        return loss

class DiceCoef(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()

    def forward(self, y_pred, y_true, smooth=1.):
        y_true = y_true.view(-1)
        y_pred = y_pred.view(-1)
        
        y_pred = torch.round((y_pred - y_pred.min()) / (y_pred.max() - y_pred.min()))
        
        intersection = (y_true * y_pred).sum()
        dice = (2.0*intersection + smooth)/(y_true.sum() + y_pred.sum() + smooth)
        
        return dice

### Training Loop

In [None]:
# set the TRAIN variable to True in the config file to perform training
if config.TRAIN:

    train_losses = []
    val_losses = []
    train_scores=[]
    val_scores = []
    best_loss = 999
    best_score = 0

    model = init_model().to(config.DEVICE)

    optimizer = torch.optim.Adam([
        {'params': model.decoder.parameters(), 'lr': config.DECODER_LEARNING_RATE}, 
        {'params': model.encoder.parameters(), 'lr': config.ENCODER_LEARNING_RATE},  
    ])

    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, 
                                              max_lr=MAX_LEARNING_RATE,
                                              epochs=EPOCHS, steps_per_epoch=len(train_loader))

    loss_func = CustomLoss()
    dice_coe = DiceCoef()

    for epoch in tqdm(range(config.EPOCHS)):

        # ------- Train ------- #

        model.train()
        train_loss = 0
        score = 0

        for batch_idx, (img, mask, organ) in enumerate(train_loader):
            
            img = img.float().to(device=config.DEVICE)
            mask = mask.float().to(device=config.DEVICE)
            outputs = model(img)['probability']   

            loss = loss_func(outputs, mask)
            loss.backward()
            
            if ((batch_idx+1)*config.BATCH_SIZE % config.ACCUMULATION == 0) | ((batch_idx+1) == len(train_df)):
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                
            train_loss += loss.detach()
            score += dice_coe(outputs,mask).item()

        train_loss /= len(train_loader)
        score /= len(train_loader)
        train_losses.append(train_loss)
        train_scores.append(score)
        print(f"FOLD: {val_fold}, EPOCH: {epoch+1}, Train_Loss: {train_loss} , Dice Value: {score}") #


        gc.collect()
        torch.cuda.empty_cache()

        # ------ Validation ------ #

        model.eval()

        with torch.no_grad():

            val_loss = 0
            val_score = 0

            for i, (img, mask, organ) in enumerate(val_loader):

                img = img.float().to(device=config.DEVICE)
                mask = mask.float().to(device=config.DEVICE)
                outputs = model(img)['probability']

                loss = loss_func(outputs, mask)
                val_loss += loss.item()
                val_score += dice_coe(outputs,mask).item()

            val_loss /= len(val_loader)
            val_losses.append(val_loss)

            val_score /= len(val_loader)
            val_scores.append(val_score)

            print(f"FOLD: {val_fold}, EPOCH: {epoch+1}, Val_Loss: {val_loss} , Valid Dice Value: {val_score}") 

            gc.collect()
            torch.cuda.empty_cache()

        if val_score > best_score:
            best_score = val_score
            torch.save(model.state_dict(), f"./FOLD{fold}_best_score_epoch{epoch+1}.pth")
            print(f"Saved model for best score : FOLD{fold}_best_score_epoch{epoch+1}.pth")

        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), f"./FOLD{fold}_best_loss_epoch{epoch+1}.pth")
            print(f"Saved model for best loss : FOLD{fold}_best_loss_epoch{epoch+1}.pth")    

        if epoch % 20 == 0:
            torch.save(model.state_dict(), f"/kaggle/working/FOLD{fold}_epoch_{epoch+1}.pth")
            print(f"Saved model for current epoch: FOLD{fold}_epoch_{epoch+1}.pth")


### Inference on new data

The following augmentation are used to perform Test Time Agumentations, i.e. performing inference on slightly different images and then averaging the predictions:

In [None]:
horizontal_flip = A.Compose([
    A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
    A.HorizontalFlip(p = 1.),
    ToTensorV2(transpose_mask=True)])

vertical_flip = A.Compose([
    A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
    A.VerticalFlip(p = 1.),
    ToTensorV2(transpose_mask=True)])

rotate_cw = A.Compose([
    A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
    A.Rotate(limit = (-90, -90), p = 1.),
    ToTensorV2(transpose_mask=True)])

rotate_acw = A.Compose([
    A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
    A.Rotate(limit = (90, 90), p = 1.),
    ToTensorV2(transpose_mask=True)])

identity_trfm = A.Compose([
    A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
    A.HorizontalFlip(p = 0.),
    ToTensorV2(transpose_mask=True)]) # does nothing

pixel_level_trfms1 = A.Compose([
    A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
    A.HueSaturationValue(10,15,10),
    ToTensorV2(transpose_mask=True)])

pixel_level_trfms2 = A.Compose([
    A.Resize(height=config.IMG_SIZE, width=config.IMG_SIZE),
    A.CLAHE(clip_limit=2),
    ToTensorV2(transpose_mask=True)])

increase_size1 = A.Compose([
    A.Resize(height=config.IMG_SIZE+32, width=config.IMG_SIZE+32),
    ToTensorV2(transpose_mask=True),])

reduce_size1 = A.Compose([
    A.Resize(height=config.IMG_SIZE-32, width=config.IMG_SIZE-32),
    ToTensorV2(transpose_mask=True),])

increase_size2 = A.Compose([
    A.Resize(height=config.IMG_SIZE+64, width=config.IMG_SIZE+64),
    ToTensorV2(transpose_mask=True),])

reduce_size2 = A.Compose([
    A.Resize(height=config.IMG_SIZE-64, width=config.IMG_SIZE-64),
    ToTensorV2(transpose_mask=True),])


# List of augmentations for TTA
tta_augs = [identity_trfm,
            horizontal_flip,
            vertical_flip,
            pixel_level_trfms1,
            pixel_level_trfms2,
            increase_size1,
            reduce_size1,
            increase_size2,
            reduce_size2]

# List of deaugmentations corresponding to the above augmentation list
tta_deaugs = [None,
              horizontal_flip,
              vertical_flip,
              None,
              None,
              None,
              None,
              None,
              None]

Organ thresholds under which the model performance is optimized:

In [None]:
organ_threshold = {
    'Hubmap': {
        'kidney'        : 0.40,
        'prostate'      : 0.40,
        'largeintestine': 0.40,
        'spleen'        : 0.40,
        'lung'          : 0.10,
    },
    'HPA': {
        'kidney'        : 0.50,
        'prostate'      : 0.50,
        'largeintestine': 0.50,
        'spleen'        : 0.50,
        'lung'          : 0.10,
    },
}

In [None]:
#set the SUBMIT variable to True in the config file to perform inference on the test images
if config.SUBMIT:

    #loading the model
    model = init_model().to(config.DEVICE)
    model.output_type = ["inference"]
    model.load_state_dict(torch.load(config.WEIGHTS_PATH),strict=False)
    model.float()
    model.eval()

    
    ids = []
    rles = []

    for idx, row in test_df.iterrows():

        image_id = row['id']
        organ = row['organ']
        data_source = row['data_source']
        image = tifffile.imread(config.TEST_IMG+str(image_id)+'.tiff')
        image_shape = image.shape[:2]
        
        if TTA:
            tta_pred = None
            for i, tta_aug in enumerate(tta_augs):
                
                augmentations = tta_aug(image=image)
                aug_img = augmentations["image"]
                
                x_tensor = aug_img.to(config.DEVICE).unsqueeze(0)
                pr_mask = model(x_tensor)['probability']
            
                if tta_deaugs[i] is not None:
                    pr_mask = (pr_mask.squeeze().cpu().detach().numpy())
                    pr_mask = tta_deaugs[i](image = image, mask = pr_mask)['mask']
                    pr_mask = pr_mask.unsqueeze(0)
                    pr_mask = pr_mask.unsqueeze(0)
                
                resize_image = torchvision.transforms.Resize(image_shape)
                resized_pr_mask = resize_image(pr_mask)  
                resized_pr_mask = (resized_pr_mask.squeeze().cpu().detach().numpy())

                if tta_pred is None:
                    tta_pred = resized_pr_mask
                else:       
                    tta_pred += resized_pr_mask
                    
            tta_pred = tta_pred / len(tta_augs) 
            threshold = organ_threshold[data_source][organ]
            th_mask = (tta_pred > threshold).astype(int)

        else:
            augmentations = val_transform(image=image)
            image = augmentations["image"]
            
            x_tensor = image.to(config.DEVICE).unsqueeze(0).float()
            pr_mask = model(x_tensor)['probability']
            resize_image = torchvision.transforms.Resize(image_shape)
            resized_pr_mask = resize_image(pr_mask)
            pr_mask = (resized_pr_mask.squeeze().cpu().detach().numpy())
            
            threshold = organ_threshold[data_source][organ]
            th_mask = (pr_mask > threshold).astype(int)

        rle = rle_encode(th_mask)
        ids.append(image_id)
        rles.append(rle)

    submission_df = pd.DataFrame({'id':ids,'rle':rles})
    submission_df.to_csv('submission.csv', index=False)