# Import Depencies 

In [None]:
%%capture
import torch
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

import albumentations as A
from albumentations.pytorch import ToTensorV2

import tqdm.notebook as tqdm

!pip install livelossplot
import livelossplot

!pip install timm
import timm

!pip install segmentation_models_pytorch
import segmentation_models_pytorch as seg

import cv2
import PIL
import os
import random
import math
import sys
import copy

import numpy as np 
import pandas as pd
import lovask
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
# Import Ranger Optimizer
%cd ..
!git clone https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 
%cd Ranger-Deep-Learning-Optimizer
!pip install -e .
%cd ..
%cd working
import sys
sys.path.append("../Ranger-Deep-Learning-Optimizer")
from ranger import Ranger
import warnings
warnings.filterwarnings("ignore")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Reproducibility:
import os
import random
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# Slight Stochasticity Tradeoff for Quicker Comp.
torch.backends.cudnn.benchmark = False # True for faster
pl.seed_everything()
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
train_img = '../input/hubmap-256x256/train/'
train_masks = '../input/hubmap-256x256/masks/'

pseudo_labelled_img = '../input/pseudolabelledhubmap/test/'
pseudo_labelled_masks = '../input/pseudolabelledhubmap/masks/'
def get_images():
    train_images = np.array(os.listdir(train_img))
    pseudo_images = np.array(os.listdir(pseudo_labelled_img))
    return train_images,  pseudo_images
def display_image_np(image):
    plt.imshow(image)
    plt.show()
def display_image(image):
    plt.imshow(image.cpu().transpose(0, 1).transpose(1, 2))
    plt.show() 

# Load Dataset

# Load External and Test Data

In [None]:
class Config:
    IMAGE_SIZE = 256
    NUM_FOLDS = 4
    BATCH_SIZE = 32
    TEST_BATCH_SIZE = 48
    NUM_EPOCHS = 30
    device = device

In [None]:
to_tensor = ToTensorV2()
def get_transforms():
    # Data Augmentation on Images and Mask
    train_transforms = A.Compose([
        A.Flip(p = 0.5),
        A.OneOf([
            A.Blur(),
            A.MultiplicativeNoise(),
        ], p = 0.7),
        A.OneOf([
            A.OpticalDistortion(distort_limit=1.0),
        #    #A.GridDistortion(num_steps=5, distort_limit=1.),
        #    #A.ElasticTransform(alpha=3),
        ], p=0.7),
        A.CLAHE(),
        A.ColorJitter(brightness = 0.1, hue = 0.1, contrast = 0.1, saturation = 0.1),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
        A.RandomRotate90(),
        A.Normalize()
    ])

    test_transforms = A.Compose([
        A.Normalize()
    ])
    return train_transforms, test_transforms

def get_splits(all_images):
    splitter = KFold(n_splits = 75, shuffle = True, random_state = 42)
    KSPLITS = []
    count = 0
    for train, test in splitter.split(all_images):
        KSPLITS += [(all_images[train], all_images[test])]
        count += 1
        if count == Config.NUM_FOLDS:
            break
    return KSPLITS

In [None]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_ids, pseudo_ids, transforms):
        # TO avoid messing up the splits, additional(external + pseudolabelled) is appended to train.
        self.mask_base = train_masks
        self.img_base = train_img
        self.image_ids = image_ids
        self.transforms = transforms
        self.len_base = len(self.image_ids)
        
        self.pseudo_mask_base = pseudo_labelled_masks
        self.pseudo_img_base = pseudo_labelled_img
        self.pseudo_ids = pseudo_ids
        self.len_pseudo = len(self.pseudo_ids) if self.pseudo_ids is not None else 0
        
        # Compute length of dataset
        self.total_len = self.len_base + self.len_pseudo 
        # Concatenate the images
        self.total_dataset = self.image_ids
        if self.len_pseudo != 0:
            self.total_dataset = np.concatenate([self.total_dataset, self.pseudo_ids])
        
    def __len__(self):
        return self.total_len
    def __getitem__(self, idx):
        
        if idx >= self.len_base:
           
            # Must be Pseudo Labelled
            #print("USED PSEUDO")
            image_id = self.total_dataset[idx] 
            image_path = self.pseudo_img_base + image_id
            mask_path = self.pseudo_mask_base + image_id

            image = cv2.imread(image_path)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                
        else:
            image_id = self.total_dataset[idx]
            image_path = self.img_base + image_id
            mask_path = self.mask_base + image_id

            # Load in image and masks
            image = cv2.imread(image_path)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        transform = self.transforms(image = image, mask = mask)
        
        image = to_tensor(image = transform['image'])['image']
        mask = np.transpose(transform['mask'], [0, 1])
    
        return image, mask

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, config = Config):
        super().__init__()
        self.config = config
        self.train_images, self.pseudo_images = get_images()
        self.KSPLITS = get_splits(self.train_images)
        self.train_transforms, self.test_transforms = get_transforms()
    def train_dataloader(self, idx):
        train, _ = self.KSPLITS[idx]
        trainDataset = ImageDataset(train, self.pseudo_images, self.test_transforms)
        dataloader = torch.utils.data.DataLoader(trainDataset, shuffle = True, batch_size = self.config.BATCH_SIZE, worker_init_fn = seed_worker)
        return dataloader
    def val_dataloader(self, idx):
        _, val = self.KSPLITS[idx]
        valDataset = ImageDataset(val,  None, self.test_transforms) 
        dataloader= torch.utils.data.DataLoader(valDataset, batch_size = self.config.TEST_BATCH_SIZE, worker_init_fn = seed_worker)
        return dataloader
    def get_both(self, idx):
        trainloader = self.train_dataloader(idx) 
        valloader = self.val_dataloader(idx)
        return trainloader, valloader
dataModule = DataModule()

In [None]:
train, val = dataModule.get_both(0)

In [None]:
for images, labels in train:
    for image_id in range(len(images)):
        
        plt.imshow(images[image_id].transpose(0, 1).transpose(1, 2))
        plt.show()
    break

# ENCODER CNN BLOCKS

In [None]:
class ConvBlock(pl.LightningModule):
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride, act = 'relu'):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, bias = False)
        self.bn = nn.BatchNorm2d(out_features)
        if act == 'relu':
            self.act1 = nn.ReLU(inplace = True)
        else:
            self.act1 = nn.SiLU(inplace = True)
    def forward(self, x):
        return self.bn(self.act1(self.conv(x)))
class SqueezeExcite(pl.LightningModule):
    def __init__(self, in_features, inner_features, dev, act = "relu"):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.dev = dev
        
        self.Squeeze = nn.Linear(self.in_features, self.inner_features)
        if act == 'relu':
            self.act1 = nn.ReLU(inplace = True)
        else:
            self.act1 = nn.SiLU(inplace = True)
        self.Excite = nn.Linear(self.inner_features, self.in_features)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.dev))
    def forward(self, x):
        mean = torch.mean(x, dim = -1)
        mean = torch.mean(mean, dim = -1)
        
        squeeze = self.act1(self.Squeeze(mean))
        excite = torch.sigmoid(self.Excite(squeeze)).unsqueeze(-1).unsqueeze(-1)
        return excite * x * self.gamma + (1 - self.gamma) * x 
