### Installation and package loading

In [None]:
%%capture
import torch
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import albumentations as A
from albumentations.pytorch import ToTensorV2

import tqdm.notebook as tqdm

!pip install timm
import timm

import cv2
import os
import random
import math
import sys
import copy

import numpy
import numpy as np 
import pandas as pd
import lovask
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# imports
!pip install lycon
import lycon

!pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp
from fastai.vision.all import *
from torchvision.models.resnet import ResNet, Bottleneck
import glob

### Config

In [None]:
class CONFIG():
    
    image_size = 512
    if image_size == 512:
        train_path_images = "../input/hubmap-512x512/train/"
        train_path_masks = '../input/hubmap-512x512/masks/'

        pseudo_path_images = '../input/512x512ppseudo/test/'
        pseudo_path_masks = '../input/512x512ppseudo/masks/'

        external_path_images = '../input/external-512x512/images/images/'
        external_path_masks = '../input/external-512x512/masks/masks/'


    else:
        train_path_images = "../input/hubmap-256x256/train/"
        train_path_masks = '../input/hubmap-256x256/masks/'

        pseudo_path_images = '../input/512x512-pseudo/test/'
        pseudo_path_masks = '../input/512x512-pseudo/masks/'

        external_path_images = '../input/external-data/images/images/'
        external_path_masks = '../input/external-data/masks/masks/'

    info_path = "../input/hubmap-kidney-segmentation/train.csv"
    
    
    train_samples = -1
    pseudo_samples = -1
    external_samples = -1
    
    files = pd.read_csv(info_path).id.values
    mean_train = np.array([0.63701495, 0.4709702, 0.6817423])
    std_train = np.array([0.15978882, 0.2245109, 0.14173926])
    stats = (mean_train, std_train)
    #stats = imagenet_stats
    n_folds = 5 # 2 Kidneys in Validation, 6 Kidneys train(+ pseudo and External)
    
    use_pseudo = True
    use_external = True
cfg = CONFIG()

In [None]:
# Reproducibility:
def seed_all():
    seed = 42
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    # Slight Stochasticity Tradeoff for Quicker Comp.
    torch.backends.cudnn.benchmark = True # True for faster
    pl.seed_everything()
    set_seed(42, True)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
seed_all()

train_transforms = A.Compose([
    A.OneOf([
        A.RandomBrightness(limit=.2, p=1), 
        A.RandomContrast(limit=.2, p=1), 
        A.RandomGamma(p=1)
    ], p=.5),
    A.OneOf([
        A.Blur(blur_limit=3, p=1),
        A.MedianBlur(blur_limit=3, p=1)
    ], p=.25),
    A.OneOf([
        A.GaussNoise(0.002, p=.5),
        A.IAAAffine(p=.5),
    ], p=.25),
    A.OneOf([
            A.ElasticTransform(alpha=120, sigma=120 * .05, alpha_affine=120 * .03, p=.5),
            A.GridDistortion(p=.5),
            A.OpticalDistortion(distort_limit=2, shift_limit=.5, p=1)                  
    ], p=.25),
    A.RandomRotate90(p=.5),
    A.HorizontalFlip(p=.5),
    A.VerticalFlip(p=.5),
    A.Cutout(num_holes=10, 
                max_h_size=int(.1 * cfg.image_size), max_w_size=int(.1 * cfg.image_size), 
                p=.25),
    A.ShiftScaleRotate(p=.5)
])    
test_transforms = A.Compose([
    A.Normalize(mean = cfg.stats[0], std = cfg.stats[1]),
    ToTensorV2()
])

In [None]:
def get_folds():
    files = np.array(os.listdir(cfg.train_path_images))
    n_folds = cfg.n_folds
    splitter = KFold(n_splits = n_folds, shuffle = True, random_state = 42)
    FOLDS = []
    for train, test in splitter.split(files):
        FOLDS += [(files[train], files[test])]
    return FOLDS
FOLDS = get_folds()

# Data

In [None]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, files):
        self.files = files
        self.num_samples = cfg.train_samples
        self.actual_length = len(self.files)
        if self.num_samples == -1:
            self.num_samples = self.actual_length
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        if self.num_samples != self.actual_length:
            idx = random.randint(0, self.actual_length - 1)
        file = self.files[idx]
        image_file = f"{cfg.train_path_images}{file}"
        mask_file = f"{cfg.train_path_masks}{file}"
        
        image = lycon.load(image_file)
        mask = lycon.load(mask_file)[:, :, 0]
        
        twos = mask == 2
        mask[twos] = 1
        
        augmented = train_transforms(image = image, mask = mask)
        image = test_transforms(image = augmented['image'])['image']
        mask = augmented['mask']
        return image, mask
class ValDataset(torch.utils.data.Dataset):
    def __init__(self, files):
        self.files = files
        self.actual_length = len(self.files)
    def __len__(self):
        return self.actual_length
    def __getitem__(self, idx):
        file = self.files[idx]
        image_file = f"{cfg.train_path_images}{file}"
        mask_file = f"{cfg.train_path_masks}{file}"
        
        image = lycon.load(image_file)
        mask = lycon.load(mask_file)[:, :, 0]
        
        twos = mask == 2
        mask[twos] = 1
        
        image = test_transforms(image = image)['image']
        return image, mask 
class OtherDataset(torch.utils.data.Dataset):
    def __init__(self, file_base_images, file_base_masks, num_samples):
        self.file_base_images = file_base_images
        self.file_base_masks = file_base_masks
        
        self.num_samples = num_samples
        
        self.all_files = os.listdir(self.file_base_masks)
        self.actual_length = len(self.all_files)
        if self.num_samples == -1:
            self.num_samples = self.actual_length
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        if self.num_samples != self.actual_length:
            idx = random.randint(0, self.actual_length - 1) 
        file = self.all_files[idx]
        image_file = f"{self.file_base_images}{file}"
        mask_file = f"{self.file_base_masks}{file}"
        
        image = lycon.load(image_file)
        mask = lycon.load(mask_file)[:, :, 0]
        
        twos = mask == 2
        mask[twos] = 1
        
        augmented = train_transforms(image = image, mask = mask)
        image = test_transforms(image = augmented['image'])['image']
        mask = augmented['mask']
        
        return image, mask
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, files):
        self.train_dataset = TrainDataset(files)
        self.pseudo_dataset = OtherDataset(cfg.pseudo_path_images, cfg.pseudo_path_masks, cfg.pseudo_samples)
        self.external_dataset = OtherDataset(cfg.external_path_images, cfg.external_path_masks, cfg.external_samples)
        self.use_external = cfg.use_external
        self.use_pseudo = cfg.use_pseudo
    def __len__(self):
        length = len(self.train_dataset)
        if cfg.use_external:
            length += len(self.external_dataset)
        if self.use_pseudo:
            length += len(self.pseudo_dataset)
        return length
    def __getitem__(self, idx):
        if idx < len(self.train_dataset):
            return self.train_dataset.__getitem__(idx)
        elif idx < len(self.external_dataset) + len(self.train_dataset):
            return self.external_dataset.__getitem__(idx - len(self.train_dataset))
        else:
            return self.pseudo_dataset.__getitem__(idx - len(self.train_dataset) - len(self.external_dataset))
