In [None]:
%%capture
!pip install ../input/segmentation-models-pytorch-0-1-3/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
!pip install ../input/segmentation-models-pytorch-0-1-3/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
!pip install ../input/segmentation-models-pytorch-0-1-3/timm-0.3.2-py3-none-any.whl
!pip install ../input/segmentation-models-pytorch-0-1-3/segmentation_models.pytorch.0.1.3/segmentation_models.pytorch.0.1.3

In [None]:
%%capture
import os
import gc
import cv2
import pdb
import glob
import pytz
import pickle
import random
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Dataset, sampler
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import warnings

import rasterio
from rasterio.windows import Window

import torch
import pytorch_lightning as pl

import sys
sys.path.append("../input/timm-pytorch-image-models/pytorch-image-models-master/")
import timm
from torchvision.models.resnet import ResNet, Bottleneck



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from fastai.vision.all import *
from albumentations import (
    Compose,
    CenterCrop,
    CLAHE,
    Resize,
    Normalize
)

In [None]:
height, width = 512, 512
reduce = 2
THRESHOLD = 0.5
VOTERS = 0.5
window = 1024
min_overlap = 256
use_TTA = False
DATA = '../input/hubmap-kidney-segmentation/test/'
fold0 = '../input/resnet/models/model.pth'
MODELS = [fold0]
df_sample = pd.read_csv('../input/hubmap-kidney-segmentation/sample_submission.csv')
batch_size = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Mask to Rle and Rle to Mask

In [None]:
#functions to convert encoding to mask and mask to encoding
def enc2mask(encs, shape):
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for m,enc in enumerate(encs):
        if isinstance(enc,np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(height)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1 + m
    return img.reshape(shape).T

def mask2enc(mask, n=1):
    pixels = mask.T.flatten()
    encs = []
    for i in range(1,n+1):
        p = (pixels == i).astype(np.int8)
        if p.sum() == 0: encs.append(np.nan)
        else:
            p = np.concatenate([[0], p, [0]])
            runs = np.where(p[1:] != p[:-1])[0] + 1
            runs[1::2] -= runs[::2]
            encs.append(' '.join(str(x) for x in runs))
    return encs

#https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
#with transposed mask
def rle_encode_less_memory(img):
    #the image should be transposed
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

In [None]:
# Imagenet statistics Mean and variance
mean = np.array([0.63701495, 0.4709702, 0.6817423])
std = np.array([0.15978882, 0.2245109, 0.14173926])
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
def get_transforms(mean, std):
    list_transforms = [Normalize(mean = mean, std = std), ToTensorV2()]
    
    list_trfms = Compose(list_transforms)
    return list_trfms

def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

class HuBMAPDataset(Dataset):
    def __init__(self, data):
        self.data = data
        if self.data.count != 3:
            subdatasets = self.data.subdatasets
            self.layers = []
            if len(subdatasets) > 0:
                for i, subdataset in enumerate(subdatasets, 0):
                    self.layers.append(rasterio.open(subdataset))
        self.shape = self.data.shape
        self.mask_grid = make_grid(self.data.shape, window=window, min_overlap=min_overlap)
        self.transforms = get_transforms(mean, std)
        
    def __len__(self):
        return len(self.mask_grid)
        
    def __getitem__(self, idx):
        x1, x2, y1, y2 = self.mask_grid[idx]
        if self.data.count == 3:
            img = data.read([1,2,3], window=Window.from_slices((x1, x2), (y1, y2)))
            img = np.moveaxis(img, 0, -1)
        else:
            img = np.zeros((window, window, 3), dtype=np.uint8)
            for i, layer in enumerate(self.layers):
                img[:,:,i] = layer.read(window=Window.from_slices((x1,x2),(y1,y2)))
        img = cv2.resize(img, (height, width), cv2.INTER_AREA)
        augmented = self.transforms(image=img)
        img = augmented['image']
        vetices = torch.tensor([x1, x2, y1, y2])
        
        return img, vetices

## Initialize models and load checkpoints

Full Model

In [None]:
class Config:
    IMAGE_SIZE = 512
    BATCH_SIZE = 8 # Small-ish batch size needed to support ASPP + FPN
    NUM_EPOCHS = 48 #200  
    
    NUM_WORKERS = 4
    device = device
    
    encoder_type = 'effnet'
    num_classes = 2
    use_ASPP = True
    use_FPN = True
    attention_type = "none"
    use_linkNet = True # linkNet Blocks should perform better
    use_decoder_attention = False # Special Attention
    gate_attention = False# Reduces Instability of Attention Layers at Beginning of Training.
    act = 'mish' # Actually Performs better than SiLU.
    bottleneck_type = 'inverse'
    buffed_decoder = False # Adds BottleNecks and More Processing to the Decoder.
    buffed_encoder = False # Adds BottleNecks to the Encoder, After the ASPP module.
    num_blocks = 1
    use_bam = False # In Testing.
    bam_dilate = 3
    use_sem = False # In Testing.
    reduction = 1 # reduction factor
    aspp_reduction = 2 # Reduction factor for ASPP Modules.
    expand = 2 # Expansion Factor 

In [None]:
def initialize_weights(layer):
    # More Optimal Initialization for CNNs
    for m in layer.modules():
        if isinstance(m, nn.Conv2d):
            # Kaiming + ReLU
            nn.init.kaiming_normal_(m.weight, nonlinearity = 'relu')
        elif isinstance(m, nn.BatchNorm2d):
            # 1's and 0's
            m.weight.data.fill_(1)
            m.bias.data.zero_()

In [None]:
class Mish(pl.LightningModule):
    # Mish activation, can act as a drop in replacement.
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))
def replace_all(model):
    for child_name, child in model.named_children():
        if isinstance(child, (nn.ReLU, nn.SiLU, timm.models.layers.activations.Swish)):
            setattr(model, child_name, Mish())
        else:
            replace_all(child)