class CBAMChannel(pl.LightningModule):
    def __init__(self, in_features, inner_features, dev, act = 'relu'):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.dev = dev 
        
        self.Squeeze = nn.Linear(self.in_features, self.inner_features) 
        if act == 'relu':
            self.act1 = nn.ReLU(inplace = True)
        else:
            self.act1 = nn.SiLU(inplace = True)
        self.Excite = nn.Linear(self.inner_features, self.in_features)
        self.gamma = nn.Parameter(torch.zeros((1), device = self.dev))
    def forward(self, x):
        mean = torch.mean(x, dim = -1)
        mean = torch.mean(mean, dim = -1)
        
        max_pool, _ = torch.max(x, dim = -1) 
        max_pool, _ = torch.max(max_pool, dim = -1)
        
        squeeze_mean = self.act1(self.Squeeze(mean))
        excite_mean = self.Excite(squeeze_mean)
        
        squeeze_max = self.act1(self.Squeeze(max_pool))
        excite_max = self.Excite(squeeze_max)
        
        excite = torch.sigmoid((excite_mean + excite_max) / 2).unsqueeze(-1).unsqueeze(-1)
        return excite * x * self.gamma + (1 - self.gamma) * x
        
class Attention(pl.LightningModule):
    def __init__(self, in_features, inner_features, dev, attention_type = 'se', act = 'relu'):
        super().__init__()
        self.attention_type = attention_type
        assert self.attention_type in ['se', 'cbam', 'none']
        if self.attention_type == 'se':
            self.layer = SqueezeExcite(in_features, inner_features, dev, act = act)
        elif self.attention_type == 'cbam':
            self.layer = CBAMChannel(in_features, inner_features, dev, act = act)
        else:
            self.layer= nn.Identity()
    def forward(self, x):
        return self.layer(x)

# Self Attention Blocks
class ConvPlusBatchNorm(pl.LightningModule):
    '''
    Conv2d + BN, no activation.
    '''
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride)
        self.bn1 = nn.BatchNorm2d(out_features)
    def forward(self, x):
        return self.bn1(self.conv(x))
class SelfAttention(pl.LightningModule):
    # Non Local Block.
    def __init__(self, in_features, inner_features, num_heads):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.num_heads = num_heads
        self.K = ConvPlusBatchNorm(self.in_features, self.inner_features * self.num_heads, 3, 1, 1, 1)
        self.V = ConvPlusBatchNorm(self.in_features, self.inner_features * self.num_heads, 3, 1, 1, 1)
        self.Q = ConvPlusBatchNorm(self.in_features, self.inner_features * self.num_heads, 3, 1, 1, 1)
        self.Linear = ConvPlusBatchNorm(self.inner_features * self.num_heads, self.in_features, 3, 1, 1, 1)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        B, C, H, W = x.shape
        Keys = self.K(x)
        Values = self.V(x)
        Queries = self.Q(x)
        
        Keys = Keys.reshape(B, self.num_heads, self.inner_features, H, W)
        Values = Values.reshape(B, self.num_heads, self.inner_features, H, W)
        Queries = Queries.reshape(B, self.num_heads, self.inner_features, H, W) 
        
        Keys = Keys.reshape(B * self.num_heads, self.inner_features, H * W)
        Values = Values.view(B * self.num_heads, self.inner_features, H * W)
        Queries = Queries.view(B * self.num_heads, self.inner_features, H * W)
        
        att_mat = F.softmax(torch.bmm(Keys.transpose(1, 2), Queries) / math.sqrt(self.inner_features))
        att_vals = torch.bmm(att_mat, Values.transpose(1, 2))
        
        scores = att_vals.view(B, self.num_heads, self.inner_features, H, W)
        scores = scores.view(B, self.num_heads * self.inner_features, H, W)
        output = self.Linear(scores) 
        return output * self.gamma + (1 - self.gamma) * x
class CBAMSqueezeAttend(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, out_size, squeeze_factor = 4, act = 'relu'):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.out_size = out_size
        self.squeeze_factor = squeeze_factor
        self.act = act 
        
        self.proj = ConvBlock(self.in_features, self.out_features, 3, 1, 1, 1, act = self.act)
        self.max_pool = nn.MaxPool2d(kernel_size = 5, padding = 2, stride = self.squeeze_factor)
        self.avg_pool = nn.AvgPool2d(kernel_size = 5, padding = 2, stride = self.squeeze_factor)
        
        self.Squeeze = ConvBlock(self.out_features, self.inner_features, 3, 1, 1, 1, act = self.act)
        self.Excite = ConvPlusBatchNorm(self.inner_features, self.out_features, 3, 1, 1, 1)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x)
        
        max_pool = self.max_pool(x)
        avg_pool = self.avg_pool(x)
        
        squeeze_max = self.Squeeze(max_pool)
        squeeze_avg = self.Squeeze(avg_pool) 
        
        excite_max = self.Excite(squeeze_max)
        excite_avg = self.Excite(squeeze_avg)
        
        excite = torch.sigmoid((excite_max + excite_avg) / 2)
        excited = avg_pool * self.gamma * excite + (1 - self.gamma) * avg_pool
    
        # Interpolate Upward
        excited = F.interpolate(excited, size = (self.out_size, self.out_size), mode = 'nearest')
        return excited
class SESqueezeAttend(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, out_size, squeeze_factor = 4, act = 'relu'):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_size = out_size
        self.squeeze_factor = squeeze_factor 
        self.act = act
        self.avg_pool = nn.AvgPool2d(kernel_size = 5, padding = 2, stride = squeeze_factor)
        
        self.proj = ConvBlock(self.in_features, self.out_features, 3, 1, 1, 1, act = self.act)
        self.Squeeze = ConvBlock(self.in_features, self.inner_features, 3, 1, 1, 1, act = self.act)
        self.Excite = ConvPlusBatchNorm(self.inner_features, self.in_features, 3, 1, 1, 1)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        '''
        x: Tensor(B, C, H, W)
        '''
        B, C, H, W = x.shape
        x = self.proj(x)
        pooled = self.avg_pool(x)
        
        squeeze = self.Squeeze(pooled)
        excite = torch.sigmoid(self.Excite(squeeze))
        
        excited = self.gamma * pooled * excite + (1 - self.gamma) * pooled 
        
        # Interpolate Back Up.
        excited = F.interpolate(excited, size = (self.out_size, self.out_size), mode = 'nearest')
        return excited
class SqueezeAttend(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, out_size, squeeze_factor = 4, act = 'relu', attention_type = 'se'):
        super().__init__()
        
        self.attention_type = attention_type
        assert self.attention_type in ['se', 'cbam', 'none']
        if self.attention_type == 'se':
            self.layer = SESqueezeAttend(in_features, inner_features, out_features, out_size, squeeze_factor = squeeze_factor, act = act)
        elif self.attention_type =='cbam':
            self.layer = CBAMSqueezeAttend(in_features, inner_features, out_features, out_size, squeeze_factor = squeeze_factor, act = act)
        else:
            self.layer = nn.Identity()
    def forward(self,x):
        return self.layer(x)

