In [None]:
# for collab
# !pip install -q timm
# !pip install -q scikit-image==0.19.*

In [None]:
# download data here
# https://www.kaggle.com/competitions/hubmap-organ-segmentation/data

# pvt_v2_b4 weights can be downloded here
# !wget -q https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth -O ../pvt_v2_weights/pvt_v2_b4.pth

In [None]:
%load_ext autoreload
%autoreload 2

import os
import random
import numpy as np
import pandas as pd
import datetime
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
from skimage import io
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pvt_v2_kaggle as pvt_v2
import daformer
import rle_format



DATA_FOLDER = '../input/hubmap-organ-segmentation'

# FIX SEED
seed = 442
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
from skimage.transform import rescale

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# plots
plt.rcParams['figure.figsize'] = (20,10)

# treading
N_JOBS = 0

# model hyperparameters
H = 704
W = H


# train - hpa data
# public - hubmap data + hpa data
# private - hubmap data

### Train test split

In [None]:
from sklearn.model_selection import train_test_split

df = pd.read_csv(os.path.join(DATA_FOLDER, 'train.csv'))
train_df, val_df = train_test_split(df, test_size=0.15, random_state=seed, stratify=df['organ'])

train_ids, val_ids = train_df['id'].values, val_df['id'].values

### Augmentations

In [None]:
train_transform = A.Compose([
    A.Rotate(
        p=0.5,
        value=1,
        mask_value=0,
    ),
    A.RandomRotate90(p=0.5),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.OneOf([
        A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        A.GridDistortion(p=0.5),
        A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=0.5),
    ], p=0.5),
    A.OneOf([
        A.HueSaturationValue(10, 15, 10),
        A.CLAHE(clip_limit=4),
        A.RandomGamma(p=0.2),
        A.RandomBrightnessContrast(p=0.5),            
    ], p=0.5),      
])

In [None]:
resize_normalize = A.Compose([
    A.Resize(
        height=H, 
        width=W,
        p=1
    ),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])

mask_transform = A.Compose([
    ToTensorV2(),
])

### Dataset

In [None]:
hbmp_pix = {
    'kidney': 0.5,
    'largeintestine': 0.229,
    'lung': 0.7562,
    'spleen': 0.4945,
    'prostate': 6.263
}

class CustomDataset(Dataset):
    def __init__(
        self,
        idxs,
        main_transform,
        resize_normalize,
        mask_transform
    ):
        self.idxs = idxs
        self.df = pd.read_csv(os.path.join(DATA_FOLDER, 'train.csv')).set_index('id').loc[self.idxs]

        self.main_transform = main_transform
        self.resize_normalize = resize_normalize
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, idx):
        idx = self.idxs[idx]
        img_name = os.path.join(DATA_FOLDER, 'train_images', f'{idx}.tiff')

        image = io.imread(img_name)
        mask = rle_format.rle2mask(
            self.df.loc[idx, 'rle'], 
            shape=image.shape[:2]
        ) 

        scale = 0.4 / hbmp_pix[self.df.loc[idx, 'organ']]
        image = (rescale(image, scale, order=1, anti_aliasing=True, channel_axis=2) * 255).astype(np.uint8)

        transformed = self.main_transform(image=image, mask=mask)
        sample = {
            'pixel_values': self.resize_normalize(image=transformed['image'])['image'],
            'labels': self.mask_transform(image=transformed['mask'])['image'].float()
        }

        return sample

In [None]:
val_dataset = CustomDataset(
    idxs=val_ids,
    main_transform=A.Compose([]), 
    resize_normalize=resize_normalize,
    mask_transform=mask_transform
)

train_dataset = CustomDataset(
    idxs=train_ids,
    main_transform=train_transform, 
    resize_normalize=resize_normalize,
    mask_transform=mask_transform
)


TRAIN_BATCH_SIZE = 1
VAL_BATCH_SIZE = 1

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=N_JOBS)
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False,  num_workers=N_JOBS)