class DataModule:
    @classmethod
    def get_both(cls, idx):
        train, val = FOLDS[idx]
        train_dataset = ConcatDataset(train)
        val_dataset = ValDataset(val)
        return train_dataset, val_dataset

In [None]:
def display_image_np(image):
    plt.imshow(image)
    plt.show()
def display_image(image):
    plt.imshow(image.cpu().transpose(0, 1).transpose(1, 2))
    plt.show() 

Encoder

In [None]:
def initialize_weights(layer):
    # More Optimal Initialization for CNNs
    for m in layer.modules():
        if isinstance(m, nn.Conv2d):
            # Kaiming + ReLU
            nn.init.kaiming_normal_(m.weight, nonlinearity = 'relu')
        elif isinstance(m, nn.BatchNorm2d):
            # 1's and 0's
            m.weight.data.fill_(1)
            m.bias.data.zero_()

In [None]:
class Mish(pl.LightningModule):
    # Mish activation, can act as a drop in replacement.
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))
def replace_all(model):
    for child_name, child in model.named_children():
        if isinstance(child, (nn.ReLU, nn.SiLU, timm.models.layers.activations.Swish)):
            setattr(model, child_name, Mish())
        else:
            replace_all(child)
class Act(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.act_type = Config.act
        if self.act_type == 'silu':
            self.act = nn.SiLU(inplace = True)
        elif self.act_type == 'mish':
            self.act = Mish()
        else:
            self.act = nn.ReLU(inplace = True)
    def forward(self, x):
        return self.act(x)
class ConvBlock(pl.LightningModule):
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, bias = False)
        self.bn = nn.BatchNorm2d(out_features)
        self.act1 = Act()
        initialize_weights(self)
    def forward(self, x):
        return self.bn(self.act1(self.conv(x)))
class SqueezeExcite(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.Squeeze = nn.Linear(self.in_features, self.inner_features)
        self.act1 = Act()
        self.Excite = nn.Linear(self.inner_features, self.in_features)
    def forward(self, x):
        mean = torch.mean(x, dim = -1)
        mean = torch.mean(mean, dim = -1)
        
        squeeze = self.act1(self.Squeeze(mean))
        excite = torch.sigmoid(self.Excite(squeeze)).unsqueeze(-1).unsqueeze(-1)
        return excite * x  
class SCSE(pl.LightningModule):
    # Spatial Channel Squeeze Excite
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features  = in_features
        self.inner_features = inner_features
        
        self.squeeze = nn.Linear(self.in_features, self.inner_features)
        self.Act = Act()
        self.excite = nn.Linear(self.inner_features, self.in_features)
        
        self.spatial = nn.Conv2d(self.in_features, 1, kernel_size = 1)
        initialize_weights(self)
    def forward(self, x):
        mean = torch.mean(x, dim = -1)
        mean = torch.mean(mean, dim = -1)
        
        squeeze = self.Act(self.squeeze(mean))
        excite = torch.sigmoid(self.excite(squeeze)).unsqueeze(-1).unsqueeze(-1) * x
        
        spatial = torch.sigmoid(self.spatial(x)) * x
        
        excited = (excite + spatial) / 2 
        return excited

class Attention(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.attention_type = Config.attention_type
        self.gate_attention = Config.gate_attention
        if self.attention_type == 'se':
            self.layer = SqueezeExcite(in_features, inner_features)
        elif self.attention_type == 'scse':
            self.layer = SCSE(in_features, inner_features)
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        if self.attention_type == 'none':
            return x
        processed = self.layer(x)
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * processed + (1 - gamma) * x
        else:
            return processed

class BottleNeckBlock(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.reduction = Config.reduction
        self.Squeeze = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1) 
        self.Process = ConvBlock(self.inner_features, self.inner_features, 3, 1, 1, 1)
        self.Expand = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1)
        self.SE = Attention(self.in_features, self.in_features // self.reduction)

        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        squeeze = self.Squeeze(x)
        process = self.Process(squeeze)
        expand = self.Expand(process)
        SE = self.SE(expand)
        gamma = torch.sigmoid(self.gamma)
        return SE * gamma + (1 - gamma) * x
class InverseBottleNeckBlock(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.reduction = Config.reduction
        self.Expand = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1) 
        self.DW = ConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features, 1)
        self.SE = Attention(self.inner_features, self.inner_features//self.reduction)
        self.Squeeze = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        expand = self.Expand(x)
        dw = self.DW(expand)
        se = self.SE(dw)
        squeeze = self.Squeeze(se)
        gamma = torch.sigmoid(self.gamma)
        return squeeze * gamma + (1 - gamma) * x
class AstrousConvolution(pl.LightningModule):
    '''
    Astrous(More Properly - à trous(at holes in french)) Convolution
    '''
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride, dilation):
        super().__init__()
        self.astrous = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, dilation = dilation, bias = False)
        self.bn = nn.BatchNorm2d(out_features)
        self.act1 = Act()
        initialize_weights(self)
    def forward(self, x):
        return self.bn(self.act1(self.astrous(x)))