class BottleNeckBlock(pl.LightningModule):
    def __init__(self, in_features, inner_features, dev, attention_type = 'se', stochastic_depth = 0, act = 'relu'):
        super().__init__()
        self.stochastic_depth = stochastic_depth
        self.in_features = in_features
        self.inner_features = inner_features
        self.dev = dev
        self.attention_type = attention_type
        
        self.Squeeze = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1, act = act) 
        self.Process = ConvBlock(self.inner_features, self.inner_features, 3, 1, 1, 1, act = act)
        self.Expand = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1, act = act)
        self.SE = Attention(self.in_features, self.in_features // 4, self.dev, attention_type = self.attention_type)

        self.gamma = nn.Parameter(torch.zeros((1), device = self.dev))
    def forward(self, x):
        if self.training and random.random() < self.stochastic_depth:
            return x
        squeeze = self.Squeeze(x)
        process = self.Process(squeeze)
        expand = self.Expand(process)
        SE = self.SE(expand)
        return SE * self.gamma + (1 - self.gamma) * x
        
class InverseBottleNeckBlock(pl.LightningModule):
    def __init__(self, in_features, inner_features, dev, attention_type = 'se', stochastic_depth = 0, act = 'relu'):
        super().__init__()
        self.stochastic_depth = stochastic_depth
        self.in_features = in_features
        self.inner_features = inner_features
        self.dev = dev
        self.attention_type = attention_type
        
        self.Expand = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1, act = act) 
        self.DW = ConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features, 1, act = act)
        self.SE = Attention(self.inner_features, self.inner_features//4, self.dev, attention_type = self.attention_type, act = act)
        self.Squeeze = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1, act = act)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        if self.training and random.random() < self.stochastic_depth:
            return x
        expand = self.Expand(x)
        dw = self.DW(expand)
        se = self.SE(dw)
        squeeze = self.Squeeze(se)
        return squeeze * self.gamma + (1 - self.gamma) * x

class DownSamplerBottleNeck(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, stride, dev, attention_type = 'se', act = 'relu'):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride
        self.dev = dev
        self.attention_type = attention_type
        
        self.pool = nn.AvgPool2d(kernel_size = 3, padding = 1, stride =stride)
        self.pool_conv = ConvBlock(self.in_features, self.out_features, 1, 0, 1, 1, act = act)
        self.Squeeze = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1, act = act) 
        self.Process = ConvBlock(self.inner_features, self.inner_features, 3, 1, 1, 1, act = act)
        self.Expand = ConvBlock(self.inner_features, self.out_features, 1, 0, 1, self.stride, act = act)
        self.SE = Attention(self.out_features, self.out_features // 4, self.dev, attention_type = self.attention_type, act = act)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.dev))
    def forward(self, x):
        pool = self.pool_conv(self.pool(x))
        conv_features = self.SE(self.Expand(self.Process(self.Squeeze(x))))
        return pool * self.gamma + pool * (1 - self.gamma)
class DownSamplerInverse(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, stride, dev, attention_type = 'se', act = 'relu'):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride
        self.dev = dev
        self.attention_type = attention_type
    
        self.pool = nn.AvgPool2d(kernel_size = 3, padding = 1, stride = self.stride)
        self.pool_conv = ConvBlock(self.in_features, self.out_features, 1, 0, 1, 1, act = act)
        
        self.squeeze = ConvBlock(self.in_features, self.inner_features,1, 0, 1, 1, act = act)
        self.process = ConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features, 1, act = act)
        self.SE = Attention(self.inner_features, self.inner_features // 4, self.dev, act = act)
        self.expand = ConvBlock(self.inner_features, self.out_features, 1, 0, 1, self.stride, act = act)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        pool = self.pool_conv(self.pool(x))
        conv = self.expand(self.SE(self.process(self.squeeze(x))))
        return conv * self.gamma + (1 - self.gamma) * pool

class AstrousConvolution(pl.LightningModule):
    '''
    Astrous(More Properly - à trous(at holes in french)) Convolution
    '''
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride, dilation, act = 'relu'):
        super().__init__()
        self.astrous = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, dilation = dilation, bias = False)
        self.bn = nn.BatchNorm2d(out_features)
        if act == 'relu':
            self.act1 = nn.ReLU(inplace = True)
        else:
            self.act1 = nn.SiLU(inplace = True)
    def forward(self, x):
        return self.bn(self.act1(self.astrous(x)))
class ASPP(pl.LightningModule):
    '''
    à trous spatial pooling pyramid block. No further Processing, this should be added later.
    
    5 Part:
    - Normal Conv
    - à trous: 3 dilation
    - à trous: 5 dilation
    - à trous: 7 dilation
    '''
    def __init__(self, in_features, inner_features, out_features, act = 'relu'):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.conv1 = ConvBlock(self.in_features, self.inner_features, 3, 1, 1, 2, act = act)
        self.conv2 = AstrousConvolution(self.in_features, self.inner_features, 3, 1, 1, 1, 3, act = act)
        self.conv3 = AstrousConvolution(self.in_features, self.inner_features, 3, 3, 1, 1, 5, act = act)
        self.conv4 = AstrousConvolution(self.in_features, self.inner_features, 3, 5, 1, 1, 7, act = act) 
        
        self.proj = ConvBlock(4 * self.inner_features, self.out_features, 1, 0, 1, 1, act = act)
        
    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        conv4 = self.conv4(x)
        # concat
        concat = torch.cat([conv1, conv2, conv3, conv4], dim = 1)
        proj = self.proj(concat)
        return proj

# ENCODER

Various Encoder Types(BaseLine, ResNet, EffNet)

