In [None]:
import os
import pandas as pd
import numpy as np

import warnings
warnings.filterwarnings('ignore')

import cv2

import matplotlib.pyplot as plt

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from apex import amp

from torchcontrib.optim import SWA

from fastai.vision.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, CyclicLR
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, Dataset

train_path = './input/train'
mask_path = './input/masks'

print('Image: ', len(os.listdir(train_path)))
print('Mask: ', len(os.listdir(mask_path)))
print('Total: ', len(os.listdir(mask_path)) + len(os.listdir(train_path)))

In [None]:
d = {'image_id': os.listdir(train_path), 'mask_id': os.listdir(mask_path)}
df = pd.DataFrame(data=d)
print(df.shape)
df.head()

In [None]:
class HuBMAPDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
  
    def __len__(self):
        return len(self.df)
  
    def __getitem__(self, index):
        image_id = self.df['image_id'].values[index]
        image_path = f'./input/train/{image_id}'
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask_id = self.df['mask_id'].values[index]
        mask_path = f'./input/masks/{mask_id}'
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        image = image.astype(np.float32)
        image /= 255.0
        image = image.transpose(2, 0, 1)

        return torch.tensor(image), torch.tensor(mask)

In [None]:
transforms_train = A.Compose([                            
  A.HorizontalFlip(p=0.5),
  A.VerticalFlip(p=0.5),
  A.Transpose(p=0.5),
  A.RandomRotate90(p=0.5),
  A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, 
                         border_mode=cv2.BORDER_REFLECT),

  A.IAAAdditiveGaussianNoise(p=0.2),
  A.IAAPerspective(p=0.5),

  A.OneOf(
      [
          A.CLAHE(p=1),
          A.RandomBrightness(p=1),
          A.RandomGamma(p=1),
      ],
      p=0.9,
  ),

  A.OneOf(
      [
          A.IAASharpen(p=1),
          A.Blur(blur_limit=3, p=1),
          A.MotionBlur(blur_limit=3, p=1),
      ],
      p=0.9,
  ),

  A.OneOf(
      [
          A.RandomContrast(p=1),
          A.HueSaturationValue(p=1),
      ],
      p=0.9,
  ),
  
  A.Compose([
      A.VerticalFlip(p=0.5),              
      A.RandomRotate90(p=0.5)]
  )

])


transforms_valid = A.Compose([                 

])

In [None]:
train_image = HuBMAPDataset(df[:1000].reset_index(drop=True), transforms=transforms_train)

import matplotlib.pyplot as plt
from pylab import rcParams
rcParams['figure.figsize'] = 20, 10
for i in range(2):
    f, axarr = plt.subplots(1, 5)
    for p in range(5):
        idx = np.random.randint(0, len(train_image))
        img, mask = train_image[idx]
        axarr[p].imshow(img.transpose(0,1).transpose(1, 2))

In [None]:
def dice_coef_metric(probabilities: torch.Tensor,
                     truth: torch.Tensor,
                     treshold: float = 0.5,
                     eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Dice score for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: dice score aka f1.
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)
    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = 2.0 * (truth_ * prediction).sum()
        union = truth_.sum() + prediction.sum()
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)


def jaccard_coef_metric(probabilities: torch.Tensor,
               truth: torch.Tensor,
               treshold: float = 0.5,
               eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Jaccard index for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: jaccard score aka iou."
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)

    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = (prediction * truth_).sum()
        union = (prediction.sum() + truth_.sum()) - intersection + eps
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)

In [None]:
class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1., dims=(-2,-1)):

        super(SoftDiceLoss, self).__init__()
        self.smooth = smooth
        self.dims = dims
    
    def forward(self, x, y):

        tp = (x * y).sum(self.dims)
        fp = (x * (1 - y)).sum(self.dims)
        fn = ((1 - x) * y).sum(self.dims)
        
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
        dc = dc.mean()

        return 1 - dc

In [None]:
bce_fn = nn.BCEWithLogitsLoss()
dice_fn = SoftDiceLoss()

def loss_fn(y_pred, y_true):
    bce = bce_fn(y_pred, y_true)
    dice = dice_fn(y_pred.sigmoid(), y_true)
    return 0.8*bce+ 0.2*dice

