In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Model

## HED

In [None]:
class HED(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.line1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True)
        )

        self.line2 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True)
        )

        self.line3 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True)
        )

        self.line4 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True)
        )

        self.line5 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=True)
        )

        self.line1_out = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
        self.line2_out = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
        self.line3_out = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
        self.line4_out = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
        self.line5_out = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)

        self.output = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
            torch.nn.Sigmoid()
        )

    def forward(self, X):
        X = X * 255.0

        X1 = self.line1(X)
        X2 = self.line2(X1)
        X3 = self.line3(X2)
        X4 = self.line4(X3)
        X5 = self.line5(X4)

        output_1 = self.line1_out(X1)
        output_2 = self.line2_out(X2)
        output_3 = self.line3_out(X3)
        output_4 = self.line4_out(X4)
        output_5 = self.line5_out(X5)

        output_1 = torch.nn.functional.interpolate(input=output_1, size=(X.shape[2], X.shape[3]), mode='bilinear', align_corners=False)
        output_2 = torch.nn.functional.interpolate(input=output_2, size=(X.shape[2], X.shape[3]), mode='bilinear', align_corners=False)
        output_3 = torch.nn.functional.interpolate(input=output_3, size=(X.shape[2], X.shape[3]), mode='bilinear', align_corners=False)
        output_4 = torch.nn.functional.interpolate(input=output_4, size=(X.shape[2], X.shape[3]), mode='bilinear', align_corners=False)
        output_5 = torch.nn.functional.interpolate(input=output_5, size=(X.shape[2], X.shape[3]), mode='bilinear', align_corners=False)

        return self.output(torch.cat([output_1, output_2, output_3, output_4, output_5], 1))

## Edge UNET

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0):
        super(CNNBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )
        
    def forward(self, X):
        return self.block(X)
    
class DoubleCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0):
        super(DoubleCNNBlock, self).__init__()
        self.block1 = CNNBlock(in_channels, out_channels, kernel_size, stride, padding)
        self.block2 = CNNBlock(out_channels, out_channels, kernel_size, stride, padding)
        
    def forward(self, X):
        X = self.block1(X)
        X = self.block2(X)
        return X

In [None]:
class Expansion(nn.Module):
    def __init__(self, in_channels, ratio, kernel_size=1, stride=1, padding=0):
        super(Expansion, self).__init__()
        self.out_channels = in_channels * ratio
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, self.out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(self.out_channels),
            nn.LeakyReLU(inplace=True)
        )
        
    def forward(self, X):
        return self.block(X)
    
class DepthwiseCNN(nn.Module):
    def __init__(self, in_channels, kernel_size=3, stride=1, padding=0):
        super(DepthwiseCNN, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels),
            nn.BatchNorm2d(in_channels),
            nn.LeakyReLU(inplace=True)
        )
        
    def forward(self, X):
        return self.block(X)
    
class Compression(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super(Compression, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels)
        )
        
    def forward(self, X):
        return self.block(X)
    
class MBConv(nn.Module):
    def __init__(self, in_channels, out_channels, ratio, kernel_size, stride=1, padding=1):
        super(MBConv, self).__init__()
        self.expanssion_block = Expansion(in_channels, ratio)
        self.depthwise_cnn_block = DepthwiseCNN(in_channels * ratio, kernel_size=3, padding=1)
        self.compression_block = Compression(in_channels * ratio, out_channels)
        
    def forward(self, X):
        X = self.expanssion_block(X)
        X = self.depthwise_cnn_block(X)
        X = self.compression_block(X)
        return X

In [None]:
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(DownSample, self).__init__()
        self.double_cnn_block = DoubleCNNBlock(in_channels, out_channels, kernel_size, stride, padding)
        self.max_pool_layer = nn.MaxPool2d(2, 2)
        self.dropout_layer = nn.Dropout2d(p=0.3, inplace=True)
        
    def forward(self, X):
        out = self.double_cnn_block(X)
        X = self.max_pool_layer(out)
        X = self.dropout_layer(X)
        return X, out
    
