This notebooks covers the training code for Weight Watcher, a Food Classification, Calorie and Weight Estimation Deep Learning Bot

From an image, it estimates the weight of the food and then estimates the calories of the dish.

Model Strategy:
- Segmentation for Pretraining, then classification + regression
Segmentation:
- Modified UNet + ASPP, to segment the coin and food(1 Yuan coin for scale, else weight estimation is impossible)
- BackBone is then extracted(now that it has good features), and using the backbone, a Classification and Regression head is applied

Regression: For Weight, calories can be directly looked up in table.
classification - Classify into 30 different foods.


# Import Dependencies

Potential Change: - Object Detection for Coin and Food?
- 2x BBox Regression from image features(Use this if semantic seg fails)

In [None]:
%%capture
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import torchvision

import albumentations as A
from albumentations.pytorch import ToTensorV2

import numpy
import numpy as np
import pandas as pd
import json
import cv2

import os
import math
import copy
import random

!pip install efficientnet-pytorch
from efficientnet_pytorch import EfficientNet

!apt-get update
!apt-get install libturbojpeg
!pip install -U git+git://github.com/lilohuang/PyTurboJPEG.git


!pip install PyTurboJPEG
from turbojpeg import TurboJPEG

import pytorch_lightning as pl
from fastai.vision.all import *
from collections import Counter 

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
!pip install xlrd
import xml.etree.ElementTree as ET
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import glob

In [None]:
# Reproducibility:
import os
import random
def seed_all():
    seed = 42
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    # Slight Stochasticity Tradeoff for Quicker Comp.
    torch.backends.cudnn.benchmark = False # True for faster
    pl.seed_everything()
seed_all()
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
# HELPER Fn's
def display_image_pt(image):
    plt.imshow(image.cpu().transpose(0, 1).transpose(1, 2))
    plt.show()
def display_image_np(image):
    plt.imshow(image)
    plt.show()
def display_image_mask(image, mask):
    plt.imshow(image.cpu().transpose(0, 1).transpose(1, 2))
    plt.imshow(mask, alpha = 0.5)
    plt.show()

In [None]:
class BaseDataset(torch.utils.data.Dataset):
    def __init__(self, image_base_path, train_df):
        self.train_df = train_df
        self.image_base_path = image_base_path
        self.processed = self.process_dataframe()
    def process_dataframe(self):
        '''
        Train_Df: Dataframe of 174 images.
        Extracts the other 2978 images.
        '''
        processed = []
        for row in self.train_df.iterrows():
            row = row[1]
            image_id = row.id
            if 'mix' in image_id:
                continue
            class_val = row.type
            volume = row['volume(mm^3)']
            weight = row['weight(g)']
            # Glob and Find all files with these statistics
            files = glob.glob(f"{self.image_base_path}{image_id}*")
            for file in files:
                file_path = os.path.splitext(file)[0]
                # Find First / and String away file path 
                string = ''
                for idx in range(len(file_path) -1, -1, -1):
                    if file_path[idx] == '/':
                        string = file_path[idx + 1:]
                        break
                processed += [{string: {'volume': volume, 'weight': weight, 'class': class_val}}]
        return processed

In [None]:
def get_augmentations(IMAGE_SIZE):
    # Heavy, Heavy Augmentation
    train_transforms = A.Compose([
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.OneOf([
            A.ColorJitter(p = 1, hue = 0.1, saturation = 0.1),
            A.RandomGamma(p=1)
        ], p=.5),
        A.OneOf([
            A.Blur(blur_limit=3, p=1),
            A.MedianBlur(blur_limit=3, p=1)
        ], p=.25),
        A.OneOf([
            A.GaussNoise(0.002, p=.5),
            A.IAAAffine(p=.5),
        ], p=.25),
        A.OneOf([
            A.ElasticTransform(alpha=120, sigma=120 * .05, alpha_affine=120 * .03, p=.5),
            A.GridDistortion(p=.5),
            A.OpticalDistortion(distort_limit=2, shift_limit=.5, p=1)                  
        ], p=.25),
        A.RandomRotate90(p=.5),
        A.HorizontalFlip(p=.5),
        A.Cutout(num_holes=10, 
                    max_h_size=int(.01 * IMAGE_SIZE), max_w_size=int(.01 * IMAGE_SIZE), 
                    p=.25),
        A.ShiftScaleRotate(p=.5),
        A.Normalize()
    ])
    test_transforms = A.Compose([
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.Normalize()
    ])
    return train_transforms, test_transforms
def load_excel(df_path):
    df = pd.read_excel(df_path, sheet_name = None)
    concat = None
    for key in df:
        if concat is None:
            concat = df[key]
        else:
            concat = concat.append(df[key])
    return concat
class Config:
    annotation_base_path = '../input/fooddataset/dataset/ECUSTFD-resized--master/Annotations/'
    image_base_path = '../input/fooddataset/dataset/ECUSTFD-resized--master/JPEGImages/'
    df_path = '../input/fooddataset/dataset/ECUSTFD-resized--master/density.xls'
    pretrained_model = '../input/pretrainedunet/best.pth' # Trained in prev version of the notebook.
    df = load_excel(df_path)
    files = BaseDataset(image_base_path, df).processed
    # Split into Train and Val
    train, val = train_test_split(files, train_size = 0.95, test_size = 0.05, random_state = 42)
    # Weight Information -> Calories(Hard Coded)
    weights2Cal ={
        'apple': 0.52,
        'banana': 0.89,
        'bread': 3.15,
        'bun':2.23,
        'doughnut':4.34,
        'egg': 1.43,
        'fired_dough_twist': 24.16,
        'grape': 0.69,
        'lemon': 0.29,
        'litchi': 0.66,
        'mango': 0.60,
        'mooncake': 18.83,
        'orange': 0.63,
        'peach': 0.57,
        'pear': 0.39,
        'plum': 0.46,
        'qiwi': 0.61,
        'sachima': 21.45,
        'tomato': 0.27
    }
    
    classes = sorted(weights2Cal.keys())
    cls2idx = {}
    idx2cls = {}
    for idx in range(len(classes)):
        cls2idx[classes[idx]] = idx
        idx2cls[idx] = classes[idx]
    num_classes = len(classes)
    # Hyper Parameters
    IMAGE_SIZE = 256
    train_transforms, test_transforms = get_augmentations(IMAGE_SIZE)
    to_tensor = ToTensorV2()
    

# Data Processing:
- Note: All images are smaller than 1024x1024, but there is only 2978 of them, so I expect heavy overfitting
- Note also: Images are of variable size, thus I must create a segmentation mask first, then resize using Albumentations.

In [None]:
 class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, dict_mapping, train = True):
        self.dict_mapping = dict_mapping
        self.TurboJPEG = TurboJPEG()
        self.train = train
    def __len__(self):
        return len(self.dict_mapping)
    def __getitem__(self, idx):
        '''
        Generates a Segmentation Mask for the Image 
        
        0: Nothing there
        1: Coin or Food
        Trained with Lovask Loss
        '''
        row = self.dict_mapping[idx]
        # Extract Data
        ids = list(row.keys())[0]
        image_path = Config.image_base_path + ids + '.JPG'
        annot_path = Config.annotation_base_path + ids + '.xml'
        # Load in Image
        with open(image_path, 'rb') as file:
            image = self.TurboJPEG.decode(file.read())
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # Load in XML
        root = ET.parse(annot_path).getroot()
        # Extract Major Headings
        size = root[4]
        food = root[6]
        coin = root[7]
        # Parse Further
        width = int(size[0].text)
        height = int(size[1].text)
        
        bnd_box_food = food[4]
        x_min_food = int(bnd_box_food[0].text)
        y_min_food = int(bnd_box_food[1].text)
        x_max_food = int(bnd_box_food[2].text)
        y_max_food = int(bnd_box_food[3].text)
        
        bnd_box_coin = coin[4]
        x_min_coin = int(bnd_box_coin[0].text)
        y_min_coin = int(bnd_box_coin[1].text)
        x_max_coin = int(bnd_box_coin[2].text)
        y_max_coin = int(bnd_box_coin[3].text)
        
        # Generate Segmentation Mask
        seg_mask = np.zeros((height, width))
        # Fill Coin
        seg_mask[y_min_coin: y_max_coin, x_min_coin: x_max_coin] = 1
        # Fill Food
        seg_mask[y_min_food: y_max_food, x_min_food: x_max_food] = 1
        # Augment Both Image and Mask
        if self.train:
            augmented = Config.train_transforms(image = image, mask = seg_mask)
            image = Config.to_tensor(image = augmented['image'])['image']
            seg_mask = augmented['mask']
        else:
            augmented = Config.test_transforms(image = image, mask = seg_mask)
            image = Config.to_tensor(image = augmented['image'])['image']
            seg_mask = augmented['mask']
        return image, seg_mask 
            