In [None]:
# ResNet Based Complex Encoder(+ SE + dropout2d)
class ResNetEncoderAlpha(pl.LightningModule):
    '''
    ResNet34d encoder + SE and Dropout
    
    I would scale the model larger(ex. ResNet50), but larger models have the wrong dimensions.
    '''
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def unfreeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = True 
    def increase_drop(self):
        self.drop_prob += self.increase_dropout
    def __init__(self, increase_drop, attention_type):
        super().__init__()
        # HYPER PARAMETERS
        self.increase_dropout = increase_drop
        self.attention_type = attention_type
        self.drop_prob = 0
        self.model_name = 'resnet34d'
        # END OF HYPER PARAMETERS
        self.model = timm.create_model(self.model_name, pretrained = True) 
        # Extract Layers
        self.conv1 = self.model.conv1 # (B, 64, 128, 128)
        self.bn1 = self.model.bn1
        self.act1 = self.model.act1
        self.maxpool = self.model.maxpool
        
        self.layer1 = self.model.layer1 # (b, 64, 64, 64)
        self.layer2 = self.model.layer2 # (b, 128, 32, 32)
        self.layer3 = self.model.layer3 # (b, 256, 16, 16)
        self.layer4 = self.model.layer4 # (b, 512, 8, 8)
        
        self.Dropout0 = nn.Dropout2d(self.drop_prob)
        self.Attention0 = Attention(64, 16, self.device, attention_type = self.attention_type)
        self.increase_drop()
        
        self.Dropout1 = nn.Dropout2d(self.drop_prob)
        self.increase_drop()
        self.Attention1 = Attention(64, 16, self.device, attention_type = self.attention_type)
        self.Dropout2 = nn.Dropout2d(self.drop_prob)
        self.Attention2 = Attention(128, 32, self.device, attention_type = self.attention_type)
        self.increase_drop()
        self.Dropout3 = nn.Dropout2d(self.drop_prob)
        self.Attention3 = Attention(256, 64, self.device, attention_type = self.attention_type)
        self.increase_drop()
        self.Dropout4 = nn.Dropout2d(self.drop_prob)
        self.Attention4 = Attention(512, 128, self.device, attention_type = self.attention_type)
        
        del self.model
    def forward(self, x):
        features0 = self.bn1(self.act1(self.conv1(x)))
        features0 = self.Dropout0(features0)
        features0 = self.Attention0(features0)
        
        layer1 = self.layer1(self.maxpool(features0))
        layer1 = self.Dropout1(layer1)
        layer1 = self.Attention1(layer1)
        
        layer2 = self.layer2(layer1)
        layer2 = self.Dropout2(layer2)
        layer2 = self.Attention2(layer2)
        
        layer3 = self.layer3(layer2)
        layer3 = self.Dropout3(layer3)
        layer3 = self.Attention3(layer3)
        
        layer4 = self.layer4(layer3)
        layer4 = self.Dropout4(layer4)
        layer4 = self.Attention4(layer4)
        
        return x, features0, layer1, layer2, layer3, layer4
class EffNetEncoderAlpha(pl.LightningModule):
    '''
    EfficientNet-b4 based Encoder(SE + Dropout)
    '''
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def unfreeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = True 
    def increase_drop(self):
        self.drop_prob += self.increase_drop
    def __init__(self, increase_drop, attention_type, act = 'relu'):
        super().__init__()
        self.drop_prob = 0
        self.act = act 
        self.increase_drop = increase_drop
    
        self.attention_type = attention_type
        
        self.model_name = 'tf_efficientnet_b4_ns'
        self.model = timm.create_model(self.model_name, pretrained = True)
        
        self.conv1 = self.model.conv_head
        self.bn1 = self.model.bn1
        self.act1 = self.model.act1
    
        self.block0 = self.model.blocks[0] # 24
        self.block1 = self.model.blocks[1] # 32
        self.block2 = self.model.blocks[2] # 56
        self.block3 = self.model.blocks[3] # 112
        self.block4 = self.model.blocks[4] # 160 
        self.block5 = self.model.blocks[5] # 272
        self.block6 = self.model.blocks[6] # 448
        
        # Custom Layer
        self.Dropout0 = nn.Dropout2d(self.drop_prob)
        self.Attention0 = Attention(24, 6, self.device, attention_type= self.attention_type, act = self.act)
        self.increase_dropout()
        
        self.Dropout1 = nn.Dropout2d(self.drop_prob)
        self.Attention1 = Attention(32, 8, self.device, attention_type = self.attention_type, act = self.act)
        self.increase_dropout()
        
        self.Dropout2 = nn.Dropout2d(self.drop_prob)
        self.Attention2 = Attention(56, 16, self.device, attention_type = self.attention_type, act = self.act)
        self.increase_dropout()
    
        self.Dropout3 = nn.Dropout2d(self.drop_prob)
        self.Attention3 = Attention(160, 48, self.device, attention_type = self.attention_type, act = self.act)
        self.increase_dropout()
        
        self.Dropout4 = nn.Dropout2d(self.drop_prob)
        self.Attention4 = Attention(448, 128, self.device, attention_type = self.attention_type, act = self.act)
        
        # Proj Blocks(To Match ResBlocks)
        self.proj0 = ConvBlock(24, 64, 3, 1, 1, 1)
        self.proj1 = ConvBlock(32, 64, 3, 1, 1, 1)
        self.proj2 = ConvBlock(56, 128, 3, 1, 1, 1)
        self.proj3 = ConvBlock(160, 256, 3, 1, 1, 1)
        self.proj4 = ConvBlock(448, 512, 3, 1, 1, 1)
    def forward(self, x):
        '''
        l0: (b, 3, 256, 256)
        l1: (B, 64, 128, 128)
        l2: (B, 64, 64, 64)
        l3: (B, 128, 32, 32)
        l4: (B, 256, 16, 16)
        l5: (B, 512, 8, 8)
        '''
        features0 = self.bn1(self.act1(self.conv1(x))) # (B, 48, 128, 128)
        block0 = self.block0(features0) # (B, 24, 128, 128)
        block0 = self.Dropout0(block0)
        block0 = self.Attention0(block0)
        
        block1 = self.block1(block0) # (B, 32, 64, 64)
        block1 = self.Dropout1(block1)
        block1 = self.Attention1(block1)
        
        block2 = self.block2(block1) # (B, 56, 32, 32)
        block2 = self.Dropout2(block2)
        block2 = self.Attention2(block2)
        
        block3 = self.block3(block2) # (B, 112, 16, 16)
        block4 = self.block4(block3) # (B, 160, 16, 16)
        block4 = self.Dropout3(block4)
        block4 = self.Attention3(block4)
        
        block5 = self.block5(block4) # (B, 272, 8, 8)
        block6 = self.block6(block5) # (B, 448, 8, 8)
        block6 = self.Dropout4(block6)
        block6 = self.Attention4(block6)
        
        # Project Block
        l1 = self.proj0(block0)
        l2 = self.proj1(block1)
        l3 = self.proj2(block2)
        l4 = self.proj3(block4)
        l5 = self.proj4(block6)
        return x, l1, l2, l3, l4, l5   
        