class DownSampleMBConv(nn.Module):
    def __init__(self, in_channels, out_channels, ratio, kernel_size, stride=1, padding=1):
        super(DownSampleMBConv, self).__init__()
        self.mbconv_block = MBConv(in_channels, out_channels, ratio, kernel_size)
        self.max_pool_layer = nn.MaxPool2d(2, 2)
        self.dropout_layer = nn.Dropout2d(p=0.3, inplace=True)
        
    def forward(self, X):
        out = self.mbconv_block(X)
        X = self.max_pool_layer(out)
        X = self.dropout_layer(X)
        return X, out
    
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(UpSample, self).__init__()
        self.conv_layer = nn.Conv2d(in_channels, in_channels//2, kernel_size=1, padding=0)
        self.dropout_layer = nn.Dropout2d(p=0.3)
        self.conv_in = (in_channels * 3) // 2
        self.double_conv_block = DoubleCNNBlock(self.conv_in, out_channels, padding=1)
        
    def forward(self, X, out, E):
        X = F.interpolate(X, out.shape[2:], mode='bilinear', align_corners=False)
        X = self.conv_layer(X)
        mul = torch.mul(out, E)
        X = torch.concat([mul, out, X], dim=1)
        X = self.dropout_layer(X)
        X = self.double_conv_block(X)
        return X

In [None]:
class Encoder(nn.Module):
    def __init__(self, filters):
        super(Encoder, self).__init__()
        self.filters = filters
        self.down_sample_block1 = DownSample(1, filters[0])
        self.down_sample_block2 = DownSampleMBConv(filters[0], filters[1], 2, kernel_size=3)
        self.down_sample_block3 = DownSampleMBConv(filters[1], filters[2], 4, kernel_size=3)
        self.down_sample_block4 = DownSampleMBConv(filters[2], filters[3], 4, kernel_size=3)
        self.bottleneck_block = DoubleCNNBlock(filters[3], filters[4], padding=1)
        
    def forward(self, X):
        X, out1 = self.down_sample_block1(X)
        X, out2 = self.down_sample_block2(X)
        X, out3 = self.down_sample_block3(X)
        X, out4 = self.down_sample_block4(X)
        bottleneck = self.bottleneck_block(X)
        return out1, out2, out3, out4, bottleneck
    
class Decoder(nn.Module):
    def __init__(self, filters):
        super(Decoder, self).__init__()
        self.filters = filters
        self.up_sample_block4 = UpSample(filters[4], filters[3])
        self.up_sample_block3 = UpSample(filters[3], filters[2])
        self.up_sample_block2 = UpSample(filters[2], filters[1])
        self.up_sample_block1 = UpSample(filters[1], filters[0])
        
    def forward(self, out1, out2, out3, out4, E1, E2, E3, E4, bottleneck):
        X = self.up_sample_block4(bottleneck, out4, E4)
        X = self.up_sample_block3(X, out3, E3)
        X = self.up_sample_block2(X, out2, E2)
        X = self.up_sample_block1(X, out1, E1)
        return X

In [None]:
class EDGEUnet(nn.Module):
    def __init__(self, num_classes, encoder_filters, decoder_filters):
        super(EDGEUnet, self).__init__()
        self.num_classes = num_classes
        self.encoder_filters = encoder_filters
        self.decoder_filters = decoder_filters
        self.hed_model = HED()
        self.encoder = Encoder(encoder_filters)
        self.decoder = Decoder(decoder_filters)
        self.output_conv_layer = nn.Conv2d(decoder_filters[0], num_classes, kernel_size=1)
        
    def forward(self, X):
        edge = self.hed_model(X)
        out1, out2, out3, out4, bottleneck = self.encoder(X)
        
        E1 = F.interpolate(edge, out1.shape[2:], mode='bilinear', align_corners=False)
        E2 = F.interpolate(edge, out2.shape[2:], mode='bilinear', align_corners=False)
        E3 = F.interpolate(edge, out3.shape[2:], mode='bilinear', align_corners=False)
        E4 = F.interpolate(edge, out4.shape[2:], mode='bilinear', align_corners=False)
        
        X = self.decoder(out1, out2, out3, out4, E1, E2, E3, E4, bottleneck)
        return self.output_conv_layer(X)

In [None]:
device = torch.device("cuda")
filters = (64, 128, 256, 512, 1024)
NUM_CLASSES = 3
model = EDGEUnet(NUM_CLASSES, filters, filters).to(device)
model.load_state_dict(torch.load("/kaggle/input/edge-unet/last_epoch-00.bin"))

# -

In [None]:
import pandas as pd
import numpy as np
import gc
import copy
from collections import defaultdict

df = pd.read_csv("/kaggle/input/gitractcsv/data.csv")
df.head()

In [None]:
from PIL import Image

def load_img(path):
    img = Image.open(path)
    img = np.expand_dims(np.array(img), axis=-1).astype('float32')
    mx = np.max(img)
    if mx:
        img /= mx
    return img

def load_mask(path):
    mask = np.load(path).astype('float32')
    mask /= 255.
    return mask

In [None]:
class UWMGIDataset(torch.utils.data.Dataset):
    def __init__(self, df, label=True, transforms=None):
        self.df = df
        self.label = label
        self.transforms = transforms
        
    def __getitem__(self, i):
        row = df.loc[i]
        img = load_img(row.image_path)
        
        if self.label:
            mask = load_mask(row.mask_path)
            if self.transforms:
                data = self.transforms(image=img, mask=mask)
                img, mask = data['image'], data['mask']
            img = np.transpose(img, (-1, 0, 1))
            mask = np.transpose(mask, (-1, 0, 1))
            return torch.tensor(img), torch.tensor(mask)
        else:
            if self.transforms:
                data = self.transforms(image=img)
                img = data['image']
            img = np.transpose(img, (-1, 0, 1))
            return torch.tensor(img)
        
    def __len__(self):
        return len(self.df)

In [None]:
import albumentations as A
import cv2

IMG_SIZE = [320, 384]

data_transforms = {
    "train": None,
    "valid": None
}

In [None]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
BATCH_SIZE = 16

train_df, valid_df = train_test_split(df, test_size=.2, random_state=42)
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)