In [None]:
def train_loop_fn(model, loader, optimizer, loss_func, dice_coef_metric, jaccard_coef_metric, device):

    model.train()

    losses = []
    dice = []
    jaccard = []

    for (data, target) in tqdm(loader):
        data = data.to(device, dtype=torch.float)
        target = target.to(device, dtype=torch.float)

        optimizer.zero_grad()

        outputs = model(data)
        probs = torch.sigmoid(outputs)

        loss = loss_func(outputs, target.unsqueeze(1))

        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()

        dice_scores = dice_coef_metric(probs.squeeze(1).detach().cpu(), target.detach().cpu(), 0.5)
        jaccard_scores = jaccard_coef_metric(probs.squeeze(1).detach().cpu(), target.detach().cpu(), 0.5)


        optimizer.step()

        losses.append(loss.item())
        dice.append(dice_scores)
        jaccard.append(jaccard_scores)

    return np.array(losses).mean(), np.array(dice).mean(), np.array(jaccard).mean()


def val_loop_fn(model, loader, optimizer, loss_func, dice_coef_metric, jaccard_coef_metric, device):

    model.eval()

    losses = []
    dice = []
    jaccard = []

    with torch.no_grad():
        for (data, target) in tqdm(loader):
            data = data.to(device, dtype=torch.float)
            target = target.to(device, dtype=torch.float)

            outputs = model(data)
            probs = torch.sigmoid(outputs)

            loss = loss_func(outputs, target.unsqueeze(1))

            dice_scores = dice_coef_metric(probs.squeeze(1).detach().cpu(), target.detach().cpu(), 0.5)
            jaccard_scores = jaccard_coef_metric(probs.squeeze(1).detach().cpu(), target.detach().cpu(), 0.5)

            losses.append(loss.item())
            dice.append(dice_scores)
            jaccard.append(jaccard_scores)

    return np.array(losses).mean(), np.array(dice).mean(), np.array(jaccard).mean()

In [None]:
# https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter

class FPN(nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch*2),
             nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])
        
    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i),mode='bilinear') 
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

class UnetBlock(Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, up_in:Tensor, left_in:Tensor) -> Tensor:
        s = left_in
        up_out = self.shuf(up_in)
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))
        
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, groups=1):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                stride=1, padding=padding, dilation=dilation, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[6, 12, 18, 24], out_c=None):
        super().__init__()
        self.aspps = [_ASPPModule(inplanes, mid_c, 1, padding=0, dilation=1)] + \
            [_ASPPModule(inplanes, mid_c, 3, padding=d, dilation=d,groups=4) for d in dilations]
        self.aspps = nn.ModuleList(self.aspps)
        self.global_pool = nn.Sequential(nn.AdaptiveMaxPool2d((1, 1)),
                        nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
                        nn.BatchNorm2d(mid_c), nn.ReLU())
        out_c = out_c if out_c is not None else mid_c
        self.out_conv = nn.Sequential(nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False),
                                    nn.BatchNorm2d(out_c), nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False)
        self._init_weight()

    def forward(self, x):
        x0 = self.global_pool(x)
        xs = [aspp(x) for aspp in self.aspps]
        x0 = F.interpolate(x0, size=xs[0].size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x0] + xs, dim=1)
        return self.out_conv(x)
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [None]:
# https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter

class UneXt50(nn.Module):
    def __init__(self, stride=1, **kwargs):
        super().__init__()
        #encoder
        m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models',
                           'resnext50_32x4d_ssl')
        self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.enc1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,256,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4])
        self.drop_aspp = nn.Dropout2d(0.5)
        #decoder
        self.dec4 = UnetBlock(512,1024,256)
        self.dec3 = UnetBlock(256,512,128)
        self.dec2 = UnetBlock(128,256,64)
        self.dec1 = UnetBlock(64,64,32)
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = nn.Dropout2d(0.1)
        self.final_conv = ConvLayer(32+16*4, 1, ks=1, norm_type=None, act_cls=None)
        
    def forward(self, x):
        enc0 = self.enc0(x)
        enc1 = self.enc1(enc0)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3)
        dec2 = self.dec3(dec3,enc2)
        dec1 = self.dec2(dec2,enc1)
        dec0 = self.dec1(dec1,enc0)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        x = F.interpolate(x,scale_factor=2,mode='bilinear')
        return x