class Act(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.act_type = Config.act
        if self.act_type == 'silu':
            self.act = nn.SiLU(inplace = True)
        elif self.act_type == 'mish':
            self.act = Mish()
        else:
            self.act = nn.ReLU(inplace = True)
    def forward(self, x):
        return self.act(x)
class ConvBlock(pl.LightningModule):
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, bias = False)
        self.bn = nn.BatchNorm2d(out_features)
        self.act1 = Act()
        initialize_weights(self)
    def forward(self, x):
        return self.bn(self.act1(self.conv(x)))
class SqueezeExcite(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.Squeeze = nn.Linear(self.in_features, self.inner_features)
        self.act1 = Act()
        self.Excite = nn.Linear(self.inner_features, self.in_features)
    def forward(self, x):
        mean = torch.mean(x, dim = -1)
        mean = torch.mean(mean, dim = -1)
        
        squeeze = self.act1(self.Squeeze(mean))
        excite = torch.sigmoid(self.Excite(squeeze)).unsqueeze(-1).unsqueeze(-1)
        return excite * x  
class SCSE(pl.LightningModule):
    # Spatial Channel Squeeze Excite
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features  = in_features
        self.inner_features = inner_features
        
        self.squeeze = nn.Linear(self.in_features, self.inner_features)
        self.Act = Act()
        self.excite = nn.Linear(self.inner_features, self.in_features)
        
        self.spatial = nn.Conv2d(self.in_features, 1, kernel_size = 1)
        initialize_weights(self)
    def forward(self, x):
        mean = torch.mean(x, dim = -1)
        mean = torch.mean(mean, dim = -1)
        
        squeeze = self.Act(self.squeeze(mean))
        excite = torch.sigmoid(self.excite(squeeze)).unsqueeze(-1).unsqueeze(-1) * x
        
        spatial = torch.sigmoid(self.spatial(x)) * x
        
        excited = (excite + spatial) / 2 
        return excited

class Attention(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.attention_type = Config.attention_type
        self.gate_attention = Config.gate_attention
        if self.attention_type == 'se':
            self.layer = SqueezeExcite(in_features, inner_features)
        elif self.attention_type == 'scse':
            self.layer = SCSE(in_features, inner_features)
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        if self.attention_type == 'none':
            return x
        processed = self.layer(x)
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * processed + (1 - gamma) * x
        else:
            return processed

class BottleNeckBlock(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.reduction = Config.reduction
        self.Squeeze = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1) 
        self.Process = ConvBlock(self.inner_features, self.inner_features, 3, 1, 1, 1)
        self.Expand = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1)
        self.SE = Attention(self.in_features, self.in_features // self.reduction)

        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        squeeze = self.Squeeze(x)
        process = self.Process(squeeze)
        expand = self.Expand(process)
        SE = self.SE(expand)
        gamma = torch.sigmoid(self.gamma)
        return SE * gamma + (1 - gamma) * x
class InverseBottleNeckBlock(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.reduction = Config.reduction
        self.Expand = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1) 
        self.DW = ConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features, 1)
        self.SE = Attention(self.inner_features, self.inner_features//self.reduction)
        self.Squeeze = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        expand = self.Expand(x)
        dw = self.DW(expand)
        se = self.SE(dw)
        squeeze = self.Squeeze(se)
        gamma = torch.sigmoid(self.gamma)
        return squeeze * gamma + (1 - gamma) * x
class AstrousConvolution(pl.LightningModule):
    '''
    Astrous(More Properly - à trous(at holes in french)) Convolution
    '''
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride, dilation):
        super().__init__()
        self.astrous = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, dilation = dilation, bias = False)
        self.bn = nn.BatchNorm2d(out_features)
        self.act1 = Act()
        initialize_weights(self)
    def forward(self, x):
        return self.bn(self.act1(self.astrous(x)))