### Model

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder =  daformer.daformer_conv3x3(
            encoder_dim = [64, 128, 320, 512],
            decoder_dim = 320,
            dilation = None
        )
        
        self.logit = nn.Sequential(
            nn.Conv2d(320, 1, kernel_size=1)
        )
        
    def forward(self, x):
        x = self.decoder(x)[0]
        x = self.logit(x)
        return x


class EncDec(nn.Module):
    def __init__(
        self
    ):
        super(EncDec, self).__init__()
        encoder=pvt_v2.pvt_v2_b4()
        # pvt_v2_b4 weights can be downloded here
        # https://github.com/whai362/PVT/releases/tag/v2
        encoder.load_state_dict(torch.load('../pvt_v2_weights/pvt_v2_b4.pth'))
        self.encoder = encoder
        self.decoder = Decoder()
        self.upsamle = nn.Upsample(
            scale_factor=4,
            mode='nearest'
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        
        return self.upsamle(x)

In [None]:
class IoULoss(nn.Module):
    def __init__(self):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return 1 - IoU

    
def dice_metric(outputs, mask):
    interpol_out = (F.interpolate(outputs, (mask.shape[-2], mask.shape[-1]))[0] > 0.5).int()
    return (2 * (interpol_out * mask).sum() / (interpol_out.sum() +  mask.sum())).item()

In [None]:
EPOCHS = 500
model = EncDec()

model.to(device)
loss_fn = IoULoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
NUM_ACCUMULATION_IMAGES = 16


train_losses = []
val_losses = []
val_losses_bce = []
start_time = datetime.datetime.now()
accum_im = 0

In [None]:
for epoch in range(EPOCHS):
    # TRAIN
    running_loss = []
    val_loss = 0
    model.train()
    for i, data in enumerate(train_loader):
        inputs, mask = data['pixel_values'], data['labels']
        
        outputs = model(inputs.to(device))
        outputs = F.interpolate(outputs, mask.shape[-2:])
        
        batch_images = outputs.shape[0]
        loss = loss_fn(
            outputs,
            mask.to(device)
        )
        loss = loss * batch_images / NUM_ACCUMULATION_IMAGES
        
        loss.backward()
        
        accum_im += batch_images
        if accum_im >= NUM_ACCUMULATION_IMAGES:
            accum_im = 0
            optimizer.step()
            optimizer.zero_grad()

        running_loss.append(loss.item())
        torch.cuda.empty_cache()
        
    scheduler.step()
    train_losses.append(np.mean(running_loss))
    
    # VAL
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            inputs, mask = data['pixel_values'], data['labels']
            outputs = model(inputs.to(device))
            
            val_loss += dice_metric(
                torch.sigmoid(outputs),
                mask.to(device)
            )
            torch.cuda.empty_cache()

    val_losses.append(val_loss / len(val_ids))
    
    # SAVE BEST MODEL
    if np.argmax(val_losses) == (len(val_losses) - 1):
        path = os.path.join(
            '..', 'model_checkpoints',
            f'pvt_v2_b4_{epoch}_{H}_{loss_fn.__class__.__name__}_{val_losses[-1]}.pth'
        )
        torch.save(model.state_dict(), path)
    
    # OUTPUT / PLOT
    clear_output()
    print(datetime.datetime.now() - start_time, val_losses[-1], max(val_losses))
    
    with torch.no_grad():
        test_img_input = io.imread(f'{DATA_FOLDER}/test_images/10078.tiff')
        test_img = resize_normalize(image=test_img_input)['image'].unsqueeze(0).to(device)
        test_img_out = torch.sigmoid(model(test_img)).cpu()[0][0]

    plt.subplot(221)
    plt.grid()
    plt.plot(train_losses)

    plt.subplot(222)
    plt.grid()
    plt.plot(val_losses)

    plt.subplot(223)
    io.imshow(test_img_out.numpy())

    plt.subplot(224)
    io.imshow(test_img_input)

    plt.show()