class WeightWatcherDataset(torch.utils.data.Dataset):
    def __init__(self, dict_mapping, train = True):
        self.dict_mapping = dict_mapping
        self.TurboJPEG = TurboJPEG()
        self.train = train
    def __len__(self):
        return len(self.dict_mapping)
    def __getitem__(self, idx):
        ids = list(self.dict_mapping[idx])[0]
        image_id = Config.image_base_path + ids + ".JPG"
        
        GT = self.dict_mapping[idx][ids]
        volume = GT['volume']
        weight = GT['weight']
        class_idx = Config.cls2idx[GT['class']]
        
        # Load in Image
        with open(image_id, 'rb') as file:
            image = self.TurboJPEG.decode(file.read())
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # Augment
        if self.train:
            image = Config.train_transforms(image = image)['image']
        else:
            image = Config.test_transforms(image = image)['image']
        image = Config.to_tensor(image = image)['image']
        return image, class_idx, volume, weight
        

In [None]:
class DataModule:
    @classmethod
    def get_seg(cls, train = True):
        if train == True:
            seg_Dataset = SegmentationDataset(Config.train, train = True)
        else:
            seg_Dataset = SegmentationDataset(Config.val, train = False)
        return seg_Dataset
    @classmethod
    def get_cls(cls, train = True):
        if train == True:
            cls_Dataset = WeightWatcherDataset(Config.train, train = True)
        else:
            cls_Dataset = WeightWatcherDataset(Config.val, train = False)
        return cls_Dataset

# Step 1: Segmentation Pretraining
- Design Custom UNet + ASPP Architecture to segment out the food and the coin, allowing it to learn key features about the image before doing anything about the 
- BackBone - EfficientNet + GhostNet Blocks + ASPP 
- Weak Decoder, Powerful Encoder, since we want most of the knowledge encoded in the encoder.

Handy CNN Blocks to Define(GhostNet, EffNet, ConvBlocks, etc.)

In [None]:
def initialize_weights(layer):
    # Initialize weights using Kaiming init, better than Xavier.
    for module in layer.modules():
        if isinstance(module, (nn.Conv2d, nn.Conv1d)):
            nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
        elif isinstance(module, nn.BatchNorm2d):
            module.weight.data.fill_(1)
            module.bias.data.zero_()
class Mish(pl.LightningModule):
    def __init__(self):
        pass
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))
def replace_mishes(model):
    for name, module in model.named_modules():
        if isinstance(module, (nn.SiLU, nn.ReLU)):
            setattr(model, name, Mish())
class Act(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.act_type = ModelConfig.act_type
        if self.act_type == 'relu':
            self.act = nn.ReLU(inplace = True)
        elif self.act_type == 'mish':
            self.act = Mish()
        else:
            self.act = nn.SiLU(inplace = True)
    def forward(self, x):
        return self.act(x)

class ConvBlock(pl.LightningModule):
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, bias = False)
        self.bn = nn.BatchNorm2d(out_features)
        self.act = Act()
        initialize_weights(self)
    def forward(self, x):
        return self.bn(self.act(self.conv(x)))