train_dataset = UWMGIDataset(train_df, label=True, transforms=data_transforms["train"])
valid_dataset = UWMGIDataset(valid_df, label=True, transforms=data_transforms["valid"])

train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, BATCH_SIZE*2)

In [None]:
!pip -qq install segmentation_models_pytorch

import segmentation_models_pytorch as smp

In [None]:
DiceLoss    = smp.losses.DiceLoss(mode='multilabel')
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()

def criterion(y_pred, y_true):
    return 0.5*BCELoss(y_pred, y_true) + 0.5*DiceLoss(y_pred, y_true)

In [None]:
N_ACCUMULATIONS = 2
from tqdm import tqdm
from torch.cuda import amp

def train(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    scaler = amp.GradScaler()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks) in pbar:         
        images = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        with amp.autocast(enabled=True):
            y_pred = model(images)
            loss   = criterion(y_pred, masks)
            loss   = loss / N_ACCUMULATIONS
            
        scaler.scale(loss).backward()
    
        if (step + 1) % N_ACCUMULATIONS == 0:
            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad()

            if scheduler is not None:
                scheduler.step()
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss

In [None]:
@torch.no_grad()
def valid(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    val_scores = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid')
    for step, (images, masks) in pbar:        
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        y_pred  = model(images)
        loss    = criterion(y_pred, masks)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        y_pred = nn.Sigmoid()(y_pred)
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        val_jaccard = iou_coef(masks, y_pred).cpu().detach().numpy()
        val_scores.append([val_dice, val_jaccard])
        
    val_scores  = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss, val_scores

In [None]:
def full_training(model, optimizer, scheduler, device, num_epochs):
    
    
    best_model_wts = None
    best_dice      = -np.inf
    best_epoch     = -1
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        train_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_dataloader, 
                                           device=device, epoch=epoch)
        
        val_loss, val_scores = valid_one_epoch(model, valid_dataloader, 
                                                 device=device, 
                                                 epoch=epoch)
        val_dice, val_jaccard = val_scores
        
        if val_dice >= best_dice:
            best_dice    = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_epoch-{fold:02d}.bin"
            torch.save(model.state_dict(), PATH)
            
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f"last_epoch-{fold:02d}.bin"
        torch.save(model.state_dict(), PATH)
        
    
    model.load_state_dict(best_model_wts)
    
    return model, history

In [None]:
from torch.optim import lr_scheduler
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-6)
EPOCHS = 5

In [None]:
full_training(model, optimizer, scheduler=None, device=device, num_epochs=EPOCHS)