class EncoderQTPi(pl.LightningModule):
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def unfreeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = True
    def increase_dropout(self):
        self.drop_prob += self.increase_drop
    def increase_stochasticity(self):
        '''
        Increases the Rate of Stochastic Depth Dropout(Deeper should drop more.)
        '''
        self.stochastic_depth += self.increase_stoc
    def __init__(self, increase_drop, increase_stoc, attention_type, use_ASPP = False, encoder_type = 'resnet', act = 'relu'):
        # Suggested Increase_drop = 0.05, increase_stoc = 0.1
        super().__init__()
        self.act = act
        self.encoder_type = encoder_type
        assert self.encoder_type in ['resnet', 'effnet']
        self.increase_drop = increase_drop
        self.drop_prob = 5 * self.increase_drop
        
        self.stochastic_depth = 0
        self.increase_stoc = increase_stoc
        
        self.attention_type = attention_type
        
        self.backbone = ResNetEncoderAlpha(self.increase_drop, self.attention_type) if self.encoder_type == 'resnet' else EffNetEncoderAlpha(self.increase_drop, self.attention_type, act = self.act)
        
        self.use_ASPP = use_ASPP
        
        def add_block_stoc(x):
            self.increase_stochasticity()
            return x
        def add_block(x):
            # Adds a Block and Increases the Stochasticity of the model
            self.increase_dropout()
            self.increase_stochasticity()
            return x
        if self.use_ASPP:
            self.ASPP = nn.Sequential(*[
                ASPP(512, 256, 1024, act = self.act)
            ] + [
                add_block_stoc(BottleNeckBlock(1024, 256, self.device, attention_type = self.attention_type, stochastic_depth = self.stochastic_depth, act = self.act)) for i in range(3)
            ])
            
        else:
            self.ASPP = nn.Sequential(*[
                DownSamplerBottleNeck(512, 256, 1024, 2, self.device, attention_type = self.attention_type, act = self.act),
            ] + [
                add_block_stoc(BottleNeckBlock(1024, 256, self.device, attention_type = self.attention_type, stochastic_depth = self.stochastic_depth, act = self.act)) for i in range(3)
            ])
        
        self.layer7 = nn.Sequential(*[
            DownSamplerBottleNeck(1024, 512, 2048, 2, self.device, attention_type = self.attention_type, act = self.act)
        ] + [
            add_block_stoc(BottleNeckBlock(2048, 512, self.device, attention_type= self.attention_type, stochastic_depth = self.stochastic_depth, act = self.act)) for i in range(2)
        ])
        
        self.Dropout6 = nn.Dropout2d(self.drop_prob)
        self.increase_dropout()
        self.Attention6 = Attention(1024, 256, self.device, attention_type = self.attention_type, act = self.act)
        
        self.Dropout7 = nn.Dropout2d(self.drop_prob)
        self.increase_dropout()
        self.Attention7 = Attention(2048, 512, self.device, attention_type = self.attention_type, act = self.act)
        
    def forward(self, x):
        l0, l1, l2, l3, l4, l5 = self.backbone(x) 
        # L5: (B, 512, 8, 8) 
        l6 = self.ASPP(l5)
        l6 = self.Dropout6(l6)
        l6 = self.Attention6(l6)
        
        l7 = self.layer7(l6)
        l7 = self.Dropout7(l7)
        l7 = self.Attention7(l7)
        return l0, l1, l2, l3, l4, l5, l6, l7

# BASE LINE CODE.

In [None]:
class EncoderBaseLine(pl.LightningModule):
    '''
    ResNet34 Pretrained Model
    '''
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def unfreeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = True
    def __init__(self):
        super().__init__()
        self.model_name = 'resnet34d'
        self.model = timm.create_model(self.model_name, pretrained = True)
        # Extract Layers
        self.conv1 = self.model.conv1
        self.bn1 = self.model.bn1
        self.act1 = self.model.act1
        self.pool = self.model.maxpool
        
        self.layer1 = self.model.layer1
        self.layer2 = self.model.layer2
        self.layer3 = self.model.layer3
        self.layer4 = self.model.layer4
        # Freeze Initial Layers
        #self.freeze([self.conv1, self.bn1, self.layer1])
    def forward(self, x):
        features0 = self.bn1(self.act1(self.conv1(x))) # (B, 64, 128, 128)
        
        layer1 = self.layer1(self.pool(features0)) # (B, 64, 64, 64)
        layer2 = self.layer2(layer1) # (B, 128, 32, 32)
        layer3 = self.layer3(layer2) # (B, 256, 16, 16)
        layer4 = self.layer4(layer3) # (B, 512, 8, 8)
        return x, features0, layer1, layer2, layer3, layer4

class BaseLineUNetBlock(pl.LightningModule):
    '''
    UNet Block, upsamples using interpolation(Transposed Convolutions are very unstable and annoying to deal with.)
    '''
    def __init__(self, left_features, down_features, out_features, act = 'relu'):
        super().__init__()
        self.act = act
        self.left_features = left_features
        self.down_features = down_features
        self.out_features = out_features
        
        self.proj = ConvBlock(self.left_features + self.down_features, self.out_features, 3, 1, 1, 1, act = self.act) 
        
        self.process = ConvBlock(self.out_features, self.out_features, 3, 1, 1, 1, act = self.act)
        
    def forward(self, left_features, down_features):
        B, C, H, W = down_features.shape
        upsampled = F.interpolate(down_features, scale_factor= 2, mode = 'nearest') # Upsample images
        if left_features != None:
            upsampled = torch.cat([left_features, upsampled], dim = 1) # Concatenate
        features = self.proj(upsampled)
        return self.process(features)

class DecoderBaseLine(pl.LightningModule):
    '''
    Decoder with nothing Fancy. For Testing and Sanity Check
    '''
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
    
        self.left_features = [256, 128, 64, 64, 0]
        self.down_dims = [512, 256, 128, 64, 32, 16]
        self.dec_blocks = nn.ModuleList([
            BaseLineUNetBlock(self.left_features[i], self.down_dims[i], self.down_dims[i + 1]) for i in range(len(self.left_features))
        ])
        self.proj = nn.Conv2d(16, 1, kernel_size = 3, padding =1)
    def forward(self, l0, l1, l2, l3, l4, l5):
        '''
        Encoder Dims:
        [B, 3, 256, 256]
        [B, 64, 128, 128],
        [B, 64, 64, 64],
        [B, 128, 32, 32],
        [B, 256, 16, 16]
        [B, 512, 8, 8]
        '''
        d4 = self.dec_blocks[0](l4, l5)
        d3 = self.dec_blocks[1](l3, d4)
        d2 = self.dec_blocks[2](l2, d3)
        d1 = self.dec_blocks[3](l1, d2)
        d0 = self.dec_blocks[4](None, d1)
        return self.proj(d0)
        