class ASPP_Pool(pl.LightningModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.pooling_type = 'mean'
        if self.pooling_type == 'mean':
            self.pool = nn.AdaptiveAvgPool2d((1, 1))
        else:
            self.pool = nn.AdaptiveMaxPool2d((1, 1))
        self.process = nn.Sequential(*[
            ConvBlock(self.in_features, self.out_features, 1, 0, 1, 1)
        ])
    def forward(self, x):
        B, C, H, W = x.shape
        # Pool
        pooled = self.pool(x)
        processed = self.process(pooled)
        upsampled = F.interpolate(processed, size = (H, W), mode = 'bilinear')
        return upsampled
class ASPP(pl.LightningModule):
    '''
    à trous spatial pooling pyramid block. No further Processing, this should be added later.
    
    5 Part:
    - Normal Conv 1x1
    - à trous: 4 dilation
    - à trous: 5 dilation
    - à trous: 7 dilation
    '''
    def __init__(self, in_features, inner_features, out_features, stride = 1):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride
        self.num_groups = 4
        
        self.pool = ASPP_Pool(self.in_features, self.inner_features)
        self.conv1 = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.conv2 = AstrousConvolution(self.in_features, self.inner_features, 3, self.stride * 1, self.num_groups, 1, self.stride * 1)
        self.conv3 = AstrousConvolution(self.in_features, self.inner_features, 3, self.stride * 3, self.num_groups, 1, self.stride * 3)
        self.conv4 = AstrousConvolution(self.in_features, self.inner_features, 3, self.stride * 5, self.num_groups, 1, self.stride * 5)
        self.conv5 = AstrousConvolution(self.in_features, self.inner_features, 3, self.stride * 7, self.num_groups, 1, self.stride * 7)
        
        self.conv_proj = ConvBlock(self.inner_features * 6, self.out_features, 1, 0, 1, 1)
        initialize_weights(self)
    def forward(self, x):
        pool = self.pool(x)
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        conv4 = self.conv4(x)
        conv5 = self.conv5(x)
        
        concat = torch.cat([pool, conv1, conv2, conv3, conv4, conv5], dim = 1)
        return self.conv_proj(concat)
class BAM(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.Squeeze = nn.Linear(self.in_features, self.inner_features)
        self.Act = Act()
        self.Excite = nn.Linear(self.inner_features, self.in_features)
        
        self.dilation = Config.bam_dilate
        self.SqueezeConv = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.DA = AstrousConvolution(self.inner_features, self.inner_features, 3, self.dilation, self.inner_features, 1, self.dilation)
        self.ExciteConv = ConvBlock(self.inner_features, 1, 1, 0, 1, 1)
        self.gate_attention = Config.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        mean = torch.squeeze(x, dim = -1)
        mean = torch.squeeze(mean, dim = -1)
        
        squeeze = self.Act(self.Squeeze(mean))
        excite = self.Excite(squeeze).unsqueeze(-1).unsqueeze(-1)
        
        squeeze_conv = self.SqueezeConv(x)
        DA = self.DA(squeeze_conv)
        excite_conv = self.ExciteConv(DA)
        
        excited = torch.sigmoid((excite_conv + excite) / 2) * x
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excited + (1 - gamma) * x
        return excited
class SEM(pl.LightningModule):
    def __init__(self, in_features, inner_features, stride = 1):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.stride = stride
        
        self.Squeeze = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.FS = ConvBlock(self.inner_features, self.inner_features, 3, 1, 1, 1)
        
        # Dilation ASPP
        self.conv1 = AstrousConvolution(self.inner_features, self.inner_features, 3, self.stride * 1, self.inner_features, 1, self.stride * 1)
        self.conv2 = AstrousConvolution(self.inner_features, self.inner_features, 3, self.stride * 2, self.inner_features, 1, self.stride * 2)
        self.conv3 = AstrousConvolution(self.inner_features, self.inner_features, 3, self.stride * 3, self.inner_features, 1, self.stride * 3)
        self.conv4 = AstrousConvolution(self.inner_features, self.inner_features, 3, self.stride * 4, self.inner_features, 1, self.stride * 4)
        
        self.proj = ConvBlock(self.inner_features * 4 + self.in_features, self.in_features, 1, 0, 1, 1)
    def forward(self, x):
        squeezed = self.Squeeze(x)
        FS = self.FS(squeezed)
        
        conv1 = self.conv1(FS)
        conv2 = self.conv2(FS)
        conv3 = self.conv3(FS)
        conv4 = self.conv4(FS)
        
        concat = torch.cat([x, conv1, conv2, conv3, conv4], dim = 1)
        proj = self.proj(concat)
        return proj

In [None]:
class EncoderUNext(pl.LightningModule):
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def unfreeze(self, layer):
        for parameter in layer.parameters():
            parameter.requires_grad = True
    def __init__(self):
        super().__init__()
        self.model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)
        #weights = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext50_32x4d_swsl')
        #self.model.load_state_dict(weights.state_dict())
        
        self.conv1 = self.model.conv1 # 64
        self.bn1 = self.model.bn1
        self.act1 = Mish()
        self.maxpool = self.model.maxpool
        
        self.layer1 = self.model.layer1 # 256
        self.layer2 = self.model.layer2 # 512
        # Freeze Initial Layers
        self.freeze([self.conv1, self.bn1, self.layer1])
        
        self.layer3 = self.model.layer3 # 1024
        self.layer4 = self.model.layer4 # 2048
        
        self.aspp_reduction = Config.aspp_reduction
        self.ASPP = ASPP(2048, 2048 // self.aspp_reduction, 512)
        del self.model
    def forward(self, x):
        features0 = self.bn1(self.act1(self.conv1(x)))
        layer1 = self.layer1(self.maxpool(features0))
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)
        layer4 = self.ASPP(layer4)
        
        features = [x, features0, layer1, layer2, layer3, layer4]
        return features
class EncoderResNet(pl.LightningModule):
    def freeze(self, layer):
        for parameter in layer.parameters():
            parameter.requires_grad = False
    def unfreeze(self, layer):
        for parameter in layer.parameters():
            parameter.requires_grad = False
    def __init__(self):
        super().__init__()
        self.model_name = 'resnet34d'
        self.model = timm.create_model(self.model_name, pretrained = False)
        # Extract Layers
        self.enc_dims = [64, 64, 128, 256, 512]
        self.conv1 = self.model.conv1
        self.bn1 = self.model.bn1
        self.act1 = self.model.act1
        self.maxpool = self.model.maxpool
        self.layer1 = self.model.layer1
        self.layer2 = self.model.layer2
        self.layer3 = self.model.layer3
        self.layer4 = self.model.layer4
        
        self.aspp_reduction = Config.aspp_reduction
        self.use_aspp = Config.use_ASPP
        if self.use_aspp:
            self.ASPP = ASPP(self.enc_dims[-1], self.enc_dims[-1] // self.aspp_reduction, self.enc_dims[-1])
        
        
    def forward(self, x):
        features0 = self.bn1(self.act1(self.conv1(x))) # 64 
        layer1 = self.layer1(self.maxpool(features0)) # 64
        layer2 = self.layer2(layer1) # 128
        layer3 = self.layer3(layer2) # 256
        layer4 = self.layer4(layer3) # 512
        
        layer4 = self.ASPP(layer4)
        features = [x, features0, layer1, layer2, layer3, layer4]
        return features
        
class EncoderQTPi(pl.LightningModule):
    def freeze_beginning(self):
        self.freeze([self.model.encoder])
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def unfreeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = True
    def __init__(self):
        super().__init__()
        self.enc_dims = [3, 32, 16, 24, 40, 80, 112, 320]
        # HYPER PARAMETERS
        self.base_name = 'timm-efficientnet-b0'
        self.model_name = 'tf_efficientnet_b0_ns' # Larger Efficientnets provide similar performance. If needed, I can scale this up.
        # END OF HYPER PARAMETERS
        self.model = smp.Unet(self.base_name, encoder_weights = None)
        self.weights = timm.create_model(self.model_name, pretrained = False)
        self.model.encoder.load_state_dict(self.weights.state_dict())
        del self.weights
        # Freeze Layer
        # Custom Layers(Attention - SE, Dropout2d)
        self.use_ASPP = Config.use_ASPP
        self.aspp_reduction = Config.aspp_reduction
    
        if self.use_ASPP:
            self.block7 = nn.Sequential(*[
                ASPP(self.enc_dims[7], self.enc_dims[7] // self.aspp_reduction, self.enc_dims[7])
            ])
        else:
            self.block7 = nn.Identity()
        self.buff_encoder = Config.buffed_encoder
        if self.buff_encoder:
            self.num_blocks = Config.num_blocks
            self.expansion = Config.expand
            self.block8 = nn.Sequential(*[
                InverseBottleNeckBlock(self.enc_dims[7], self.enc_dims[7] * self.expansion) for i in range(self.num_blocks)
            ])
        else:
            self.block8 = nn.Identity()
        self.use_bam = Config.use_bam
        self.reduction = Config.reduction
        if self.use_bam:
            # Two BAM blocks added, one after the encoder, and one after ASPP
            self.bam1 = BAM(self.enc_dims[7], self.enc_dims[7] // self.reduction)
            self.bam2 = BAM(self.enc_dims[7], self.enc_dims[7] // self.reduction)
        else:
            self.bam1 = nn.Identity()
            self.bam2 = nn.Identity()
    def forward(self, x):
        '''
        x: Tensor(B, 3, 512, 512)
        Returns:
        l0: Tensor(B, 16, 256, 256)
        l1: Tensor(B, 24, 128, 128)
        l2: Tensor(B, 48, 64, 64)
        l3: Tensor(B, 120, 32, 32)
        l4: Tensor(B, 352, 16, 16)
        l5: Tensor(B, 512, 8, 8) 
        '''
        x, l0, l1, l2, l3, l4 = tuple(self.model.encoder(x))
        l4 = self.bam1(l4)
        l4 = self.block7(l4)
        l4 = self.block8(l4)
        l4 = self.bam2(l4)
        features = [x, l0, l1, l2, l3, l4]
        return features

In [None]:
# Special Convolutional Blocks for the UNet Decoder:
class RecurrentConvolution(pl.LightningModule):
    '''
    Recurrent Convolution Block
    '''
    def __init__(self, in_features, kernel_size, padding, groups, t = 2):
        super().__init__()
        self.in_features = in_features
        self.kernel_size = kernel_size
        self.padding = padding
        self.groups = groups
        self.t = t
        
        self.block = ConvBlock(self.in_features, self.in_features, self.kernel_size, self.padding, self.groups, 1)
    def forward(self, x):
        for t in range(self.t):
            if t == 0:
                x1 = self.block(x)
            else:
                x1 = self.block((x + x1) / 2)
        return x1
class RecurrentBlock(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.reduction = Config.reduction
        self.conv = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.recurrent = RecurrentConvolution(self.inner_features, 3, 1, self.inner_features)
        self.SE = Attention(self.inner_features, self.inner_features // self.reduction)
        self.conv2 = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1)
        
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        proj_down = self.conv(x)
        recurrent = self.recurrent(proj_down)
        se = self.SE(recurrent)
        conv2 = self.conv2(se)
        
        gamma = torch.sigmoid(self.gamma)
        return gamma * conv2 + (1 - gamma) * x
class GatedSpatialAttention(pl.LightningModule):
    '''
    Base Gated Spatial Attention
    '''
    def __init__(self, left_features, down_features, inner_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.inner_features = inner_features
        
        self.ConvLeft = nn.Conv2d(self.left_features, self.inner_features, kernel_size = 1, bias = False)
        self.ConvDown = nn.Conv2d(self.down_features, self.inner_features, kernel_size = 1, bias = False)
        
        self.BatchNorm = nn.BatchNorm2d(self.inner_features)
        self.act = Act()
        
        self.ConvBlock = nn.Conv2d(self.inner_features, self.left_features, kernel_size = 1, bias = False)
        self.BatchNorm2 = nn.BatchNorm2d(self.left_features)
        self.gate_attention = Config.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
        initialize_weights(self)
    def forward(self, left_features, down_features):
        conv_left = self.ConvLeft(left_features)
        conv_down = self.ConvDown(down_features)
    
        conv = self.BatchNorm(self.act((conv_down + conv_left) / 2))
        logits = torch.sigmoid(self.BatchNorm2(self.ConvBlock(conv)))
        excite = logits * left_features
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excite + (1 - gamma) * left_features
        return excite
        
class GatedChannelAttention(pl.LightningModule):
    '''
    Similar to the Attention UNet, but with SE principles.
    
    I find that Conv2d never works for attention.
    '''
    def __init__(self, left_features, down_features, inner_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.inner_features = inner_features
        
        self.LeftSqueeze = nn.Linear(self.left_features, self.inner_features)
        self.Act = Act()
        self.DownSqueeze = nn.Linear(self.down_features, self.inner_features)
        
        self.Excite = nn.Linear(self.inner_features, self.left_features)
        self.gate_attention = Config.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, left_features, down_features):
        
        mean_left = torch.mean(left_features, dim = -1)
        mean_left = torch.mean(mean_left, dim = -1)
        
        mean_down = torch.mean(down_features, dim = -1)
        mean_down = torch.mean(mean_down, dim = -1)
        
        squeeze_left = self.LeftSqueeze(mean_left)
        squeeze_down = self.DownSqueeze(mean_down)
        
        squeeze = self.Act((squeeze_left + squeeze_down) / 2)
        
        excite = torch.sigmoid(self.Excite(squeeze)).unsqueeze(-1).unsqueeze(-1) * left_features
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excite + (1 - gamma) * left_features
        return excite
class ChooseBottleNeck(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.bottleneck_type = Config.bottleneck_type
        assert self.bottleneck_type in ['none', 'recurrent', 'inverse', 'bottleneck']
        if self.bottleneck_type == 'recurrent':
            self.layer = RecurrentBlock(self.in_features, self.inner_features)
        elif self.bottleneck_type == 'inverse':
            self.layer = InverseBottleNeckBlock(self.in_features, self.inner_features)
        elif self.bottleneck_type == 'bottleneck':
            self.layer = BottleNeckBlock(self.in_features, self.inner_features)
        else:
            self.layer = nn.Identity()
    def forward(self, x):
        return self.layer(x)

In [None]:
class FPN(pl.LightningModule):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        assert isinstance(self.in_channels, list) 
    
        self.conv_proj = nn.ModuleList([
            nn.Sequential(*[
                ConvBlock(self.in_channels[idx], self.out_channels * 2, 3, 1, 1, 1),
                ConvBlock(self.out_channels * 2, self.out_channels, 3, 1, 1, 1)
            ]) 
            for idx in range(len(self.in_channels))])
        
    def forward(self, features, last_dim):
        B, C, H, W = last_dim.shape
        concatted_features = []
        for idx in range(len(features)):
            processed = self.conv_proj[idx](features[idx])
            upsampled = F.interpolate(processed, size = (H, W), mode = 'bilinear')
            concatted_features += [upsampled]
        concat = torch.cat([last_dim] + concatted_features, dim = 1)
        return concat

class LinkNetBlockQTPi(pl.LightningModule):
    def __init__(self, left_features, down_features, out_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.out_features = out_features
        
        self.PixelShuffle = PixelShuffle_ICNR(self.down_features, self.down_features, blur = True)
        self.reduction = Config.reduction
        
        self.use_attention = Config.use_decoder_attention
        self.Conv1 = ConvBlock((self.down_features + self.left_features), self.out_features, 3, 1, 1, 1)
        self.Conv2 = ConvBlock(self.out_features, self.out_features, 3, 1, 1, 1)
        self.attention2 = Attention(self.out_features, self.out_features // self.reduction)
        if self.use_attention and self.left_features != 0:
            self.attention1 = GatedChannelAttention(self.left_features, self.down_features, self.left_features // self.reduction)
        self.buff_decoder = Config.buffed_decoder
        if self.buff_decoder:
            # Add a Few Residual Blocks
            self.num_blocks = Config.num_blocks
            self.expand = Config.expand 
            self.additional_blocks = nn.Sequential(*[
                InverseBottleNeckBlock(self.out_features, self.out_features * self.expand) for i in range(self.num_blocks)
            ])
    def forward(self, left_features, down_features):
        down_features = self.PixelShuffle(down_features)
        if left_features is not None:
            if self.use_attention:
                left_features = self.attention1(left_features, down_features)
            down_features = torch.cat([down_features, left_features], dim = 1)
        conv1 = self.Conv1(down_features)
        conv2 = self.Conv2(conv1)
        attention2 = self.attention2(conv2)
        if self.buff_decoder:
            attention2 = self.additional_blocks(attention2) # gives slightly more power to the decoder. Use with risk.
        return attention2
class DecoderBlockQTPi(pl.LightningModule):
    def __init__(self, left_features, down_features, out_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.out_features = out_features
        self.reduction = Config.reduction
        
        self.use_attention = Config.use_decoder_attention
        self.conv1 = ConvBlock(self.left_features + self.down_features, self.out_features, 3, 1, 1, 1)
        self.conv2 = ConvBlock(self.out_features, self.out_features, 3, 1, 1, 1)
        self.att2 = Attention(self.out_features, self.out_features // self.reduction)
        if self.use_attention and self.left_features != 0 and self.down_features != 0:
            self.attention = GatedChannelAttention(self.left_features, self.down_features, self.left_features // self.reduction)
        self.buff_decoder = Config.buffed_decoder
        if self.buff_decoder:
            self.num_blocks = Config.num_blocks
            self.expand = Config.expand
            self.additional_blocks = nn.Sequential(*[
                InverseBottleNeckBlock(self.out_features, self.out_features * self.expand) for i in range(self.num_blocks)
            ])
    def forward(self, left_features, down_features):
        down_features = F.interpolate(down_features, scale_factor = 2, mode = 'nearest')
        if left_features is not None:
            # Attend
            if self.use_attention:
                left_features = self.attention(left_features, down_features)
            down_features = torch.cat([down_features, left_features], dim = 1)
        conv1 = self.conv1(down_features)
        conv2 = self.conv2(conv1)
        conv2 = self.att2(conv2)
        if self.buff_decoder:
            conv2= self.additional_blocks(conv2)
        return conv2
class DecoderQTPi(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.num_classes = Config.num_classes
        self.encoder_type = Config.encoder_type
        if self.encoder_type == 'resnet':
            self.left_dim = [256, 128, 64, 64, 0]
            self.down_dim = [512, 256, 128, 64, 32, 16]
        elif self.encoder_type == 'unext':
            self.left_dim = [1024, 512, 256, 64, 0]
            self.down_dim = [512, 256, 128, 64, 32, 16]
        else:    
            self.left_dim = [112,  40,  24, 32,  0]
            self.down_dim = [320, 256, 128, 64, 32, 16] 
        
        self.useLinkNet = Config.use_linkNet
        def block(idx):
            if self.useLinkNet:
                return LinkNetBlockQTPi(self.left_dim[idx], self.down_dim[idx], self.down_dim[idx + 1])
            else:
                return DecoderBlockQTPi(self.left_dim[idx], self.down_dim[idx], self.down_dim[idx + 1])
            
        self.decoder_blocks = nn.ModuleList([
            block(i) for i in range(len(self.left_dim)) 
        ])
        self.use_SEM = Config.use_sem
        self.aspp_reduction = Config.aspp_reduction
        if self.use_SEM:
            # 2 SEM Blocks - Like the 2 BAM Blocks in Encoder - in early decoder to save memory
            self.sem1 = SEM(self.down_dim[1], self.down_dim[1] // self.aspp_reduction)
            self.sem2 = SEM(self.down_dim[2], self.down_dim[2] // self.aspp_reduction)
        else:
            self.sem1 = nn.Identity()
            self.sem2 = nn.Identity()
        self.use_FPN = Config.use_FPN
        if self.use_FPN:
            self.FPN = FPN(self.down_dim[0:-2], self.down_dim[-2])
        self.drop_final = nn.Dropout2d(0.0) # Small DropProb at end 0.1 Default
        self.drop_middle = nn.Dropout2d(0.0) # Large Drop in Middle, 0.5 for ASPP
        if self.use_FPN:
            self.fpn_proj = ConvBlock(self.down_dim[-2] * 5, self.down_dim[-2], 1, 0, 1, 1)
        
        self.proj = nn.Conv2d(16, self.num_classes, kernel_size = 3, padding = 1)
        
    def forward(self, x0, l0, l1, l2, l3, l4):
        '''
        l0: Tensor(B, 16, 128, 128)
        l1: Tensor(B, 24, 64, 64) - FPN 2x
        l2: Tensor(B, 40, 32, 32) - FPN 4x
        l3: Tensor(B, 112, 16, 16) - FPN 8x
        l4: Tensor(B, 320, 8, 8) - FPN 16x
        '''
        # Drop Middle
        l4 = self.drop_middle(l4)
        d4 = self.decoder_blocks[0](l3, l4) # 16
        d4 = self.sem1(d4)
        
        d3 = self.decoder_blocks[1](l2, d4) # 32
        d3 = self.sem2(d3)
        
        d2 = self.decoder_blocks[2](l1, d3) # 64
        
        d1 = self.decoder_blocks[3](l0, d2) # 128
        if self.use_FPN:
            d1 = self.FPN([l4, d4, d3, d2], d1)
            d1 = self.fpn_proj(d1)
        d0 = self.decoder_blocks[4](None, d1) # 256
        # Drop Final
        d0 = self.drop_final(d0)
        # Segmentation Head
        pred = self.proj(d0)
        return pred

In [None]:
class UNetQTPi(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder_type = Config.encoder_type
        if self.encoder_type == 'resnet':
            self.encoder = EncoderResNet()
        elif self.encoder_type == 'unext':
            self.encoder = EncoderUNext()
        else:
            self.encoder = EncoderQTPi()
            
        self.decoder = DecoderQTPi()
        if Config.act == 'mish':
            replace_all(self)
    def forward(self, x):
        return torch.squeeze(self.decoder(*self.encoder(x)))

In [None]:
class TrainingConfig:
    lr = 1e-3
    weight_decay = 1e-3 # Increase Later
    # Increase Dropout Later
    NUM_WORKERS = 4
    patience = 3
    factor = 0.2
    use_SWA = False # Stochastic Weight Averaging leads to better results often.
    mish = True
    eta_min = 1e-9
    num_steps = 5
    clip_grads = 20

In [None]:
class SMP(pl.LightningModule):
    def __init__(self):
        # BaseLine SMP Model
        super().__init__()
        self.Model = smp.Unet('efficientnet-b0', encoder_weights = None, classes = Config.num_classes)
    def forward(self, x):
        return self.Model(x)

In [None]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = SMP()
    def forward(self, x):
        return self.model(x)

In [None]:
%%capture
model = SMP()
model.load_state_dict(torch.load(MODELS[0], map_location = device))
model.to(device)

In [None]:
class TTAModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        flip0 = x
        flip1 = x.flip(-1)
        flip2 = x.flip(-1).flip(-2)
        flip3 = x.flip(-2)
        
        pred0 = self.model(flip0)
        pred1 = self.model(flip1)
        pred2 = self.model(flip2)
        pred3 = self.model(flip3)
        
        # UnFlip
        pred = pred0
        pred = pred + pred1.flip(-1)
        pred = pred + pred2.flip(-1).flip(-2)
        pred = pred + pred3.flip(-2)
        return pred / 4

In [None]:
class EnsembleModel(pl.LightningModule):
    def __init__(self, models):
        super().__init__()
        self.models = models
        #self.use_tta = use_TTA
        #if self.use_tta:
        #    models = []
        #    for model in self.models:
        #        models += [TTAModel(model)]
        #    self.models = models
        #self.num_models = len(self.models)
    def forward(self, x):
        self.eval()
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                logits = F.softmax(self.models(x), dim = 1)
                #for model in self.models:
                #    vals = F.softmax(model(x), dim = 1) # (B, 2, H, W)
                #    if logits is None:
                #        logits = vals
                #    else:
                #        logits = logits + vals
                #logits = logits / self.num_models
                # ArgMax over dims
                _, selected = torch.max(logits, dim = 1) # (B, H, W)
                return torch.squeeze(selected)

In [None]:
model = EnsembleModel(TTAModel(model))

Inference

In [None]:
names, predictions = [],[]
for idx, row in tqdm(df_sample.iterrows(),total=len(df_sample)):
    imageId = row['id']
    data = rasterio.open(os.path.join(DATA, imageId+'.tiff'), transform = identity, num_threads='all_cpus')
    preds = np.zeros(data.shape, dtype=np.uint8)
    dataset = HuBMAPDataset(data)
    dataloader = DataLoader(dataset, batch_size = batch_size, num_workers=0, shuffle=False, pin_memory=True)
    for i, (img, vertices) in enumerate(dataloader):
        img = img.to(device)
        pred = model(img)
        pred = pred.squeeze().cpu().float()
        pred = pred.numpy()
        vertices = vertices.numpy()
        assert not np.isnan(np.sum(pred))
        for p, vert in zip(pred, vertices):
            x1, x2, y1, y2 = vert
            p = cv2.resize(p, (window, window))
            p = (p > THRESHOLD).astype(np.uint8)
            preds[x1:x2,y1:y2] += p
    del dataset
    del dataloader
    preds = (preds > VOTERS).astype(np.uint8)

    #convert to rle
    rle = rle_encode_less_memory(preds)
    names.append(imageId)
    predictions.append(rle)
    del preds
    gc.collect()

In [None]:
df = pd.DataFrame({'id':names,'predicted':predictions}).set_index('id')
df = df.reset_index()
df.to_csv('submission.csv', index = False)