#split the model to encoder and decoder for fast.ai
split_layers = lambda m: [list(m.enc0.parameters())+list(m.enc1.parameters())+
                list(m.enc2.parameters())+list(m.enc3.parameters())+
                list(m.enc4.parameters()),
                list(m.aspp.parameters())+list(m.dec4.parameters())+
                list(m.dec3.parameters())+list(m.dec2.parameters())+
                list(m.dec1.parameters())+list(m.fpn.parameters())+
                list(m.final_conv.parameters())]

In [None]:
import segmentation_models_pytorch as smp

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = UneXt50().to(device)

n_epochs = 60
loss_func = loss_fn
base_opt = optim.Adam(model.parameters(), lr=3e-4)
optimizer = SWA(base_opt)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
scheduler = CosineAnnealingLR(optimizer, n_epochs)

In [None]:
from sklearn.model_selection import KFold

folds = df.copy()
kf = KFold(n_splits=4, shuffle=True, random_state=42)

N_FOLDS = 4

for fold, (train_idx, valid_idx) in enumerate(kf.split(folds)):

    print(f'FOLD: {fold+1}/{N_FOLDS}')

    train_test = folds.iloc[train_idx]
    valid_test = folds.iloc[valid_idx]

    train_test.reset_index(drop=True, inplace=True)
    valid_test.reset_index(drop=True, inplace=True)

    train_dataset = HuBMAPDataset(train_test, transforms=transforms_train)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)

    valid_dataset = HuBMAPDataset(valid_test, transforms=transforms_valid)
    valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=4)

    loss_history = {
      "train": [],
      "valid": []
    }

    dice_history = {
      "train": [],
      "valid": []
    }

    jaccard_history = {
      "train": [],
      "valid": []
    }

    dice_max = 0.0
    kernel_type = 'unext50'
    best_file = f'./{kernel_type}_best_fold{fold}_strong_aug_60_epochs.bin'

    for epoch in range(n_epochs):
      
        scheduler.step(epoch)
        avg_train_loss, train_dice_scores, train_jaccard_scores = train_loop_fn(model, 
                                                                              train_loader, 
                                                                              optimizer, 
                                                                              loss_fn,
                                                                              dice_coef_metric, 
                                                                              jaccard_coef_metric,
                                                                              device)
      
        if epoch > 30 and epoch % 5 == 0:
            optimizer.update_swa()
      
        loss_history["train"].append(avg_train_loss)
        dice_history["train"].append(train_dice_scores)
        jaccard_history["train"].append(train_jaccard_scores)


        avg_val_loss, val_dice_scores, val_jaccard_scores = val_loop_fn(model, 
                                                                      valid_loader, 
                                                                      optimizer, 
                                                                      loss_fn,
                                                                      dice_coef_metric, 
                                                                      jaccard_coef_metric,
                                                                      device)

        loss_history["valid"].append(avg_val_loss)
        dice_history["valid"].append(val_dice_scores)
        jaccard_history["valid"].append(val_jaccard_scores)

        print(f"Epoch: {epoch+1} | lr: {optimizer.param_groups[0]['lr']:.7f} | train loss: {avg_train_loss:.4f} | val loss: {avg_val_loss:.4f}")
        print(f"train dice: {train_dice_scores:.4f} | val dice: {val_dice_scores:.4f} | train jaccard: {train_jaccard_scores:.4f} | val jaccard: {val_jaccard_scores:.4f}")

        if val_dice_scores > dice_max:
            print('score2 ({:.6f} --> {:.6f}).  Saving model ...'.format(dice_max, val_dice_scores))
            torch.save(model.state_dict(), best_file)
            dice_max = val_dice_scores

    optimizer.swap_swa_sgd()

In [None]:
import matplotlib.pyplot as plt

from pylab import rcParams
rcParams['figure.figsize'] = 10, 5

plt.title("Train-Val Loss")
plt.plot(range(1,n_epochs+1),loss_history["train"],label="train")
plt.plot(range(1,n_epochs+1),loss_history["valid"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

plt.title("Dice Train-Val Score")
plt.plot(range(1,n_epochs+1),dice_history["train"],label="train")
plt.plot(range(1,n_epochs+1),dice_history["valid"],label="val")
plt.ylabel("Score")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

plt.title("Jaccard Train-Val Score")
plt.plot(range(1,n_epochs+1),jaccard_history["train"],label="train")
plt.plot(range(1,n_epochs+1),jaccard_history["valid"],label="val")
plt.ylabel("Score")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()