class BaseLineSolution(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # HYPER PARAMETERS-----------------
        self.num_classes = 1
        # END OF HYPER PARAMETERS ---------
        self.encoder = EncoderBaseLine()
        self.decoder = DecoderBaseLine(self.num_classes) 
    def forward(self, x):
        return torch.squeeze(self.decoder(*self.encoder(x)))

# DECODER

In [None]:
class FPN(pl.LightningModule):
    '''
    Feature Pyramid Network, incorporates information at all scales of the network
    '''
    def __init__(self, in_features, out_features, out_size, attention_type = 'se', act = 'relu'):
        super().__init__()
        self.in_features = in_features
        self.num_blocks = len(self.in_features)
        self.out_features = out_features
        self.out_size = out_size
        self.act = act
        self.attention_type = attention_type
        
        self.SABlocks = nn.ModuleList([
            SqueezeAttend(self.in_features[i], self.out_features // 4, self.out_features, self.out_size, act = self.act, attention_type = self.attention_type) for i in range(self.num_blocks)   
        ])
        self.proj = ConvBlock(self.num_blocks * self.out_features, self.out_features, 3, 1, self.out_features, 1)
        
    def forward(self, x):
        assert isinstance(x, list) and len(x) == self.num_blocks
        # Process Each of the Features
        features = []
        for i in range(self.num_blocks):
            features += [self.SABlocks[i](x[i])]
        # Concatenate
        concat = torch.cat(features, dim = 1)
        return self.proj(concat)
        
        

class DecoderBlockQTPi(pl.LightningModule):
    '''
    Uses Pixel Shuffle, Concatenation, and Attention to Upsample Blocks(Mimic ResNet on the Way up)
    '''
    def increase_stochasticity(self):
        self.stochastic_depth += self.increase_stochastic
    def __init__(self, left_features, down_features, out_features, num_blocks, attention_type, drop_prob, use_pixel_shuffle = True, act = 'relu', stochastic_depth = 0, increase_stochastic = 0):
        super().__init__()
        self.left_features = left_features
        self.out_features = out_features
        self.stochastic_depth = stochastic_depth
        self.increase_stochastic = increase_stochastic
        self.act = act
        self.num_blocks = num_blocks
        self.down_features = down_features
        self.use_pixel_shuffle = use_pixel_shuffle
        self.attention_type = attention_type
        self.drop_prob = drop_prob
        assert self.down_features % 4 == 0
        
        if self.use_pixel_shuffle:
            self.pixel_shuffle = nn.PixelShuffle(2)
            self.att1 = Attention(self.down_features // 4, self.down_features // 16, self.device, act = self.act, attention_type = self.attention_type)
            self.concat_dim = self.left_features + self.down_features // 4
        else:
            self.concat_dim = self.down_features + self.left_features
            self.att1= Attention(self.down_features, self.down_features // 4, self.device, act = self.act, attention_type = self.attention_type)
    
        self.proj = ConvBlock(self.concat_dim, self.out_features, 3, 1, 1, 1)
        
        def add_block(x):
            self.increase_stochasticity()
            return x
        self.blocks = nn.Sequential(*[
            add_block(BottleNeckBlock(self.out_features, self.out_features // 4, self.device, attention_type = self.attention_type, act= self.act, stochastic_depth = self.stochastic_depth)) for i in range(self.num_blocks)
        ])
        self.dropout = nn.Dropout2d(self.drop_prob)
        self.att2 = Attention(self.out_features, self.out_features // 4, self.device, act = self.act, attention_type = self.attention_type)
    
        
    def forward(self, left_features, down_features):
        '''
        x: Tensor(B, C, H, W) 
        '''
        if self.use_pixel_shuffle:
            # Pixel Shuffle Upsample Down Features
            upsampled = self.pixel_shuffle(down_features)
        else:
            upsampled = F.interpolate(down_features, scale_factor = 2, mode = 'nearest')
        upsampled = self.att1(upsampled)
        if left_features != None:
            concat = torch.cat([left_features, upsampled], dim = 1)
        else:
            concat = upsampled # Final Layer
        proj = self.proj(concat)
        blocks = self.blocks(proj)
        dropped = self.dropout(blocks)
        return self.att2(dropped)
        
        

class DecoderQTPi(pl.LightningModule):
    def increase_stochasticity(self):
        self.stochastic_depth += self.increase_stochastic
    def increase_dropout(self):
        self.drop_prob += self.increase_drop
    def __init__(self, num_classes, attention_type, act, use_pixel_shuffle = True, drop_prob = 0, increase_drop = 0, stochastic_depth = 0, increase_stochastic = 0):
        super().__init__()
        self.drop_prob = drop_prob
        self.increase_drop = increase_drop
        self.use_pixel_shuffle = use_pixel_shuffle
        self.stochastic_depth = stochastic_depth
        self.increase_stochastic = increase_stochastic
        self.attention_type = attention_type
        self.act = act
        self.num_classes = num_classes
    
        self.left_features = [1024, 512, 256, 128, 64, 64, 0]
        self.down_dims = [2048, 1024, 512, 256, 128, 64, 64, 16]
        self.num_blocks = [5, 4, 4, 3, 3, 2, 2] # slightly Mimics ResNet's Block Structure
        
        def add_block(i):
            block = DecoderBlockQTPi(self.left_features[i], self.down_dims[i], self.down_dims[i + 1], self.num_blocks[i], self.attention_type, self.drop_prob, stochastic_depth = self.stochastic_depth, increase_stochastic = self.increase_stochastic, act = self.act, use_pixel_shuffle = self.use_pixel_shuffle)
            for x in range(self.num_blocks[i]):
                self.increase_stochasticity()
            self.increase_dropout()
            return block
        self.dec_blocks = nn.ModuleList([
           add_block(i) for i in range(len(self.left_features))
        ])
        
        # FPN layers
        self.FPN = FPN(self.down_dims[1:-2], 64, out_size = 128, attention_type = self.attention_type, act = self.act) 
        self.proj_FPN = ConvBlock(128, 64, 3, 1, 64, 1)
        
        self.proj = nn.Conv2d(self.down_dims[-1], self.num_classes, kernel_size = 3, padding =1, bias = False)
    def forward(self, l0, l1, l2, l3, l4, l5, l6, l7):
        '''
        Encoder Dims:
        [B, 3, 256, 256]
        [B, 64, 128, 128],
        [B, 64, 64, 64],
        [B, 128, 32, 32],
        [B, 256, 16, 16]
        [B, 512, 8, 8]
        '''
        d6 = self.dec_blocks[0](l6, l7)
        d5 = self.dec_blocks[1](l5, d6)
        d4 = self.dec_blocks[2](l4, d5)
        d3 = self.dec_blocks[3](l3, d4)
        d2 = self.dec_blocks[4](l2, d3)
        d1 = self.dec_blocks[5](l1, d2)
        
        
        fpn = self.FPN([d6, d5, d4, d3, d2])
        # Concatenate with the d1
        concat = torch.cat([fpn, d1], dim = 1) 
        fpn_proj = self.proj_FPN(concat)
        
        d0 = self.dec_blocks[6](None, fpn_proj)
        return self.proj(d0)
        
        

# ENTIRE MODEL

In [None]:
class UNetQTPi(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Params
        self.increase_drop_prob = 0.05
        self.stochastic_depth = 0.1
        self.num_classes = 1
        self.attention_type = 'cbam'
        self.model_type = 'resnet'
        self.act = 'relu'
        self.use_ASPP = True
        
        self.use_pixel_shuffle = False
        self.decoder_stoc = 0
        self.decoder_increase_stoc = 0.0
        self.decoder_drop = 0
        self.decoder_increase_drop = 0.00
        # END OF HYPER PARAMETERS
        
        self.encoder = EncoderQTPi(self.increase_drop_prob, self.stochastic_depth, self.attention_type, use_ASPP = self.use_ASPP, encoder_type = self.model_type, act = self.act)
        self.decoder = DecoderQTPi(self.num_classes, self.attention_type, self.act, use_pixel_shuffle = self.use_pixel_shuffle, stochastic_depth = self.decoder_stoc, drop_prob = self.decoder_drop, increase_drop = self.decoder_increase_drop, increase_stochastic = self.decoder_increase_stoc)
        
    def forward(self, x):
        return torch.squeeze(self.decoder(*self.encoder(x)))
        

Loss Functions(BCELoss, Dice Loss, Lovask Loss)

In [None]:
class SymmetricBCE(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.criterion = nn.BCEWithLogitsLoss()
    def forward(self, y_pred, y_true):
        '''
        Symmetric BCE Loss, on both 0s and 1s
        '''
        ones_bools = y_true == 1
        zeros_bools = y_true == 0
        
        y_pred_zeros = y_pred[zeros_bools]
        y_pred_ones = y_pred[ones_bools]
        
        loss_ones = self.criterion(y_pred_ones, torch.ones_like(y_pred_ones, device = y_pred_ones.device)) * 0.5
        loss_zeros = self.criterion(y_pred_zeros, torch.zeros_like(y_pred_zeros, device = y_pred_zeros.device)) * 0.5
        return loss_ones + loss_zeros
class Symmetric_Lovask(nn.Module):
    '''
    Symmetric Lovask loss.
    '''
    def __init__(self):
        super().__init__()
    def forward(self, y_pred, y_true):
        '''
        y_pred: Logits
        y_true: Targets
        '''
        return 0.5 * lovask.lovasz_hinge(y_pred, y_true) + 0.5 * lovask.lovasz_hinge(-y_pred, 1 - y_true)
class DiceLoss(pl.LightningModule):
    def __init__(self):
        super().__init__()
    def forward(self, y_pred, y_true):
        '''
        Computes the Binary Cross Entropy Based Dice Loss.
        y_pred: Binary Predictions, before sigmoid: Shape(B, H, W)  
        y_true: Ground Truth, Shape (B, H, W)
        '''
        B, H, W = y_pred.shape
        sigmoided = torch.sigmoid(y_pred)
        smooth = 1e-6
        
        numerator = 2 * torch.sum(sigmoided * y_true, [1, 2]) + smooth
        denominator = torch.sum(y_true, [1, 2]) + torch.sum(sigmoided, [1, 2]) + smooth
        loss = 1 - numerator / denominator
        loss = torch.sum(loss)
        return loss / B
class CustomLoss(pl.LightningModule):
    '''
    custom Loss function, merging Lovask Loss(0.8) and Cross Entropy(0.2).
    
    Lovask Loss is great in terms of optimizing dice, but it needs help from just predicting 0s.
    '''
    def __init__(self):
        super().__init__()
        self.lovask = Symmetric_Lovask()
        self.bce = SymmetricBCE()
        
        self.lovask_weight = 0.8
        self.bce_weight = 1 - self.lovask_weight
    def forward(self, y_pred, y_true):
        lovask_loss = self.lovask(y_pred, y_true) * self.lovask_weight
        bce_loss = self.bce(y_pred, y_true) * self.bce_weight
        return lovask_loss + bce_loss

Metrics

In [None]:
class AccuracyMetric(pl.LightningModule):
    '''
    Per Pixel Metric
    '''
    def __init__(self):
        super().__init__()
    def threshold(self, y_pred):
        bools = y_pred >= 0.5
        y_pred[:, :, :] = 0
        y_pred[bools] = 1
        return y_pred
    def pos_and_neg(self, y_pred, y_true):
        
        y_pred = torch.sigmoid(y_pred)
        y_pred = self.threshold(y_pred)
        pos_bools= y_true== 1
        neg_bools= y_true == 0
        
        pos_entries = y_pred[pos_bools]
        N = pos_entries.shape[0]
        pos_acc = torch.sum(pos_entries) / N
        
        neg_entries = y_pred[neg_bools]
        N = neg_entries.shape[0]
        neg_acc = torch.sum(1 - neg_entries) / N
        
        return pos_acc, neg_acc
    def forward(self, y_pred, y_true):
        '''
        Logits
        '''
        y_pred = torch.sigmoid(y_pred) 
        y_pred = self.threshold(y_pred)
        B, H, W = y_pred.shape
        return torch.sum((y_pred == y_true).int()) / B / H / W
class DiceMetric(pl.LightningModule):
    def __init__(self):
        super().__init__()
    def threshold(self, y_pred):
        bools = y_pred >= 0.5
        y_pred[:, :] = 0
        y_pred[bools] = 1
        return y_pred
    def forward(self, y_pred, y_true):
        '''
        Measures Dice Metric over logits 
        '''
        B, _, _ = y_pred.shape
        y_pred = torch.sigmoid(y_pred)
        y_pred = self.threshold(y_pred)
        smooth = 1e-6
        numerator = 2 * torch.sum(y_pred * y_true, [1, 2]) + smooth
        denominator = torch.sum(y_pred, [1, 2]) + torch.sum(y_true, [1, 2]) + smooth
        coef = numerator / denominator
        metric = torch.sum(coef) / B
        return metric

# TRAINING MODULES

In [None]:
class TrainingConfig:
    model_name = 'QTPi' # QTPi = Powerful, BaseLine is Simple Transfer Learned
    optim = 'adam'
    criterion_type = 'lovask'
    lr = 1e-3
    weight_decay = 1e-3
    
    num_steps = 5
    step_size = 0.9
    eta_min = 1e-7

In [None]:
class TrainingSolverQTPi(pl.LightningModule):
    def __init__(self, dev, fold_idx = 0):
        super().__init__()
        self.fold_idx = fold_idx
        self.dev = dev
        self.config = TrainingConfig
        self.criterion_type = self.config.criterion_type
        self.model_name = self.config.model_name
        assert self.model_name in ['baseline', 'QTPi']
        assert self.criterion_type in ['custom', 'dice', 'bce', 'lovask']
        self.criterion = self.configure_loss()
        self.model = self.configure_model()
        self.AccuracyMetric = AccuracyMetric()
        self.DiceMetric = DiceMetric()
        self.initialize_states()
        # Send Model to Device
        self.to(self.dev)
    def configure_loss(self):
        if self.criterion_type == 'custom':
            criterion = CustomLoss()
        elif self.criterion_type == 'dice':
            criterion = DiceLoss()
        elif self.criterion_type == 'bce':
            criterion = SymmetricBCE()
        else:
            criterion = Symmetric_Lovask()
        return criterion
    def initialize_states(self):
        # Initializes Hidden States
        self.training_loss = 0
        self.training_pos_acc = 0
        self.training_neg_acc = 0
        self.training_dice = 0
        self.training_steps = 0
        
        self.val_loss = 0
        self.val_neg_acc = 0
        self.val_pos_acc = 0
        self.val_dice = 0
        self.val_steps = 0
        
        self.NUM_EPOCHS = 0
        
        self.best_loss = float('inf')
        self.best_pos_acc = 0
        self.best_neg_acc = 0
        self.best_dice = 0
        
        self.liveloss = livelossplot.PlotLosses()
    def round_states(self):
        # Rounds Loss Stats to 3 Decimals
        if self.training_steps != 0:
            self.training_loss /= self.training_steps
            self.training_pos_acc /= self.training_steps
            self.training_neg_acc /= self.training_steps
            self.training_dice /= self.training_steps
            
            self.training_loss = round(self.training_loss, 3)
            self.training_pos_acc = round(self.training_pos_acc, 3)
            self.training_neg_acc = round(self.training_neg_acc, 3)
            self.training_dice = round(self.training_dice, 3)
        if self.val_steps != 0:
            self.val_loss /= self.val_steps
            self.val_pos_acc /= self.val_steps
            self.val_neg_acc /= self.val_steps
            self.val_dice /= self.val_steps
            
            self.val_loss = round(self.val_loss, 3)
            self.val_neg_acc = round(self.val_neg_acc, 3)
            self.val_pos_acc = round(self.val_pos_acc, 3)
            self.val_dice = round(self.val_dice, 3)

    def display_logs(self):
        logs = {}
        logs['loss'] = self.training_loss
        logs['pos_acc'] = self.training_pos_acc
        logs['neg_acc'] = self.training_neg_acc
        logs['dice'] = self.training_dice
        
        logs['val_loss'] = self.val_loss
        logs['val_pos_acc'] = self.val_pos_acc
        logs['val_neg_acc'] = self.val_neg_acc
        logs['val_dice']= self.val_dice
        
        self.liveloss.update(logs)
        self.liveloss.send()
    def save_states(self):
        # Saves Loss and Dice if improved
        if self.val_loss <= self.best_loss:
            self.best_loss = self.val_loss
            torch.save(self.state_dict(), f"./fold_{self.fold_idx}_loss.pth")
        if self.val_dice >= self.best_dice:
            self.best_dice = self.val_dice
            torch.save(self.state_dict(), f"./fold_{self.fold_idx}_dice.pth")
        # Optionally Save, Uncomment to do so.
        if self.val_pos_acc >= self.best_pos_acc:
            self.best_pos_acc = self.val_pos_acc
            #torch.save(self.state_dict(), f"./fold_{self.fold_idx}_pos_acc.pth")
        if self.val_neg_acc >= self.best_neg_acc:
            self.best_neg_acc = self.val_neg_acc
            #torch.save(self.state_dict(), f"./fold_{self.fold_idx}_neg_acc.pth")
    def reset_states(self):
        # Resets Epoch by Epoch Stats
        self.training_loss = 0
        self.training_pos_acc = 0
        self.training_neg_acc = 0
        self.training_dice = 0
        self.training_steps = 0
        
        self.val_loss = 0
        self.val_neg_acc = 0
        self.val_pos_acc = 0
        self.val_dice = 0
        self.val_steps = 0
        
        self.NUM_EPOCHS += 1
        
    def configure_model(self):
        '''
        Loads in the Model
        '''
        if self.model_name == 'baseline':
            model = BaseLineSolution()
        else:
            model = UNetQTPi()
        return model
    def configure_optimizers(self):
        '''
        Loads in LR Optims and Scheduler.
        '''
        if self.config.optim == 'ranger':
            # Use Ranger(Radam + LookAhead)
            optimizer = Ranger(self.model.parameters(), lr = self.config.lr, weight_decay = self.config.weight_decay)
        else:
            # Just use Adam.
            optimizer = optim.Adam(self.model.parameters(), lr = self.config.lr, weight_decay = self.config.weight_decay) 
        # Load in Both Schedulers
        self.lr_decay1 = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.config.num_steps, eta_min = self.config.eta_min) 
        self.lr_decay2 = optim.lr_scheduler.StepLR(optimizer, self.config.num_steps, self.config.step_size) 
        return [optimizer]
    
    # Training and Val Logic
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.dev)
        y = y.to(self.dev)
        pred = self.model(x)
        loss = self.criterion(pred, y)
        pos_acc, neg_acc = self.AccuracyMetric.pos_and_neg(pred.detach(), y)
        dice = self.DiceMetric(pred.detach(), y)
        
        # Log States
        self.log('loss', loss)
        self.log('pos_acc', pos_acc)
        self.log('neg_acc', neg_acc)
        self.log("dice", dice)
        
        # Print States(Uncomment) 
        print(f"STEP: {batch_idx}, L: {round(loss.item(), 3)} PA: {round(pos_acc.item(), 3)}, NA: {round(neg_acc.item(), 3)}, D: {round(dice.item(), 3)}")
        # Update States
        self.training_pos_acc += pos_acc.item()
        self.training_neg_acc += neg_acc.item()
        self.training_dice += dice.item()
        self.training_loss += loss.item()
        self.training_steps += 1
        return loss
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.dev) 
        y = y.to(self.dev)
        with torch.no_grad():
            pred = self.model(x)
        
            
        loss = self.criterion(pred, y) 
        pos_acc, neg_acc = self.AccuracyMetric.pos_and_neg(pred, y)
        dice = self.DiceMetric(pred, y)
        
        self.log('val_loss', loss)
        self.log('val_pos_acc', pos_acc)
        self.log('val_neg_acc', neg_acc)
        self.log('val_dice', dice)
        
        self.val_dice += dice.item()
        self.val_loss += loss.item()
        self.val_neg_acc += neg_acc.item()
        self.val_pos_acc += pos_acc.item()
        self.val_steps += 1
    def training_epoch_end(self, _):
        self.lr_decay1.step()
        self.lr_decay2.step()
    def validation_epoch_end(self, _):
        self.round_states()
        self.save_states()
        self.display_logs()
        self.reset_states()

In [None]:
pl.seed_everything()
def get_model(fold_idx):
    model = TrainingSolverQTPi(Config.device, fold_idx = fold_idx)
    # Construct Trainer
    callbacks = []
    #callbacks = [pl.callbacks.EarlyStopping(
    #    monitor = 'val_dice',
    #    mode = 'max',
    #    patience = 100000
    #)]
    trainer = pl.Trainer(num_sanity_val_steps = 5, max_epochs = Config.NUM_EPOCHS, checkpoint_callback = False, logger = None, check_val_every_n_epoch = 1, precision = 16, gpus = 1, callbacks = callbacks, deterministic = True, benchmark = False)
    return model, trainer
def overfit_batches(train, num_ex):
    train_dataset = train.dataset
    train, _ = torch.utils.data.random_split(train_dataset, [num_ex, len(train_dataset) - num_ex], generator = torch.Generator().manual_seed(42))
    train_dataloader = torch.utils.data.DataLoader(train, batch_size = 32, worker_init_fn = seed_worker)
    return train_dataloader
def MultiFoldTraining(fold_idx, load_prev = None):
    train, val = dataModule.get_both(fold_idx)
    #train = overfit_batches(train, 32)
    model, trainer = get_model(fold_idx)
    if load_prev != None:
        model.load_state_dict(torch.load(f"{load_prev}fold_{fold_idx}_loss.pth", map_location = Config.device))
    trainer.fit(model, train, val) 

In [None]:
MultiFoldTraining(0, load_prev = "../input/fold0trained/")