class SqueezeExcite(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.global_avg = nn.AdaptiveAvgPool2d((1, 1))
        self.Squeeze = nn.Linear(self.in_features, self.inner_features)
        self.act = Act()
        self.Excite = nn.Linear(self.inner_features, self.in_features)
    def forward(self, x):
        mean = torch.squeeze(self.global_avg(x))
        squeeze = self.act(self.Squeeze(mean))
        excite = torch.sigmoid(self.Excite(squeeze)).unsqueeze(-1).unsqueeze(-1) * x
        return excite
class ECASqueezeExcite(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.kernel_size = 5
        self.padding = 2
        self.global_avg = nn.AdaptiveAvgPool2d((1, 1))
        self.conv = nn.Conv1d(1, 1, kernel_size = self.kernel_size, padding = self.padding, bias = False)
        initialize_weights(self)
    def forward(self, x):
        mean = torch.squeeze(self.global_avg(x), dim = -1).transpose(-1, -2) # (B, 1, C)
        conv = torch.sigmoid(self.conv(mean)).transpose(-1, -2).unsqueeze(-1) # (B, C, 1, 1)
        return conv * x
class SCSqueezeExcite(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.global_avg = nn.AdaptiveAvgPool2d((1, 1))
        self.Squeeze = nn.Linear(self.in_features, self.inner_features)
        self.act = Act()
        self.Excite = nn.Linear(self.inner_features, self.in_features)
        
        self.kernel_size = 3
        self.padding = 1
        
        self.Conv_Squeeze = nn.Conv2d(self.in_features, 1, kernel_size = self.kernel_size, padding = self.padding, bias = False)
        initialize_weights(self)
    def forward(self, x):
        mean = torch.squeeze(self.global_avg(x))
        squeeze = self.act(self.Squeeze(mean))
        excite = torch.sigmoid(self.Excite(squeeze)).unsqueeze(-1).unsqueeze(-1) * x
        
        conv_excite = torch.sigmoid(self.Conv_Squeeze(x)) * x
        excited = (excite + conv_excite) / 2
        return excited
class Attention(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.attention_type = ModelConfig.attention_type
        assert self.attention_type in ['eca', 'se', 'scse', 'none']
        if self.attention_type == 'eca':
            self.layer = ECASqueezeExcite()
        elif self.attention_type == 'se':
            self.layer = SqueezeExcite(self.in_features, self.inner_features)
        elif self.attention_type == 'scse':
            self.layer = SCSqueezeExcite(self.in_features, self.inner_features)
        else:
            self.layer = nn.Identity()
        self.gate_attention = ModelConfig.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.zeros((1), device = self.device) - 10)
    def forward(self, x):
        excited = self.layer(x)
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excited + (1 - gamma) * x
        return excited
class SwinTransformerAttention(pl.LightningModule):
    def __init__(self, length, in_features, inner_features, num_heads):
        super().__init__()
        self.length = length 
        self.in_features = in_features
        self.inner_features = inner_features
        self.num_heads = num_heads
    
        self.Keys = nn.Linear(self.in_features, self.inner_features * self.num_heads)
        self.Queries = nn.Linear(self.in_features, self.inner_features * self.num_heads)
        self.Values = nn.Linear(self.in_features, self.inner_features * self.num_heads)
        self.Linear = nn.Linear(self.inner_features * self.num_heads, self.in_features)
        
        
        self.pos_enc = nn.Parameter(nn.init.xavier_uniform_(torch.zeros((self.num_heads, self.length, self.length), device = self.device)))
    def forward(self, x):
        B, L, _ = x.shape
        assert L == self.length
        K = self.Keys(x)
        V = self.Values(x)
        Q = self.Queries(x) # (B, L, Heads * Inner_features)
        
        K = K.view(B, L, self.num_heads, self.inner_features)
        V = V.view(B, L, self.num_heads, self.inner_features)
        Q = Q.view(B, L, self.num_heads, self.inner_features)
        
        K = K.transpose(1, 2).view(-1, L, self.inner_features)
        V = V.transpose(1, 2).view(-1, L, self.inner_features)
        Q = Q.transpose(1, 2).view(-1, L, self.inner_features) # (BH, L, inner_features)
        
        pos_enc = torch.repeat_interleave(self.pos_enc, B, dim = 0)
        att_mat = F.softmax(Q @ K.transpose(1, 2) / math.sqrt(self.inner_features) + pos_enc, dim = -1) # (BH, L, L)
        att_scores = att_mat @ V # (BH, L, I)
        
        att_scores = att_scores.view(B, self.num_heads, L, self.inner_features)
        att_scores = att_scores.transpose(1, 2).view(B, L, -1)
        return self.Linear(att_scores)
class SwinTransformerEncoder(pl.LightningModule):
    def window(self, x):
        # X: Tensor(B, C, H, W) 
        B, C, H, W = x.shape
        windowed = x.view(B, C, H // self.window_size, self.window_size, W // self.window_size, self.window_size)
        windowed = x.permute(0, 2, 4, 3, 5, 1)
        windowed = torch.view(-1, self.window_size * self.window_size, C)
        return windowed
    def unwindow(self, x):
        B, _, C = x.shape
        B = B // (self.H // self.window_size) // (self.W // self.window_size)
        
        unwindow = x.view(B, self.H // self.window_size, self.W // self.window_size, self.window_size, self.window_size, C)
        unwindow = unwindow.permute(0, 5, 1, 3, 2, 4) 
        unwindow = unwindow.view(B, C, self.H, self.W)
        return unwindow
        
    def shift(self, x):
        return torch.roll(x, (-(self.window_size // 2), -(self.window_size // 2)))
    def unshift(self, x):
        return torch.roll(x, (self.window_size // 2, self.window_size // 2))
    def __init__(self, H, W, in_features, inner_features, out_features, num_heads, window_size = 4):
        super().__init__()
        self.H = H
        self.W = W
        self.length = self.H * self.W
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.num_heads = num_heads
        self.window_size = window_size
        
        self.pos_enc = nn.Parameter(nn.init.xavier_uniform_(torch.zeros((1, self.in_features, self.H, self.W), device = self.device)))
        self.norm1 = nn.LayerNorm((self.in_features, self.H, self.W))
        self.att1 = SwinTransformerAttention(self.window_size ** 2, self.in_features, self.inner_features, self.num_heads)
        self.norm2 = nn.LayerNorm((self.length, self.in_features))
        self.linear2 = nn.Linear(self.in_features, self.in_features)
        
        self.norm3 = nn.LayerNorm((self.in_features, self.H, self.W))
        self.att3 = SwinTransformerAttention(self.window_size ** 2, self.in_features, self.inner_features, self.num_heads)
        self.norm4 = nn.LayerNorm((self.length, self.in_features))
        self.linear4 = nn.Linear(self.in_features, self.out_features)
    def forward(self, x):
        # X: Tensor(B, C, H, W)
        x = x + self.pos_enc # (B, C, H, W)
        norm1 = self.norm1(x) # (B, C, H, W)
        windowed = self.window(norm1) # (-1, window_size ** 2, C)
        att1 = self.att1(windowed)
        # Unwindow
        unwindowed = self.unwindow(att1) + x# (B, C, H, W)
        unwindowed = unwindowed.view(B, C, self.W * self.H).transpose(1, 2) # (B, HW, C)
        norm2 = self.norm2(unwindowed) # (B, HW, C)
        linear2 = self.linear2(norm2) + unwindowed
    
        linear2 = linear2.transpose(1, 2).view(B, C, self.H, self.W)
        norm3 = self.norm3(linear2) # (B, C, H, W)
        windowed = self.window(norm3) 
        windowed = self.shift(windowed)
        
        att3 = self.att3(windowed)
        att3 = self.unshift(att3)
        
        unwindowed = self.unwindow(att3) + linear2 # (b, C, H, W)
        
        unwindowed = unwindowed.view(B, C, self.H * self.W).transpose(1, 2) # (B, L, C)
        norm4 = self.norm4(unwindowed) 
        linear4 = self.linear4(norm4) + unwindowed# (B, L, C) 
    
        output = linear4.transpose(1, 2).view(B, C, self.H, self.W)
        return output
        
        
        
class SplitAttention(pl.LightningModule):
    # split attention like in ResNest.
    def __init__(self, in_features, inner_features, cardinality):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.cardinality = cardinality
        assert self.in_features % self.cardinality == 0
        
        self.AvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv1 = ConvBlock(self.in_features, self.inner_features * self.cardinality, 1, 0, 1, 1)
        self.conv2 = nn.Conv2d(self.inner_features * self.cardinality, self.in_features * self.cardinality, kernel_size = 1, groups = self.cardinality, bias = False)
        initialize_weights(self.conv2)
        
        self.gate_attention = ModelConfig.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.tensor(-10.0, device = self.device))
        
    def forward(self, x):
        # X: Tensor(B, C, H, W, Cardinality)
        B, C, H, W, Cardinality = x.shape
        summed = torch.sum(x, dim = -1)
        pooled = self.AvgPool(summed) # (B, C * Cardinality, 1, 1)
        
        conv1 = self.conv1(pooled)
        conv2 = self.conv2(conv1) # (B, inner_features * cardinality, 1, 1)  
        conv2 = conv2.view(B, self.in_features, self.cardinality)# (B, inner_features, cardinality, 1, 1)
        conv2 = F.softmax(conv2.unsqueeze(2).unsqueeze(2), dim = -1) # (B, inner_features, 1, 1, cardinality)
        
        excited = x * conv2
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excited + (1 - gamma) * x
        return excited
class GhostConv(pl.LightningModule):
    def __init__(self, in_features, out_features, kernel_size, padding):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.kernel_size = kernel_size
        self.padding = padding
        
        self.inner_features = self.out_features // 2
        
        self.squeeze = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.cheap = ConvBlock(self.inner_features, self.inner_features, self.kernel_size, self.padding, self.inner_features, 1)
    def forward(self, x):
        squeeze = self.squeeze(x)
        cheap = self.cheap(squeeze) # (B, C, H, W)
        
        return torch.cat([squeeze, cheap], dim = 1) 

class AstrousConvBlock(pl.LightningModule):
    def __init__(self, in_features, out_features, kernel_size, padding, groups, stride, dilation):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, dilation = dilation, bias = False)
        self.bn = nn.BatchNorm2d(out_features)
        self.act = Act()
        initialize_weights(self)
    def forward(self, x):
        return self.bn(self.act(self.conv(x)))
        
class BAM(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        
        self.global_avg = nn.AdaptiveAvgPool2d((1, 1))
        self.Squeeze = nn.Linear(self.in_features, self.inner_features)
        self.Act = Act()
        self.Excite = nn.Linear(self.inner_features, self.in_features)
        self.bam_dilate = ModelConfig.bam_dilate
        
        self.ConvSqueeze = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.DA = AstrousConvBlock(self.inner_features, self.inner_features, 3, self.bam_dilate, self.inner_features, 1, self.bam_dilate)
        self.ConvExcite = nn.Conv2d(self.inner_features, 1, kernel_size = 1, bias = False)
        initialize_weights(self.ConvExcite)
        
        self.gate_attention = ModelConfig.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.tensor(-10.0, device = self.device))
    def forward(self, x):
        pooled = torch.squeeze(self.global_avg(x))
        squeeze = self.Act(self.Squeeze(pooled))
        excite = torch.sigmoid(self.Excite(squeeze).unsqueeze(-1).unsqueeze(-1)) * x
        
        squeeze = self.ConvSqueeze(x)
        DA = self.DA(squeeze)
        convExcite = self.ConvExcite(DA) * x
        
        excited = (convExcite + excite) / 2
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            return gamma * excited + (1 - gamma) * x
        return excited
    
# -----------------BottleNeck Blocks -------------------#
class GhostBottleNeck(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.reduction = ModelConfig.reduction
        
        self.ghost1 = GhostConv(self.in_features, self.inner_features, 1, 0)
        self.att = Attention(self.inner_features, self.inner_features // self.reduction)
        self.ghost2 = GhostConv(self.inner_features, self.in_features, 3, 1)
        
        self.gamma = nn.Parameter(torch.tensor(-10.0, device = self.device))
    def forward(self, x):
        ghost1 = self.ghost1(x)
        att1 = self.att(ghost1)
        ghost2 = self.ghost2(att1)
        
        gamma = torch.sigmoid(self.gamma)
        return gamma * ghost2 + (1 - gamma) * x
class GhostDownSampler(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, stride):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride
        self.reduction = ModelConfig.reduction
        
        self.AvgPool = nn.AvgPool2d(kernel_size = 3, padding = 1, stride = self.stride)
        self.ghostPool = GhostConv(self.in_features, self.out_features, 1, 0)
        
        self.Ghost1 = GhostConv(self.in_features, self.inner_features, 1, 0)
        self.DW = ConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features, 1)
        self.att = Attention(self.inner_features, self.inner_features // self.reduction)
        self.Ghost2 = GhostConv(self.inner_features, self.out_features, 1, 0)
        
        self.gamma = nn.Parameter(torch.tensor(-10.0, device = self.device))
    def forward(self, x):
        pooled = self.AvgPool(x)
        ghostPool = self.ghostPoool(pooled)
        
        ghost1 = self.Ghost1(x)
        dw = self.DW(ghost1)
        att = self.att(dw)
        ghost2 = self.Ghost2(att)
        
        gamma = torch.sigmoid(self.gamma)
        return gamma * ghostPool + (1 - gamma) * ghost2
class ResNextBottleNeck(pl.LightningModule):
    def __init__(self, in_features, inner_features, cardinality, resnest = False):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.cardinality = cardinality
        self.resnest = resnest
        self.reduction = ModelConfig.reduction
        
        self.Squeeze = ConvBlock(self.in_features, self.inner_features * self.cardinality, 1, 0, 1, 1)
        self.Process = ConvBlock(self.inner_features * self.cardinality, self.inner_features * self.cardinality, 3, 1, self.cardinality, 1)
        if self.resnest:
            #  Split Attention
            self.att = SplitAttention(self.inner_features, self.inner_features // self.reduction, self.cardinality)
        else:
            self.att = Attention(self.inner_features * self.cardinality, self.inner_features * self.cardinality // self.reduction)
        self.Expand = ConvBlock(self.inner_features * self.cardinality, self.in_features, 1, 0, 1, 1)
        
        self.gamma = nn.Parameter(torch.tensor(-10.0, device = self.device))
    def forward(self, x):
        squeeze = self.Squeeze(x)
        process = self.Process(squeeze)
        if self.resnest:
            B, C, H, W = process.shape
            process = process.view(B, C // self.cardinality, self.cardinality, H, W)
            process = process.transpose(2, 3).transpose(3, 4)
            
            att = self.att(process)
            att = att.transpose(3, 4).transpose(2, 3).view(B, -1, H, W)
        else:
            att = self.att(process)
        expand = self.Expand(att)
        
        gamma = torch.sigmoid(self.gamma)
        return gamma * expand + (1 - gamma) * x
class ResNextDownSampler(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, stride, cardinality, resnest = False):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride
        self.cardinality = cardinality
        self.resnest = resnest 
        self.reduction = ModelConfig.reduction
        
        self.avgPool = nn.AvgPool2d(kernel_size = 3, padding = 1, stride = self.stride)
        self.ConvPool = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        
        self.Squeeze = ConvBlock(self.in_features, self.inner_features * self.cardinality, 1, 0, 1, 1)
        self.Process = ConvBlock(self.cardinality * self.inner_features, self.inner_features * self.cardinality, 3, 1, self.cardinality, 1)
        if self.resnest:
            self.att = SplitAttention(self.inner_features, self.inner_features // self.reduction, self.cardinality)
        else:
            self.att = Attention(self.inner_features * self.cardinality, self.inner_features * self.cardinality // self.reduction)
        self.Expand = ConvBlock(self.inner_features * self.cardinality, self.in_features, 1, 0, 1, 1)
    
        self.gamma = nn.Parameter(torch.tensor(-10.0, device = self.device))
    def forward(self, x):
        pooled = self.avgPool(x)
        convPool = self.ConvPool(pooled)
        
        squeeze = self.Squeeze(x)
        process = self.Process(squeeze)
        if self.resnest:
            B, C, H, W = process.shape
            process = process.view(B, C // self.cardinality, self.cardinality, H, W)
            process = process.transpose(2, 3).transpose(3, 4)
            att = self.att(process)
            att = att.transpose(3, 4).transpose(2, 3).view(B, -1, H, W) 
        else:
            att = self.att(process)
        expand = self.Expand(att)
        
        gamma = torch.sigmoid(self.gamma)
        return gamma * expand + (1 - gamma) * convPool
class InverseBottleNeck(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.reduction = ModelConfig.reduction
        
        self.Expand = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.DW = ConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features, 1)
        self.Attention = Attention(self.inner_features, self.inner_features // self.reduction)
        self.Squeeze = ConvBlock(self.inner_features, self.in_features, 1, 0, 1, 1)
        
        self.gamma = nn.Parameter(torch.tensor(-10.0, device = self.device))
    def forward(self, x):
        expand = self.Expand(x)
        dw = self.DW(expand)
        att = self.Attention(dw)
        squeeze = self.Squeeze(att)
        
        gamma = torch.sigmoid(self.gamma)
        return gamma * squeeze + (1 - gamma) * x
class InverseDownSampler(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, stride):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride 
        self.reduction = ModelConfig.reduction
    
        self.AvgPool = nn.AvgPool2d(kernel_size = 3, padding = 1, stride = self.stride)
        self.ConvPool = ConvBlock(self.in_features, self.out_features, 1, 0, 1, 1)
    
        self.Expand = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, 1)
        self.DW = ConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features, 1)
        self.Attention = Attention(self.inner_features, self.inner_features // self.reduction)
        self.Squeeze = ConvBlock(self.inner_features, self.out_features, 1, 0, 1, 1)
        
        self.gamma = nn.Parameter(torch.tensor(-10.0, device = self.device))
    def forward(self, x):
        pooled = self.AvgPool(x)
        convPool = self.ConvPool(pooled)
        
        expand = self.Expand(x)
        dw = self.DW(expand)
        att = self.Attention(dw)
        squeeze = self.Squeeze(att)
        
        gamma = torch.sigmoid(self.gamma)
        return gamma * squeeze + (1 - gamma) * convPool
class ChooseBottleNeck(pl.LightningModule):
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.bottleneck_type = ModelConfig.bottleneck_type
        
        if self.bottleneck_type == 'ghost':
            self.layer = GhostBottleNeck(self.in_features, self.inner_features)
        elif self.bottleneck_type == 'inverse':
            self.layer = InverseBottleNeck(self.in_features, self.inner_features)
        else:
            self.cardinality = ModelConfig.groups
            self.resnest = ModelConfig.resnest
            self.layer = ResNextBottleNeck(self.in_features, self.inner_features, self.cardinality, resnest = self.resnest)
    def forward(self, x):
        return self.layer(x)
        
class ChooseDownsampler(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, stride):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride
        self.bottleneck_type = ModelConfig.bottleneck_type
    
        if self.bottleneck_type == 'ghost':
            self.layer = GhostDownSampler(self.in_features, self.inner_features, self.out_features, self.stride)
        elif self.bottleneck_type == 'inverse':
            self.layer = InverseDownSampler(self.in_features, self.inner_features, self.out_features, self.stride)
        else:
            self.cardinality = ModelConfig.groups
            self.resnest = ModelConfig.resnest
            self.layer = ResNextDownSampler(self.in_features, self.inner_features, self.out_features, self.stride, self.cardinality, resnest = self.resnest)
    def forward(self, x):
        return self.layer(x)

Segmentation Specific Blocks(ASPP, DC-UNet, swin Transformer Block)

In [None]:
class ASPP(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, stride):
        '''
        ASPP Module:
        - 1x1 Conv
        - 3x3 Conv, Astrous 2
        - 3x3 Conv, Astrous 3
        - 3x3,Conv, Astrous 5
        - 3x3 Pool
        '''
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride 
        
        self.conv1 = ConvBlock(self.in_features, self.inner_features, 1, 0, 1, self.stride)
        self.conv2 = AstrousConvBlock(self.in_features, self.inner_features, 3, 2, 1, self.stride, 2)
        self.conv3 = AstrousConvBlock(self.in_features, self.inner_features, 3, 3, 1, self.stride, 3)
        self.conv4 = AstrousConvBlock(self.in_features, self.inner_features, 3, 5, 1, self.stride, 5)
        self.conv5 = nn.AvgPool2d(kernel_size = 3, padding = 1, stride = self.stride)
        
        self.ConvProj = ConvBlock(self.inner_features * 4 + self.in_features, self.out_features, 1, 0, 1, 1)
    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        conv4 = self.conv4(x)
        conv5 = self.conv5(x)
        
        concat = torch.cat([conv1, conv2, conv3, conv4, conv5], dim = 1)
        return self.ConvProj(concat)
        
class DualChannelBlock(pl.LightningModule):
    # DC-Unet Block, However only 1 channels to reduce model parameters
    def __init__(self, in_features, inner_features, out_features, stride):
        super().__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        self.stride = stride
        
        self.Conv1 = ConvBlock(self.in_features, self.inner_features, 3, 1, 1, 1)
        self.Conv2 = ConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features, 1)
        
        self.Conv3 = ConvBlock(self.inner_features * 2 + self.in_features, self.out_features, 3, 1, 1, self.stride)
    def forward(self, x):
        conv1 = self.Conv1(x)
        conv2 = self.Conv2(conv1)
        cat = torch.cat([x, conv1, conv2], dim = 1) 
        conv3 = self.Conv3(cat)
        return conv3
class FPN(pl.LightningModule):
    def __init__(self, in_features, inner_features, out_features, out_size):
        super().__init__()
        self.in_features = in_features
        self.num_features = len(self.in_features)
        self.inner_features = inner_features
        self.out_features = out_features
        
        self.convBlocks = nn.ModuleList([
            ConvBlock(self.in_features[idx], self.inner_features) for idx in range(self.num_features)
        ])
        self.Proj = ConvBlock(self.inner_features * self.num_features, self.out_features)
    def forward(self, features):
        concat_features = []
        for idx in range(len(features)):
            conv_feature = self.convBlocks[idx](features[idx])
            conv_feature = F.interpolate(conv_feature, size = (self.out_size, self.out_size), mode = 'bilinear')
            concat_features += [conv_feature]
        concat_features = torch.cat(concat_features, dim = 1) 
        return self.Proj(concat_features)
        
class GatedAttentionBlock(pl.LightningModule):
    # Gated Attention Block, like in Attention-UNet
    def __init__(self, left_features, down_features, inner_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.inner_features = inner_features
        
        self.GlobalPool = nn.AdaptiveAvgPool2d((1, 1))
        self.Squeeze_Left = nn.Linear(self.left_features, self.inner_features)
        self.Squeeze_Down = nn.Linear(self.down_features, self.inner_features)
        
        self.Act = Act()
        self.Excite = nn.Linear(self.inner_features, self.down_features)
        self.gate_attention = ModelConfig.gate_attention
        if self.gate_attention:
            self.gamma = nn.Parameter(torch.tensor(-10, device = self.device))
    def forward(self, left_features, down_features):
        left_features = self.GlobalPool(left_features)
        down_features = self.GlobalPool(down_features)
         
        squeeze_left = self.Squeeze_Left(left_features)
        squeeze_down = self.Squeeze_Down(down_features)
        
        squeezed = self.Act((squeeze_left + squeeze_down) / 2)
        excited = torch.sigmoid(self.Excite(squeezed)).unsqueeze(-1).unsqueeze(-1)
        if self.gate_attention:
            gamma = torch.sigmoid(self.gamma)
            excited = gamma * excited + (1 - gamma) * x
        return excited
class UNetBlock(pl.LightningModule):
    def __init__(self, left_features, down_features, out_features):
        super().__init__()
        self.left_features = left_features
        self.down_features = down_features
        self.out_features = out_features
        self.reduction = ModelConfig.reduction
        
        self.decoder_attention = ModelConfig.decoder_attention
        if self.decoder_attention:
            self.att = GatedAttentionBlock(self.left_features, self.down_features, self.down_features // self.reduction)
        self.conv1 = ConvBlock(self.down_features + self.left_features, self.out_features, 3, 1, 1, 1)
        self.conv2 = ConvBlock(self.out_features, self.out_features, 3, 1, 1, 1)
        if self.decoder_attention:
            self.att = Attention(self.out_features, self.out_features // self.reduction)
        else:
            self.att = nn.Identity()
    def forward(self, left_features, down_features):
        if left_features is None:
            down_features = F.interpolate(down_features, scale_factor = 2, mode = 'nearest')
            concat = down_features
        else:
            down_features = F.interpolate(down_features, scale_factor = 2, mode = 'nearest')
            if self.decoder_attention:
                down_features = self.att(left_features, down_features)
            concat = torch.cat([down_features, left_features], dim = 1) 
        conv1 = self.conv1(concat)
        conv2 = self.conv2(conv1)
        return self.att(conv2)    

# Segmentation model

In [None]:
class EncoderAlpha(pl.LightningModule):
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def freeze_cls(self):
        self.freeze([self.conv1, self.bn1, self.block0, self.block1, self.block2, self.block3, self.block4, self.block5])
    def __init__(self):
        super().__init__()
        self.model_name = 'efficientnet-b0'
        self.model = EfficientNet.from_pretrained(self.model_name)
        
        # Extract Layers
        self.conv1 = self.model._conv_stem # 32, 128
        self.bn1 = self.model._bn0
        self.act1 = self.model._swish
        
        self.block0 = self.model._blocks[0] # 16, 128
        self.block1 = nn.Sequential(*self.model._blocks[1:3]) # 24, 64 
        self.block2 = nn.Sequential(*self.model._blocks[3: 5]) # 40, 32
        self.block3 = nn.Sequential(*self.model._blocks[5: 8]) # 80, 16
        self.block4 = nn.Sequential(*self.model._blocks[8: 11]) # 112, 16
        self.block5 = nn.Sequential(*self.model._blocks[11: 15]) # 192, 8
        self.block6 = self.model._blocks[15] # 320, 8
        
        # Freeze Initial Layers
        self.freeze([self.conv1, self.bn1, self.block0, self.block1])
    
        self.reduction = ModelConfig.bottleneck_reduction
        
        self.Dropout6 = nn.Dropout2d(ModelConfig.drop_prob)
        self.Attention6 = BAM(320, 320 // self.reduction)
        
        self.block7 = nn.Sequential(*[
            ChooseBottleNeck(320, 320 // self.reduction) for i in range(2)
        ] + [
            DualChannelBlock(320, 320 // self.reduction, 320, 1),
            ASPP(320, 320 // self.reduction, 512, 2)
        ])
        
        self.Dropout7 = nn.Dropout2d(ModelConfig.drop_prob)
        self.Attention7 = BAM(512, 512 // self.reduction)
        
        # large encoder, small decoder(Since it's being stripped away)
    def forward_cls(self, x):
        features0 = self.bn1(self.act1(self.conv1(x)))
        block0 = self.block0(features0)
        block1 = self.block1(block0)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        
        block6 = self.Attention6(self.Dropout6(block6))
        block7 = self.block7(block6)
        block7 = self.Attention7(self.Dropout7(block7))
        return block7
    def forward(self, x):
        features0= self.bn1(self.act1(self.conv1(x)))
        block0 = self.block0(features0) # 16
        block1 = self.block1(block0) # 24
        block2 = self.block2(block1) # 40
        block3 = self.block3(block2) # 80
        block4 = self.block4(block3) # 112
        block5 = self.block5(block4) # 192
        block6 = self.block6(block5) # 320
        
        block6 = self.Attention6(self.Dropout6(block6))
        block7 = self.block7(block6)
        block7 = self.Attention7(self.Dropout7(block7)) # 512
        
        
        features = [block0, block1, block2, block4, block6, block7]
        return features
class DecoderAlpha(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.enc_dims = [320, 112, 40, 24, 16, 0]
        self.dec_dims = [512, 256, 128, 64, 32, 16, 16]
        
        self.decoder_blocks = nn.ModuleList([
            UNetBlock(self.enc_dims[idx], self.dec_dims[idx], self.dec_dims[idx + 1]) for idx in range(len(self.enc_dims))
        ])
        
        self.use_FPN = ModelConfig.use_FPN
        if self.use_FPN:
            self.FPN = FPN(self.dec_dims[1:-1], self.dec_dims[-1], self.dec_dims[-1], Config.IMAGE_SIZE // 2)
        self.segmentation_head = nn.Conv2d(self.dec_dims[-1], ModelConfig.num_classes, kernel_size = 3, padding = 1)
        initialize_weights(self.segmentation_head)
    def forward(self, features):
        l0, l1, l2, l3, l4, l5 = tuple(features)
        
        d4 = self.decoder_blocks[0](l4, l5) # 256
        d3 = self.decoder_blocks[1](l3, d4) # 128
        d2 = self.decoder_blocks[2](l2, d3) # 64,
        d1 = self.decoder_blocks[3](l1, d2) # 32
        d0 = self.decoder_blocks[4](l0, d1) # 16
        
        if self.use_FPN:
            d0 = self.FPN([d4, d3, d2, d1, d0])
        
        final = self.decoder_blocks[5](None, d0)
        return self.segmentation_head(final)
        
class SegmentationModelAlpha(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.num_classes = ModelConfig.num_classes
        self.encoder = EncoderAlpha()
        self.decoder = DecoderAlpha(self.num_classes)
    def forward(self, x):
        features = self.encoder(x)
        pred = self.decoder(features)
        return pred

# Train the Segmentation Model

In [None]:
class ModelConfig:
    num_classes = 2
    drop_prob = 0.1
    act_type = 'relu'
    gate_attention = True
    attention_type = 'se'
    reduction = 16
    
    bam_dilate = 3
    groups = 16
    resnest = True
    
    bottleneck_type = 'resnext'
    bottleneck_reduction = 4
    
    decoder_attention = False
    use_FPN = False

# Metrics and LossFn

In [None]:
CRITERION = nn.CrossEntropyLoss()
def dice_loss(y_pred, y_true):
    y_pred = F.softmax(y_pred, dim = 1)
    y_ones = y_pred[:, 1, :, :]
    
    eps = 1e-8
    inter = torch.sum(y_ones * y_true)
    union = torch.sum(y_ones + y_true)
    
    dice = (2 * inter + eps) / (union + eps)
    loss = 1 - dice
    return torch.log((torch.exp(loss) + torch.exp(-loss)) / 2)
def ce_loss(y_pred, y_true):
    loss = CRITERION(y_pred, y_true.to(torch.long))
    return loss
def loss_fn(y_pred, y_true):
    ce = ce_loss(y_pred, y_true)
    dice = dice_loss(y_pred, y_true)
    return ce + dice
class Loss(Metric):
    def __init__(self):
        super().__init__()
        self.loss = 0
        self.count = 0
    def reset(self):
        self.loss = 0
        self.count = 0
    def accumulate(self, y_pred, y_true):
        loss = loss_fn(y_pred, y_true)
        self.loss += loss.item()
        self.count += 1
        return loss
    @property
    def value(self):
        if self.count != 0:
            return self.loss / self.count 
        return 0
class DiceMetric(Metric):
    # Dice Soft Metric
    def __init__(self):
        super().__init__()
        self.inter = 0
        self.union = 0
    def reset(self):
        self.inter = 0
        self.union = 0
    def inter_union(self, y_pred, y_true):
        y_ones = y_pred[:, 1, :, :]
        self.inter += torch.sum(y_ones * y_true)
        self.union += torch.sum(y_ones + y_true)
        
    def accumulate(self, y_pred, y_true):
        y_pred = F.softmax(y_pred, dim = 1)
        self.inter_union(y_pred, y_true)
    @property
    def value(self):
        eps = 1e-8
        dice = (2 * self.inter + eps) / (self.union + eps)
        return round(dice.item(), 3)
        

In [None]:
class TrainConfig:
    lr = 1e-3
    num_epochs = 10000 # Dummy Number that is never reached
    batch_size = 64
    sample_train = DataModule.get_seg()
    
    num_steps_per_epoch = len(sample_train) // batch_size
    weight_decay = 1e-1
    num_steps = 5
    eta_min = 1e-7
    max_lr = 1e-2
    num_workers = 4
    

In [None]:
class TrainerSeg(pl.LightningModule):
    # I don't have much time in this hackathon, so I'm only going to train for 20 or 30 minutes(30 epochs) could probably boost performance with more pretraining
    def __init__(self):
        super().__init__()
        self.model = self.configure_model()
        
        self.TrainLoss = Loss()
        self.ValLoss = Loss()
        self.ValDice = DiceMetric()
        
        self.best = {'loss': float('inf'), 'dice': 0.0}
        self.EPOCHS = 0
    def configure_model(self):
        model = SegmentationModelAlpha()
        return model 
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr = TrainConfig.lr, weight_decay = TrainConfig.weight_decay)
        self.lr_decay_cycle = optim.lr_scheduler.OneCycleLR(optimizer, TrainConfig.max_lr, total_steps = TrainConfig.num_steps_per_epoch * TrainConfig.num_epochs, epochs = TrainConfig.num_epochs, steps_per_epoch = TrainConfig.num_steps_per_epoch)
        return optimizer
    def training_step(self, batch, batch_idx):
        x, y = batch
    
        pred = self.model(x)
        loss = self.TrainLoss.accumulate(pred, y)
        self.lr_decay_cycle.step()
        return loss
    def validation_step(self, batch, batch_idx):
        x, y = batch 
    
        pred = self.model(x)
        loss = self.ValLoss.accumulate(pred, y)
        self.ValDice.accumulate(pred, y)
    def reset_states(self):
        self.TrainLoss.reset()
        self.ValLoss.reset()
        self.ValDice.reset()
        self.EPOCHS += 1
    def save_states(self):
        TrainLoss = self.TrainLoss.value
        ValLoss = self.ValLoss.value
        ValDice = self.ValDice.value
        self.log('val_dice', ValDice)
        if ValLoss < self.best['loss']:
            self.best['loss'] = ValLoss
            torch.save(self.model.encoder.state_dict(), './best.pth')
        if ValDice > self.best['dice']:
            self.best['dice'] = ValDice
            torch.save(self.model.encoder.state_dict(), './dice.pth')
        print(f"E: {self.EPOCHS}, TL: {TrainLoss}, VL: {ValLoss} VD: {ValDice}, BL: {self.best['loss']}, BD: {self.best['dice']}")
        
    def validation_epoch_end(self, logs):
        self.save_states()
        self.reset_states()
        
        

In [None]:
def train_seg():
    train = DataModule.get_seg()
    val = DataModule.get_seg(train = False)
    
    model = TrainerSeg()
    dls = DataLoaders.from_dsets(train, val, batch_size = TrainConfig.batch_size, num_workers = TrainConfig.num_workers, shuffle = True, pin_memory = True)
    if torch.cuda.is_available(): dls.cuda(), model.cuda()
    cbs = [pl.callbacks.EarlyStopping(monitor = 'val_dice', patience = 10, mode = 'max'
    
    )]
    trainer = pl.Trainer(check_val_every_n_epoch = 1, callbacks = cbs, checkpoint_callback = False, logger = None, gpus = 1, max_epochs = TrainConfig.num_epochs, num_sanity_val_steps=0)
    trainer.fit(model, dls[0], dls[1])
    

In [None]:
#train_seg()

# Step 2: Classification and Regression Heads.
- Cross Stitched Classification and Regression Heads

In [None]:
class FeaturesAlpha(pl.LightningModule):
    def freeze(self, layers):
        for layer in layers:
            for parameter in layer.parameters():
                parameter.requires_grad = False
    def __init__(self):
        super().__init__()
        self.model = EncoderAlpha()
        self.model.load_state_dict(torch.load(Config.pretrained_model, map_location = self.device))
        # Freeze Entire Encoder
        self.model.freeze_cls()
        # One Last Layer, seperated by Task.
        self.in_dim = 512
        self.out_dim = 1024
        self.reduction = ModelConfig.reduction
        self.cls_att8 = nn.Identity()#BAM(self.in_dim, self.in_dim // self.reduction)
        self.reg_att8 = nn.Identity()#BAM(self.in_dim, self.in_dim // self.reduction)
        self.vol_att8 = nn.Identity()#BAM(self.in_dim, self.in_dim // self.reduction)
        self.drop_att8 = nn.Identity()#nn.Dropout2d(0.1)
        
        self.proj_cls = ConvBlock(self.in_dim, self.out_dim, 1, 0, 1, 1)
        self.proj_reg = ConvBlock(self.in_dim, self.out_dim, 1, 0, 1, 1)
        self.proj_vol = ConvBlock(self.in_dim, self.out_dim, 1, 0, 1, 1)
        
        self.global_avg = nn.AdaptiveAvgPool2d((1, 1))
        
    def forward(self, x):
        features = self.model.forward_cls(x)
    
        cls_att8 = self.cls_att8(self.drop_att8(features))
        reg_att8 = self.reg_att8(self.drop_att8(features))
        vol_att8 = self.vol_att8(self.drop_att8(features))
        
        proj_cls = self.proj_cls(cls_att8)
        proj_reg = self.proj_reg(reg_att8)
        proj_vol = self.proj_vol(vol_att8)
        
        pooled_cls = torch.squeeze(self.global_avg(proj_cls))
        pooled_reg = torch.squeeze(self.global_avg(proj_reg))
        pooled_vol = torch.squeeze(self.global_avg(proj_vol))
        
        return pooled_cls, pooled_reg, pooled_vol

In [None]:
class CrossStitchUnit(pl.LightningModule):
    # Merges Features from Both Branches to better MultiTask Learn
    def __init__(self):
        super().__init__()
        self.alpha1_1 = nn.Parameter(torch.tensor(1., device = self.device))
        self.alpha1_2 = nn.Parameter(torch.tensor(0., device = self.device))
        self.alpha1_3 = nn.Parameter(torch.tensor(0., device = self.device))
        
        self.alpha2_1 = nn.Parameter(torch.tensor(0., device = self.device))
        self.alpha2_2 = nn.Parameter(torch.tensor(1., device = self.device))
        self.alpha2_3 = nn.Parameter(torch.tensor(0., device = self.device))
        
        self.alpha3_1 = nn.Parameter(torch.tensor(0., device = self.device))
        self.alpha3_2 = nn.Parameter(torch.tensor(0., device = self.device))
        self.alpha3_3 = nn.Parameter(torch.tensor(1., device = self.device))
        
    def forward(self, CLS, REG, VOL):
        alpha1_1 = torch.sigmoid(self.alpha1_1)
        alpha1_2 = torch.sigmoid(self.alpha1_2)
        alpha1_3 = torch.sigmoid(self.alpha1_3)
        
        alpha2_1 = torch.sigmoid(self.alpha2_1)
        alpha2_2 = torch.sigmoid(self.alpha2_2)
        alpha2_3 = torch.sigmoid(self.alpha2_3)
        
        alpha3_1 = torch.sigmoid(self.alpha3_1)
        alpha3_2 = torch.sigmoid(self.alpha3_2)
        alpha3_3 = torch.sigmoid(self.alpha3_3)
        
        new_cls = alpha1_1 * CLS + alpha1_2 * REG + alpha1_3 * VOL
        new_reg = alpha2_1 * CLS + alpha2_2 * REG + alpha2_3 * VOL
        new_vol = alpha3_1 * CLS + alpha3_2 * REG + alpha3_3 * VOL
        
        return new_cls, new_reg, new_vol
class LinBNReLU(pl.LightningModule):
    def __init__(self, in_features, out_features, drop_prob = 0.5):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.drop_prob = drop_prob
        self.layer = nn.Sequential(*[
            nn.Linear(self.in_features, self.out_features),
            nn.ReLU(inplace = True),
            nn.BatchNorm1d(self.out_features),
            nn.Dropout(self.drop_prob)
        ])
    def forward(self, x):
        return self.layer(x)
class BaseLineHeadAlpha(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.in_dim = 1024
        self.num_classes = Config.num_classes
        self.fc_cls = nn.Linear(self.in_dim, self.num_classes)
        self.fc_reg = nn.Linear(self.in_dim, 1)
        self.fc_vol = nn.Linear(self.in_dim, 1)
    def forward(self, CLS, REG, VOL):
        return self.fc_cls(CLS), torch.squeeze(self.fc_reg(REG)), torch.squeeze(self.fc_vol(VOL))
class CrossStitchAlpha(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.in_dim = 1024
        self.num_classes = Config.num_classes
        self.cross_stitch = CrossStitchUnit()
        
        self.fc_cls1 = LinBNReLU(self.in_dim, 512)
        self.fc_reg1 = LinBNReLU(self.in_dim, 512)
        self.fc_vol1 = LinBNReLU(self.in_dim, 512)
    
        self.cross_stitch2 = CrossStitchUnit()
        
        self.fc_cls2 = nn.Linear(512, self.num_classes)
        self.fc_reg2 = nn.Linear(512, 1) # Predict Weight and Class of the item.
        self.fc_vol2 = nn.Linear(512, 1)
    def forward(self, classification, regression, volume):
        classification, regression, volume = self.cross_stitch(classification, regression, volume)
        
        if len(classification.shape) < 2:
            classification = classification.unsqueeze(0)
        if len(regression.shape) < 2:
            regression = regression.unsqueeze(0)
        if len(volume.shape) < 2:
            volume = volume.unsqueeze(0)
            
        fc_cls1 = self.fc_cls1(classification)
        fc_reg1 = self.fc_reg1(regression)
        fc_vol1 = self.fc_vol1(volume)
        
        fc_cls1, fc_reg1, fc_vol1 = self.cross_stitch2(fc_cls1, fc_reg1, fc_vol1)
        
        fc_cls2 = self.fc_cls2(fc_cls1)
        fc_reg2 = torch.squeeze(self.fc_reg2(fc_reg1))
        fc_vol2 = torch.squeeze(self.fc_vol2(fc_vol1))
        return fc_cls2, fc_reg2, fc_vol2
class ModelAlpha(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.featureExtractor = FeaturesAlpha()
        self.crossStitch = CrossStitchAlpha()#BaseLineHeadAlpha()
    def forward(self, x):
        features= self.featureExtractor(x)
        return self.crossStitch(*features)

# Training Code

In [None]:
class Accuracy():
    def __init__(self):
        self.accuracy = 0
        self.count = 0
    def update_state(self, y_pred, y_true):
        y_pred = F.softmax(y_pred, dim = -1)
        _, y_pred = torch.max(y_pred, dim = -1) 
        B = y_pred.shape[0]
        acc = (torch.sum(y_pred == y_true) / B).item()
        self.accuracy += acc
        self.count += 1
    def reset_states(self):
        self.accuracy = 0
        self.count = 0
    @property
    def value(self):
        if self.count != 0:
            return round(self.accuracy / self.count, 3)
        return 0
class F1Score():
    # Metric For CLS
    def __init__(self):
        self.f1_score = 0
        self.count = 0 
    def update_state(self, y_pred, y_true):
        y_pred = F.softmax(y_pred, dim = -1)
        _, y_pred = torch.max(y_pred, dim = -1)
        f1 = metrics.f1_score(y_true.cpu(), y_pred.cpu(), average = 'micro')
        self.f1_score += f1
        self.count += 1
    def reset_states(self):
        self.count = 0
        self.f1_score = 0
    @property
    def value(self):
        if self.count != 0:
            return round(self.f1_score / self.count, 3)
        return  
class L2Distance():
    def __init__(self):
        self.L2Loss = nn.MSELoss()
        self.dist = 0.0
        self.count = 0
    def update_state(self, y_pred_reg, y_true_reg, y_pred_vol, y_true_vol):
        self.dist += self.L2Loss(y_pred_reg, y_true_reg.float()).item()
        self.dist += self.L2Loss(y_pred_vol, y_true_vol.float()).item()
        self.count += 1
    def reset_states(self):
        self.dist = 0
        self.count = 0
    @property 
    def value(self):
        if self.count != 0:
            return round(self.dist / self.count, 3)
        return 0.0
class LossAlpha():
    def __init__(self):
        self.L2Loss = nn.MSELoss()
        self.CELoss = nn.CrossEntropyLoss()
        
        self.loss = 0
        self.count = 0 
    def update_state(self, y_pred_cls, y_true_cls, y_pred_reg, y_true_reg, y_pred_vol, y_true_vol):
        regloss = self.L2Loss(y_pred_reg, y_true_reg.float()) / 5000
        volloss = self.L2Loss(y_pred_vol, y_true_vol.float()) / 5000
        celoss = self.CELoss(y_pred_cls, y_true_cls)  
        
        loss = regloss + volloss + celoss
        self.loss += loss.item()
        self.count += 1
        return loss
    def reset_states(self):
        self.loss = 0
        self.count = 0
    @property
    def value(self):
        if self.count != 0:
            return round(self.loss / self.count, 3)
        return 999999

In [None]:
class AlphaTrainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = self.configure_model()
        
        self.TrainLoss = LossAlpha()
        self.ValLoss = LossAlpha()
        self.ValF1 = F1Score()
        self.ValL2 = L2Distance()
        self.ValAccuracy = Accuracy()
        self.best = {'loss': float('inf'), 'f1': 0.0, 'l2': float('inf')}
        
        self.EPOCHS = 0
    def configure_model(self):
        model = ModelAlpha()
        return model
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr = TrainConfig.lr, weight_decay = TrainConfig.weight_decay)
        #self.lr_decay = optim.lr_scheduler.CosineAnnealingLR(optimizer, 5, 1e-8)
        #self.lr_decay_cycle = optim.lr_scheduler.OneCycleLR(optimizer, TrainConfig.max_lr, total_steps = TrainConfig.num_steps_per_epoch * TrainConfig.num_epochs, epochs = TrainConfig.num_epochs, steps_per_epoch = TrainConfig.num_steps_per_epoch)
        return optimizer
    def training_step(self, batch, batch_idx):
        x, class_idx, volume, weight = batch
        pred_class, pred_weight, pred_volume = self.model(x)
        
        loss = self.TrainLoss.update_state(pred_class, class_idx, pred_weight, weight, pred_volume, volume)
        #self.lr_decay_cycle.step()
        return loss
    def validation_step(self, batch, batch_idx):
        x, class_idx, volume, weight = batch
        pred_class, pred_weight, pred_volume = self.model(x)
        self.ValAccuracy.update_state(pred_class, class_idx)
        self.ValLoss.update_state(pred_class, class_idx, pred_weight, weight, pred_volume, volume)
        self.ValL2.update_state(pred_weight, weight, pred_volume, volume)
        self.ValF1.update_state(pred_class, class_idx)
    def reset_states(self):
        self.EPOCHS += 1
        self.TrainLoss.reset_states()
        self.ValLoss.reset_states()
        self.ValL2.reset_states()
        self.ValF1.reset_states()
    def print_states(self):
        trainLoss = self.TrainLoss.value
        valLoss = self.ValLoss.value
        valF1 = self.ValF1.value
        valL2 = self.ValL2.value
        valAccuracy = self.ValAccuracy.value
        self.log('val_loss', valLoss)
        #self.lr_decay.step()
        if valLoss < self.best['loss']:
            self.best['loss'] = valLoss
            torch.save(self.state_dict(), "./loss.pth")
        if valF1 > self.best['f1']:
            self.best['f1'] = valF1
            torch.save(self.state_dict(), './f1.pth')
        if valL2 < self.best['l2']:
            self.best['l2'] = valL2
            torch.save(self.state_dict(), './l2.pth')
        print(f"E: {self.EPOCHS},VA: {valAccuracy} TL: {trainLoss}, VL: {valLoss}, F1: {valF1}, L2: {valL2}, BL: {self.best['loss']}, BL2: {self.best['l2']}, BF1: {self.best['f1']}")
        
        
    def validation_epoch_end(self, logs):
        self.print_states()
        self.reset_states()

In [None]:
def trainAlpha():
    model = AlphaTrainer()
    train = DataModule.get_cls()
    val = DataModule.get_cls(train = False)
    cbs = [pl.callbacks.EarlyStopping(verbose = True,monitor = 'val_loss', patience = 20)]
    dls = DataLoaders.from_dsets(train, val, batch_size = TrainConfig.batch_size, num_workers = TrainConfig.num_workers, shuffle = True, pin_memory = True)
    if torch.cuda.is_available(): model.cuda(), dls.cuda()
    
    trainer = pl.Trainer(logger = None, callbacks = cbs, checkpoint_callback= False, check_val_every_n_epoch=1, gpus=1, max_epochs=TrainConfig.num_epochs, num_sanity_val_steps=0)
    trainer.fit(model, dls[0], dls[1])

In [None]:
#trainAlpha()

# Inference Code: Testing on a Few Samples

In [None]:
class WeightWatcherAlpha(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.weights2Cal ={
            'apple': 0.52,
            'banana': 0.89,
            'bread': 3.15,
            'bun':2.23,
            'doughnut':4.34,
            'egg': 1.43,
            'fired_dough_twist': 24.16,
            'grape': 0.69,
            'lemon': 0.29,
            'litchi': 0.66,
            'mango': 0.60,
            'mooncake': 18.83,
            'orange': 0.63,
            'peach': 0.57,
            'pear': 0.39,
            'plum': 0.46,
            'qiwi': 0.61,
            'sachima': 21.45,
            'tomato': 0.27
        }
        
        self.cls2idx = {list(self.weights2Cal.keys())[idx]: idx for idx in range(len(self.weights2Cal))}
        self.idx2cls = {idx: list(self.weights2Cal.keys())[idx] for idx in range(len(self.weights2Cal))}
        
        self.weight_path = '../input/weightwatcher/f1.pth'
        self.model = ModelAlpha()
        self.load_state_dict(torch.load(self.weight_path, map_location = self.device))
    def forward(self, x):
        self.eval()
        with torch.no_grad():
            cls_idx, weights, volume = self.model(x)
            cls_idx = F.softmax(cls_idx, dim = -1)
            _, cls_idx = torch.max(cls_idx, dim = -1)
            cls_idx = cls_idx.item()
            
            class_val = self.idx2cls[cls_idx]
            weight = weights.item()
            calories = self.weights2Cal[class_val] * weights
            
            return class_val, weights, volume, calories

In [None]:
model = WeightWatcherAlpha()

In [None]:
val_loader = DataModule.get_cls(train = False)

In [None]:
for images, class_idx, volume, weight in val_loader:
    cls_pred, weight_pred, volume_pred, calories_pred = model(images.unsqueeze(0))
    plt.imshow(images.transpose(0, 1).transpose(1, 2))
    plt.show()
    print('-----------')
    print(class_idx)
    print(cls_pred)
    print('-----')
    print(weight)
    print(weight_pred)
    print('-----------')
    print(volume)
    print(volume_pred)
    print('-----------------')
    print(calories_pred)
    break