class ASPP_Pool(pl.LightningModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.pooling_type = 'mean'
        if self.pooling_type == 'mean':
            self.pool = nn.AdaptiveAvgPool2d((1, 1))
        else:
            self.pool = nn.AdaptiveMaxPool2d((1, 1))
        self.process = nn.Sequential(*[
            ConvBlock(self.in_features, self.out_features, 1, 0, 1, 1)
        ])
    def forward(self, x):
        B, C, H, W = x.shape
        # Pool
        pooled = self.pool(x)
        processed = self.process(pooled)
        upsampled = F.interpolate(processed, size = (H, W), mode = 'bilinear')
        return upsampled
class ASPP(pl.LightningModule):
    '''
    à trous spatial pooling pyramid block. No further Processing, this should be added later.
    
    5 Part:
    - Normal Conv 1x1
    - à trous: 4 dilation
    - à trous: 5 dilation
    - à trous: 7 dilation
    '''
    def __init__(self, in_features, inner_features, out_features, stride = 1):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride
        self.num_groups = 4
        
        self.pool = ASPP_Pool(self.in_features, self.inner_features)
        self.conv1 = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.conv2 = AstrousConvolution(self.in_features, self.inner_features, 3, self.stride * 1, self.num_groups, 1, self.stride * 1)
        self.conv3 = AstrousConvolution(self.in_features, self.inner_features, 3, self.stride * 3, self.num_groups, 1, self.stride * 3)
        self.conv4 = AstrousConvolution(self.in_features, self.inner_features, 3, self.stride * 5, self.num_groups, 1, self.stride * 5)
        self.conv5 = AstrousConvolution(self.in_features, self.inner_features, 3, self.stride * 7, self.num_groups, 1, self.stride * 7)
        
        self.conv_proj = ConvBlock(self.inner_features * 6, self.out_features, 1, 0, 1, 1)
        initialize_weights(self)
    def forward(self, x):
        pool = self.pool(x)
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        conv4 = self.conv4(x)
        conv5 = self.conv5(x)
        
        concat = torch.cat([pool, conv1, conv2, conv3, conv4, conv5], dim = 1)
        return self.conv_proj(concat)
class BAM(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.Squeeze = nn.Linear(self.in_features, self.inner_features)
        self.Act = Act()
        self.Excite = nn.Linear(self.inner_features, self.in_features)
        
        self.dilation = Config.bam_dilate
        self.SqueezeConv = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.DA = AstrousConvolution(self.inner_features, self.inner_features, 3, self.dilation, self.inner_features, 1, self.dilation)
        self.ExciteConv = ConvBlock(self.inner_features, 1, 1, 0, 1, 1)
        self.gate_attention = Config.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        mean = torch.mean(x, dim = -1)
        mean = torch.mean(mean, dim = -1)
        
        squeeze = self.Act(self.Squeeze(mean))
        excite = self.Excite(squeeze).unsqueeze(-1).unsqueeze(-1)
        
        squeeze_conv = self.SqueezeConv(x)
        DA = self.DA(squeeze_conv)
        excite_conv = self.ExciteConv(DA)
        
        excited = torch.sigmoid((excite_conv + excite) / 2) * x
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excited + (1 - gamma) * x
        return excited
class SEM(pl.LightningModule):
    def __init__(self, in_features, inner_features, stride = 1):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.stride = stride
        
        self.Squeeze = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.FS = ConvBlock(self.inner_features, self.inner_features, 3, 1, 1, 1)
        
        # Dilation ASPP
        self.conv1 = AstrousConvolution(self.inner_features, self.inner_features, 3, self.stride * 1, self.inner_features, 1, self.stride * 1)
        self.conv2 = AstrousConvolution(self.inner_features, self.inner_features, 3, self.stride * 2, self.inner_features, 1, self.stride * 2)
        self.conv3 = AstrousConvolution(self.inner_features, self.inner_features, 3, self.stride * 3, self.inner_features, 1, self.stride * 3)
        self.conv4 = AstrousConvolution(self.inner_features, self.inner_features, 3, self.stride * 4, self.inner_features, 1, self.stride * 4)
        
        self.proj = ConvBlock(self.inner_features * 4 + self.in_features, self.in_features, 1, 0, 1, 1)
    def forward(self, x):
        squeezed = self.Squeeze(x)
        FS = self.FS(squeezed)
        
        conv1 = self.conv1(FS)
        conv2 = self.conv2(FS)
        conv3 = self.conv3(FS)
        conv4 = self.conv4(FS)
        
        concat = torch.cat([x, conv1, conv2, conv3, conv4], dim = 1)
        proj = self.proj(concat)
        return proj

# Simplified Model to Match SMP performance, then improve on it.

In [None]:
class EncoderUNext(pl.LightningModule):
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def unfreeze(self, layer):
        for parameter in layer.parameters():
            parameter.requires_grad = True
    def __init__(self):
        super().__init__()
        self.model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)
        weights = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext50_32x4d_swsl')
        self.model.load_state_dict(weights.state_dict())
        
        self.conv1 = self.model.conv1 # 64
        self.bn1 = self.model.bn1
        self.act1 = Mish()
        self.maxpool = self.model.maxpool
        
        self.layer1 = self.model.layer1 # 256
        self.layer2 = self.model.layer2 # 512
        # Freeze Initial Layers
        self.freeze([self.conv1, self.bn1, self.layer1])
        
        self.layer3 = self.model.layer3 # 1024
        self.layer4 = self.model.layer4 # 2048
        
        self.aspp_reduction = Config.aspp_reduction
        self.ASPP = ASPP(2048, 2048 // self.aspp_reduction, 512)
        del self.model
    def forward(self, x):
        features0 = self.bn1(self.act1(self.conv1(x)))
        layer1 = self.layer1(self.maxpool(features0))
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)
        layer4 = self.ASPP(layer4)
        
        features = [x, features0, layer1, layer2, layer3, layer4]
        return features
class EncoderResNet(pl.LightningModule):
    def freeze(self, layer):
        for parameter in layer.parameters():
            parameter.requires_grad = False
    def unfreeze(self, layer):
        for parameter in layer.parameters():
            parameter.requires_grad = False
    def __init__(self):
        super().__init__()
        self.model_name = 'resnet34d'
        self.model = timm.create_model(self.model_name, pretrained = True)
        # Extract Layers
        self.enc_dims = [64, 64, 128, 256, 512]
        self.conv1 = self.model.conv1
        self.bn1 = self.model.bn1
        self.act1 = self.model.act1
        self.maxpool = self.model.maxpool
        self.layer1 = self.model.layer1
        self.layer2 = self.model.layer2
        self.layer3 = self.model.layer3
        self.layer4 = self.model.layer4
        
        self.aspp_reduction = Config.aspp_reduction
        self.use_aspp = Config.use_ASPP
        if self.use_aspp:
            self.ASPP = ASPP(self.enc_dims[-1], self.enc_dims[-1] // self.aspp_reduction, self.enc_dims[-1])
        
        
    def forward(self, x):
        features0 = self.bn1(self.act1(self.conv1(x))) # 64 
        layer1 = self.layer1(self.maxpool(features0)) # 64
        layer2 = self.layer2(layer1) # 128
        layer3 = self.layer3(layer2) # 256
        layer4 = self.layer4(layer3) # 512
        
        layer4 = self.ASPP(layer4)
        features = [x, features0, layer1, layer2, layer3, layer4]
        return features
        
class EncoderQTPi(pl.LightningModule):
    def freeze_beginning(self):
        self.freeze([self.model.encoder])
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def unfreeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = True
    def __init__(self):
        super().__init__()
        self.enc_dims = [3, 32, 16, 24, 40, 80, 112, 320]
        # HYPER PARAMETERS
        self.base_name = 'efficientnet-b0'
        # END OF HYPER PARAMETERS
        self.model = smp.Unet(self.base_name)
        # Freeze Layer
        # Custom Layers(Attention - SE, Dropout2d)
        self.use_ASPP = Config.use_ASPP
        self.aspp_reduction = Config.aspp_reduction
    
        if self.use_ASPP:
            self.block7 = nn.Sequential(*[
                ASPP(self.enc_dims[7], self.enc_dims[7] // self.aspp_reduction, self.enc_dims[7])
            ])
        else:
            self.block7 = nn.Identity()
        self.buff_encoder = Config.buffed_encoder
        if self.buff_encoder:
            self.num_blocks = Config.num_blocks
            self.expansion = Config.expand
            self.block8 = nn.Sequential(*[
                InverseBottleNeckBlock(self.enc_dims[7], self.enc_dims[7] * self.expansion) for i in range(self.num_blocks)
            ])
        else:
            self.block8 = nn.Identity()
        self.use_bam = Config.use_bam
        self.reduction = Config.reduction
        if self.use_bam:
            # Two BAM blocks added, one after the encoder, and one after ASPP
            self.bam1 = BAM(self.enc_dims[7], self.enc_dims[7] // self.reduction)
            self.bam2 = BAM(self.enc_dims[7], self.enc_dims[7] // self.reduction)
        else:
            self.bam1 = nn.Identity()
            self.bam2 = nn.Identity()
    def forward(self, x):
        '''
        x: Tensor(B, 3, 512, 512)
        Returns:
        l0: Tensor(B, 16, 256, 256)
        l1: Tensor(B, 24, 128, 128)
        l2: Tensor(B, 48, 64, 64)
        l3: Tensor(B, 120, 32, 32)
        l4: Tensor(B, 352, 16, 16)
        l5: Tensor(B, 512, 8, 8) 
        '''
        x, l0, l1, l2, l3, l4 = tuple(self.model.encoder(x))
        l4 = self.bam1(l4)
        l4 = self.block7(l4)
        l4 = self.block8(l4)
        l4 = self.bam2(l4)
        features = [x, l0, l1, l2, l3, l4]
        return features

In [None]:
# Special Convolutional Blocks for the UNet Decoder:
class RecurrentConvolution(pl.LightningModule):
    '''
    Recurrent Convolution Block
    '''
    def __init__(self, in_features, kernel_size, padding, groups, t = 2):
        super().__init__()
        self.in_features = in_features
        self.kernel_size = kernel_size
        self.padding = padding
        self.groups = groups
        self.t = t
        
        self.block = ConvBlock(self.in_features, self.in_features, self.kernel_size, self.padding, self.groups, 1)
    def forward(self, x):
        for t in range(self.t):
            if t == 0:
                x1 = self.block(x)
            else:
                x1 = self.block((x + x1) / 2)
        return x1
class RecurrentBlock(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.reduction = Config.reduction
        self.conv = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.recurrent = RecurrentConvolution(self.inner_features, 3, 1, self.inner_features)
        self.SE = Attention(self.inner_features, self.inner_features // self.reduction)
        self.conv2 = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        proj_down = self.conv(x)
        recurrent = self.recurrent(proj_down)
        se = self.SE(recurrent)
        conv2 = self.conv2(se)
        
        gamma = torch.sigmoid(self.gamma)
        return gamma * conv2 + (1 - gamma) * x
class GatedSpatialAttention(pl.LightningModule):
    '''
    Base Gated Spatial Attention
    '''
    def __init__(self, left_features, down_features, inner_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.inner_features = inner_features
        
        self.ConvLeft = nn.Conv2d(self.left_features, self.inner_features, kernel_size = 1, bias = False)
        self.ConvDown = nn.Conv2d(self.down_features, self.inner_features, kernel_size = 1, bias = False)
        
        self.BatchNorm = nn.BatchNorm2d(self.inner_features)
        self.act = Act()
        
        self.ConvBlock = nn.Conv2d(self.inner_features, self.left_features, kernel_size = 1, bias = False)
        self.BatchNorm2 = nn.BatchNorm2d(self.left_features)
        self.gate_attention = Config.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
        initialize_weights(self)
    def forward(self, left_features, down_features):
        conv_left = self.ConvLeft(left_features)
        conv_down = self.ConvDown(down_features)
    
        conv = self.BatchNorm(self.act((conv_down + conv_left) / 2))
        logits = torch.sigmoid(self.BatchNorm2(self.ConvBlock(conv)))
        excite = logits * left_features
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excite + (1 - gamma) * left_features
        return excite
        
class GatedChannelAttention(pl.LightningModule):
    '''
    Similar to the Attention UNet, but with SE principles.
    
    I find that Conv2d never works for attention.
    '''
    def __init__(self, left_features, down_features, inner_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.inner_features = inner_features
        
        self.LeftSqueeze = nn.Linear(self.left_features, self.inner_features)
        self.Act = Act()
        self.DownSqueeze = nn.Linear(self.down_features, self.inner_features)
        
        self.Excite = nn.Linear(self.inner_features, self.left_features)
        self.gate_attention = Config.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, left_features, down_features):
        
        mean_left = torch.mean(left_features, dim = -1)
        mean_left = torch.mean(mean_left, dim = -1)
        
        mean_down = torch.mean(down_features, dim = -1)
        mean_down = torch.mean(mean_down, dim = -1)
        
        squeeze_left = self.LeftSqueeze(mean_left)
        squeeze_down = self.DownSqueeze(mean_down)
        
        squeeze = self.Act((squeeze_left + squeeze_down) / 2)
        
        excite = torch.sigmoid(self.Excite(squeeze)).unsqueeze(-1).unsqueeze(-1) * left_features
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excite + (1 - gamma) * left_features
        return excite
class ChooseBottleNeck(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.bottleneck_type = Config.bottleneck_type
        assert self.bottleneck_type in ['none', 'recurrent', 'inverse', 'bottleneck']
        if self.bottleneck_type == 'recurrent':
            self.layer = RecurrentBlock(self.in_features, self.inner_features)
        elif self.bottleneck_type == 'inverse':
            self.layer = InverseBottleNeckBlock(self.in_features, self.inner_features)
        elif self.bottleneck_type == 'bottleneck':
            self.layer = BottleNeckBlock(self.in_features, self.inner_features)
        else:
            self.layer = nn.Identity()
    def forward(self, x):
        return self.layer(x)

In [None]:
class FPN(pl.LightningModule):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        assert isinstance(self.in_channels, list) 
    
        self.conv_proj = nn.ModuleList([
            nn.Sequential(*[
                ConvBlock(self.in_channels[idx], self.out_channels * 2, 3, 1, 1, 1),
                ConvBlock(self.out_channels * 2, self.out_channels, 3, 1, 1, 1)
            ]) 
            for idx in range(len(self.in_channels))])
        
    def forward(self, features, last_dim):
        B, C, H, W = last_dim.shape
        concatted_features = []
        for idx in range(len(features)):
            processed = self.conv_proj[idx](features[idx])
            upsampled = F.interpolate(processed, size = (H, W), mode = 'bilinear')
            concatted_features += [upsampled]
        concat = torch.cat([last_dim] + concatted_features, dim = 1)
        return concat

class LinkNetBlockQTPi(pl.LightningModule):
    def __init__(self, left_features, down_features, out_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.out_features = out_features
        
        self.PixelShuffle = PixelShuffle_ICNR(self.down_features, self.down_features, blur = True)
        self.reduction = Config.reduction
        
        self.use_attention = Config.use_decoder_attention
        self.Conv1 = ConvBlock((self.down_features + self.left_features), self.out_features, 3, 1, 1, 1)
        self.Conv2 = ConvBlock(self.out_features, self.out_features, 3, 1, 1, 1)
        self.attention2 = Attention(self.out_features, self.out_features // self.reduction)
        if self.use_attention and self.left_features != 0:
            self.attention1 = GatedChannelAttention(self.left_features, self.down_features, self.left_features // self.reduction)
        self.buff_decoder = Config.buffed_decoder
        if self.buff_decoder:
            # Add a Few Residual Blocks
            self.num_blocks = Config.num_blocks
            self.expand = Config.expand 
            self.additional_blocks = nn.Sequential(*[
                InverseBottleNeckBlock(self.out_features, self.out_features * self.expand) for i in range(self.num_blocks)
            ])
    def forward(self, left_features, down_features):
        down_features = self.PixelShuffle(down_features)
        if left_features is not None:
            if self.use_attention:
                left_features = self.attention1(left_features, down_features)
            down_features = torch.cat([down_features, left_features], dim = 1)
        conv1 = self.Conv1(down_features)
        conv2 = self.Conv2(conv1)
        attention2 = self.attention2(conv2)
        if self.buff_decoder:
            attention2 = self.additional_blocks(attention2) # gives slightly more power to the decoder. Use with risk.
        return attention2
class DecoderBlockQTPi(pl.LightningModule):
    def __init__(self, left_features, down_features, out_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.out_features = out_features
        self.reduction = Config.reduction
        
        self.use_attention = Config.use_decoder_attention
        self.conv1 = ConvBlock(self.left_features + self.down_features, self.out_features, 3, 1, 1, 1)
        self.conv2 = ConvBlock(self.out_features, self.out_features, 3, 1, 1, 1)
        self.att2 = Attention(self.out_features, self.out_features // self.reduction)
        if self.use_attention and self.left_features != 0 and self.down_features != 0:
            self.attention = GatedChannelAttention(self.left_features, self.down_features, self.left_features // self.reduction)
        self.buff_decoder = Config.buffed_decoder
        if self.buff_decoder:
            self.num_blocks = Config.num_blocks
            self.expand = Config.expand
            self.additional_blocks = nn.Sequential(*[
                InverseBottleNeckBlock(self.out_features, self.out_features * self.expand) for i in range(self.num_blocks)
            ])
    def forward(self, left_features, down_features):
        down_features = F.interpolate(down_features, scale_factor = 2, mode = 'nearest')
        if left_features is not None:
            # Attend
            if self.use_attention:
                left_features = self.attention(left_features, down_features)
            down_features = torch.cat([down_features, left_features], dim = 1)
        conv1 = self.conv1(down_features)
        conv2 = self.conv2(conv1)
        conv2 = self.att2(conv2)
        if self.buff_decoder:
            conv2= self.additional_blocks(conv2)
        return conv2
class DecoderQTPi(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.num_classes = Config.num_classes
        self.encoder_type = Config.encoder_type
        if self.encoder_type == 'resnet':
            self.left_dim = [256, 128, 64, 64, 0]
            self.down_dim = [512, 256, 128, 64, 32, 16]
        elif self.encoder_type == 'unext':
            self.left_dim = [1024, 512, 256, 64, 0]
            self.down_dim = [512, 256, 128, 64, 32, 16]
        else:    
            self.left_dim = [112,  40,  24, 32,  0]
            self.down_dim = [320, 256, 128, 64, 32, 16] 
        
        self.useLinkNet = Config.use_linkNet
        def block(idx):
            if self.useLinkNet:
                return LinkNetBlockQTPi(self.left_dim[idx], self.down_dim[idx], self.down_dim[idx + 1])
            else:
                return DecoderBlockQTPi(self.left_dim[idx], self.down_dim[idx], self.down_dim[idx + 1])
            
        self.decoder_blocks = nn.ModuleList([
            block(i) for i in range(len(self.left_dim)) 
        ])
        self.use_SEM = Config.use_sem
        self.aspp_reduction = Config.aspp_reduction
        if self.use_SEM:
            # 2 SEM Blocks - Like the 2 BAM Blocks in Encoder - in early decoder to save memory
            self.sem1 = SEM(self.down_dim[1], self.down_dim[1] // self.aspp_reduction)
            self.sem2 = SEM(self.down_dim[2], self.down_dim[2] // self.aspp_reduction)
        else:
            self.sem1 = nn.Identity()
            self.sem2 = nn.Identity()
        self.use_FPN = Config.use_FPN
        if self.use_FPN:
            self.FPN = FPN(self.down_dim[0:-2], self.down_dim[-2])
        self.drop_final = nn.Dropout2d(0.0) # Small DropProb at end 0.1 Default
        self.drop_middle = nn.Dropout2d(0.0) # Large Drop in Middle, 0.5 for ASPP
        if self.use_FPN:
            self.fpn_proj = ConvBlock(self.down_dim[-2] * 5, self.down_dim[-2], 1, 0, 1, 1)
        
        self.proj = nn.Conv2d(16, self.num_classes, kernel_size = 3, padding = 1)
        
    def forward(self, x0, l0, l1, l2, l3, l4):
        '''
        l0: Tensor(B, 16, 128, 128)
        l1: Tensor(B, 24, 64, 64) - FPN 2x
        l2: Tensor(B, 40, 32, 32) - FPN 4x
        l3: Tensor(B, 112, 16, 16) - FPN 8x
        l4: Tensor(B, 320, 8, 8) - FPN 16x
        '''
        # Drop Middle
        l4 = self.drop_middle(l4)
        d4 = self.decoder_blocks[0](l3, l4) # 16
        d4 = self.sem1(d4)
        
        d3 = self.decoder_blocks[1](l2, d4) # 32
        d3 = self.sem2(d3)
        
        d2 = self.decoder_blocks[2](l1, d3) # 64
        
        d1 = self.decoder_blocks[3](l0, d2) # 128
        if self.use_FPN:
            d1 = self.FPN([l4, d4, d3, d2], d1)
            d1 = self.fpn_proj(d1)
        d0 = self.decoder_blocks[4](None, d1) # 256
        # Drop Final
        d0 = self.drop_final(d0)
        # Segmentation Head
        pred = self.proj(d0)
        return pred

In [None]:
class UNetQTPi(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder_type = Config.encoder_type
        if self.encoder_type == 'resnet':
            self.encoder = EncoderResNet()
        elif self.encoder_type == 'unext':
            self.encoder = EncoderUNext()
        else:
            self.encoder = EncoderQTPi()
            
        self.decoder = DecoderQTPi()
        if Config.act == 'mish':
            replace_all(self)
    def forward(self, x):
        return torch.squeeze(self.decoder(*self.encoder(x)))

In [None]:
class Lovask(pl.LightningModule):
    def __init__(self):
        super().__init__()
    def forward(self, y_pred, y_true):
        '''
        y_pred: Logits
        y_true: Targets
        '''
        y_pred = torch.squeeze(y_pred)
        return 0.5 * lovask.lovasz_hinge(y_pred, y_true) + 0.5 * lovask.lovasz_hinge(-y_pred, 1 - y_true)
class DiceLoss(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
    def soft_dice_score(self, output, target, smooth = 0.0):
        assert output.size() == target.size()
        eps = 1e-7
        intersection = torch.sum(output * target) * 2 + eps
        cardinality = torch.sum(output + target) + eps
        dice_score =  intersection / cardinality
        return dice_score

    def forward(self, y_pred, y_true):
        y_pred = torch.sigmoid(y_pred)
        y_true = y_true.to(torch.float)
        loss = 1 - self.soft_dice_score(y_pred, y_true)
        return torch.log((torch.exp(loss) + torch.exp(-loss)) / 2)
class BCE(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
        self.symmetric = False
    
    def symmetric_bce(self, y_pred, y_true):
        # Symmetric
        y_true = y_true.to(torch.float)
        ones = y_true == 1
        zeros = y_true == 0
        
        loss1 = self.BCEWithLogitsLoss(y_pred[ones], torch.ones_like(y_pred[ones], device = y_pred.device))
        loss2 = self.BCEWithLogitsLoss(y_pred[zeros], torch.zeros_like(y_pred[zeros], device = y_pred.device))
        return (loss1 + loss2) / 2
    def regular_bce(self, y_pred, y_true):
        y_true = y_true.to(torch.float)
        loss = self.BCEWithLogitsLoss(y_pred, y_true)
        return loss
    def forward(self, y_pred, y_true):
        if self.symmetric:
            loss = self.symmetric_bce(y_pred, y_true)
        else:
            loss = self.regular_bce(y_pred, y_true)
        return loss
class CEJaccard(pl.LightningModule):
    # Log CosH Jaccard Loss
    def __init__(self):
        super().__init__()
    def jaccard_score(self, y_pred, y_true):
        y_pred = F.softmax(y_pred, dim = 1)
        y_pred_ones = y_pred[:, 1, :, :]
        eps = 1e-8
        intersection = torch.sum(y_pred_ones * y_true)
        cardinality = torch.sum(y_pred_ones + y_true)
        cardinality = cardinality - intersection
        return (intersection + eps) / (cardinality + eps)
    def forward(self, y_pred, y_true):
        jaccard = self.jaccard_score(y_pred, y_true)
        loss = 1 - jaccard
        return torch.log((torch.exp(loss) + torch.exp(-loss)) / 2)
class CEDice(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.num_classes = Config.num_classes
    def dice_score(self, y_pred, y_true):
        # One Hot Encode y_true 
        y_pred_ones = y_pred[:, 1, :, :]
        eps = 1e-8
        intersection = torch.sum(y_pred_ones * y_true) * 2 + eps
        cardinality = torch.sum(y_pred_ones + y_true) + eps 
        return intersection / cardinality
    def forward(self, y_pred, y_true):
        y_pred = F.softmax(y_pred, dim = 1)
        loss = self.dice_score(y_pred, y_true)
        loss = 1 - loss
        return torch.log((torch.exp(loss) + torch.exp(-loss)) / 2)
class CrossEntropy(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.CrossEntropy = nn.CrossEntropyLoss()
    def forward(self, y_pred, y_true):
        y_true = y_true.to(torch.long)
        loss = self.CrossEntropy(y_pred, y_true)
        return loss
class CrossEntropyDice(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.CE = CrossEntropy()
        self.Dice = CEDice()
    def forward(self, y_pred, y_true):
        loss = self.CE(y_pred, y_true)
        loss2 = self.Dice(y_pred, y_true) 
        return loss + loss2
class CrossEntropyJaccard(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.CE = CrossEntropy()
        self.Dice = CEJaccard()
    def forward(self, y_pred, y_true):
        loss = self.CE(y_pred, y_true)
        loss2 = self.Dice(y_pred, y_true) 
        return loss + loss2


class BCEDice(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.BCE = BCE()
        self.Dice = DiceLoss()
    
    def forward(self, y_pred, y_true):
        bce = self.BCE(y_pred, y_true)
        dice = self.Dice(y_pred, y_true)
        return bce + dice


In [None]:
class SMP(pl.LightningModule):
    def __init__(self):
        # BaseLine SMP Model
        super().__init__()
        self.Model = smp.Unet('efficientnet-b0', classes = Config.num_classes)
        self.use_ASPP = Config.use_ASPP
        self.aspp_reduction = Config.aspp_reduction
        self.input_size = 320
        if self.use_ASPP:
            self.ASPP = ASPP(self.input_size, self.input_size // self.aspp_reduction, self.input_size)
        else:
            self.ASPP = nn.Identity()
        self.use_BAM = Config.use_bam
        self.bam_dilate = Config.bam_dilate
        self.reduction = Config.reduction
        
        if self.use_BAM:
            self.bam1 = BAM(self.input_size, self.input_size // self.reduction)
            self.bam2 = BAM(self.input_size, self.input_size // self.reduction)
        else:
            self.bam1 = nn.Identity()
            self.bam2 = nn.Identity()
    def forward(self, x):
        features = self.Model.encoder(x)
        x, l0, l1, l2, l3, l4 = tuple(features)
        l4 = self.bam1(l4)
        l4 = self.ASPP(l4)
        l4 = self.bam2(l4)
        features = [x, l0, l1, l2, l3, l4]
        return self.Model.segmentation_head(self.Model.decoder(*features))

Losses

In [None]:
class Loss(Metric):
    def __init__(self):
        super().__init__()
        self.count = 0
        self.loss = 0
        self.num_classes = Config.num_classes
        if self.num_classes == 2:
            self.BCE = CrossEntropyJaccard()
        else:
            self.BCE = BCEDice()
    def reset(self):
        self.count = 0
        self.loss = 0
    def accumulate(self, learn):
        y_pred, y_true = learn.pred, learn.y
        loss = self.BCE(y_pred, y_true)
        self.loss += loss.item()
        self.count += 1
        return loss
    @property
    def value(self):
        if self.count != 0:
            return round(self.loss / self.count, 3)
        return 0
class Dice_soft(Metric):
    def __init__(self, axis=1): 
        self.axis = axis 
        self.num_classes = Config.num_classes
    def reset(self): self.inter,self.union = 0,0
    def accumulate(self, learn):
        if self.num_classes == 2:
            pred, targ = F.softmax(learn.pred, dim = 1), learn.y
            pred_ones = pred[:, 1, :, :]
            
            inter = (pred_ones * targ).float().sum().item()
            union = (torch.sum(pred_ones + targ)).float().item()
            
            self.inter += inter
            self.union += union

        else:
            pred,targ = torch.sigmoid(learn.pred), learn.y
            self.inter += (pred*targ).float().sum().item()
            self.union += (pred+targ).float().sum().item()
    @property
    def value(self):
        dice = 2.0 * self.inter/self.union if self.union > 0 else None
        print(f'--------DICE: {round(dice, 3)}')
        return round(dice, 3)
class Dice_th(Metric):
    def __init__(self, ths=np.arange(0.1,0.9,0.05), axis=1): 
        self.axis = axis
        self.ths = ths
        self.num_classes = Config.num_classes
        self.CEDice = CEDice()
        
    def reset(self): 
        self.inter = torch.zeros(len(self.ths))
        self.union = torch.zeros(len(self.ths))
        
    def accumulate(self, learn):
        if self.num_classes == 2:
            pred, targ = F.softmax(learn.pred, dim = 1), learn.y
            pred_ones = pred[:, 1, :, :]
            for i,th in enumerate(self.ths):
                p_ones = (pred_ones > th).float()
                self.inter[i] += (p_ones*targ).float().sum().item()
                self.union[i] += (torch.sum(p_ones + targ)).float().item()


        else:
            pred,targ = torch.sigmoid(learn.pred), learn.y
            for i,th in enumerate(self.ths):

                p = (pred > th).float()
                self.inter[i] += (p*targ).float().sum().item()
                self.union[i] += (p+targ).float().sum().item()

    @property
    def value(self):
        dices = torch.where(self.union > 0.0, 
                2.0*self.inter/self.union, torch.zeros_like(self.union))
        
        return round(dices.max().item(), 3)
class Best_dice_th(Metric):
    def __init__(self, ths=np.arange(0.2,0.7,0.01), axis=1): 
        self.axis = axis
        self.ths = ths
        self.num_classes = Config.num_classes
    def reset(self): 
        self.inter = torch.zeros(len(self.ths))
        self.union = torch.zeros(len(self.ths))
        
    def accumulate(self, learn):
        if self.num_classes == 2:
            pred, targ = F.softmax(learn.pred, dim = 1), learn.y 
            pred_ones = pred[:, 1, :, :]
            for i,th in enumerate(self.ths):
                p_ones = (pred_ones > th).float()
                self.inter[i] += (p_ones*targ).float().sum().item()
                self.union[i] += (torch.sum(p_ones + targ)).float().item()

        else:
            pred,targ = torch.sigmoid(learn.pred), learn.y
            for i,th in enumerate(self.ths):
                p = (pred > th).float()
                self.inter[i] += (p*targ).float().sum().item()
                self.union[i] += (p+targ).float().sum().item()

    @property
    def value(self):
        dices = torch.where(self.union > 0.0, 
                2.0*self.inter/self.union, torch.zeros_like(self.union))
        # Find the Best Dice Threshold
        dice = self.ths[dices.argmax()]
        return round(dice.item(), 3)

In [None]:
class Config:
    IMAGE_SIZE = cfg.image_size
    BATCH_SIZE = 36 # Small-ish batch size needed to support ASPP + FPN
    NUM_EPOCHS = 42 #200  
    
    NUM_WORKERS = 4
    device = device
    
    encoder_type = 'effnet'
    num_classes = 2
    use_ASPP = False
    use_FPN = False
    attention_type = "none"
    use_linkNet = True # linkNet Blocks should perform better
    use_decoder_attention = False # Special Attention
    gate_attention = True# Reduces Instability of Attention Layers at Beginning of Training.
    act = 'relu' # Actually Performs better than SiLU.
    bottleneck_type = 'inverse'
    buffed_decoder = False # Adds BottleNecks and More Processing to the Decoder.
    buffed_encoder = False # Adds BottleNecks to the Encoder, After the ASPP module.
    num_blocks = 1
    use_bam = False # In Testing.
    bam_dilate = 3
    use_sem = False # In Testing.
    reduction = 1 # reduction factor
    aspp_reduction = 2 # Reduction factor for ASPP Modules.
    expand = 2 # Expansion Factor 

In [None]:
class ClipGrad(Callback):
    def __init__(self, max_norm):
        super().__init__()
        self.max_norm = max_norm
    def on_backward_end(self, **kwargs):
        nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.max_norm)
class TrainingConfig:
    lr = 5e-4
    weight_decay = 0 # Increase Later
    # Increase Dropout Later
    NUM_WORKERS = 4
    patience = 3
    factor = 0.2
    eta_min = 1e-9
    num_steps = 5
    clip_grads = 20

In [None]:
import copy
class Store():
    def __init__(self, pred, y):
        self.pred = pred
        self.y = y
class TrainingModel(pl.LightningModule):
    def unfreeze_model(self):
        for parameter in self.model.parameters():
            parameter.requires_grad = True
    def __init__(self):
        super().__init__()
        self.model = self.configure_model()
        self.decay_after = 10
        # Internal States
        self.TrainLoss = Loss()
        self.ValLoss = Loss()
        self.DiceSoft = Dice_soft()
        self.DiceTh = Dice_th()
        self.BestDiceTh = Best_dice_th()
        
        self.best = {'val_thresh': 0, "val_loss": float('inf'), 'val_dice': 0, 'val_soft': 0}
        self.EPOCHS = -1
        self.reset()
    def reset(self):
        self.TrainLoss.reset()
        self.ValLoss.reset()
        self.DiceSoft.reset()
        self.DiceTh.reset()
        self.BestDiceTh.reset()
        self.EPOCHS += 1
    def configure_model(self):
        model = UNetQTPi()
        return model
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr = TrainingConfig.lr, weight_decay = TrainingConfig.weight_decay)
        self.lr_decay1 = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'max', patience = TrainingConfig.patience, factor = TrainingConfig.factor, min_lr = TrainingConfig.eta_min, verbose = True)
        self.lr_decay2 = optim.lr_scheduler.CosineAnnealingLR(optimizer, TrainingConfig.num_steps, eta_min = TrainingConfig.eta_min)
        return optimizer
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        
        pred = self.model(x)
        store = Store(pred, y)
        loss = self.TrainLoss.accumulate(store)
        return loss
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        
        pred = self.model(x)
        store= Store(pred, y)
        self.ValLoss.accumulate(store)
        self.DiceSoft.accumulate(store)
        self.DiceTh.accumulate(store)
        self.BestDiceTh.accumulate(store)
        
    def print_results(self):
        print(f'-----------EPOCH {self.EPOCHS}-----------------')
        ValLoss = self.ValLoss.value
        TrainLoss = self.TrainLoss.value
        DiceSoft = self.DiceSoft.value
        DiceTh = self.DiceTh.value
        BestDiceTh = self.BestDiceTh.value
        
        self.log('dice_soft', DiceSoft)
        if self.EPOCHS >= self.decay_after:
            self.lr_decay1.step(DiceSoft)
        self.lr_decay2.step()
        
        if ValLoss <= self.best['val_loss']:
            self.best['val_loss'] = ValLoss
            torch.save(self.state_dict(), "./loss.pth")
            print('-----------------------Saved Best Val Loss---------------------------')
        if DiceTh >= self.best['val_dice']:
            self.best['val_dice'] = DiceTh
            self.best['val_thresh'] = BestDiceTh
            torch.save(self.state_dict(),"./dice.pth")
            print("----------------Saved Dice Th--------------------")
        if DiceSoft >= self.best['val_soft']:
            self.best['val_soft'] = DiceSoft
            torch.save(self.state_dict(), "./soft.pth")
            print("-------------Saved Dice Soft-----------------")
            
        print(f"E: {self.EPOCHS} BT: {self.best['val_thresh']} BS: {self.best['val_soft']} BL: {self.best['val_loss']} BD: {self.best['val_dice']} TL: {TrainLoss} VL: {ValLoss} DS: {DiceSoft} DT: {DiceTh} BDT: {BestDiceTh} ")
        
    def validation_epoch_end(self, logs):
        self.print_results()
        self.reset()

In [None]:
def unfreeze_whole_model(model):
    for parameter in model.parameters():
        parameter.requires_grad = True
def train_folds(idx, model_path = "./", load_prev = None):
    seed_all()    
    model = SMP()
    train, val = DataModule.get_both(idx)
    # Dataloader and learner
    dls= DataLoaders.from_dsets(train, val, shuffle = True, pin_memory = True, worker_init_fn = seed_worker, num_workers = TrainingConfig.NUM_WORKERS, bs=Config.BATCH_SIZE, after_batch=Normalize.from_stats(*cfg.stats))
    if torch.cuda.is_available(): dls.cuda(), model.cuda()
    metrics = [Dice_soft(), Dice_th(), Best_dice_th()]
    cbs = [SaveModelCallback(monitor='dice_soft', comp = np.greater), ReduceLROnPlateau(monitor = 'dice_soft', comp = np.greater, patience = TrainingConfig.patience, factor = TrainingConfig.factor), EarlyStoppingCallback(monitor = 'dice_soft', patience = 10, comp = np.greater)]
    def optimizer(*args, **kwargs):
        return Lookahead(Adam(model.parameters(), lr = TrainingConfig.lr, wd = TrainingConfig.weight_decay))
    learn = Learner(dls, model, metrics= metrics, wd=TrainingConfig.weight_decay, loss_func=CrossEntropyJaccard(), opt_func=optimizer , cbs=cbs)
    learn.to_fp16()
    learn.fit_one_cycle(Config.NUM_EPOCHS, lr_max = TrainingConfig.lr)
    del model
    del train, val
    del dls
    del cbs
    torch.cuda.empty_cache()

In [None]:
train_folds(0)