In [None]:
target_cols=['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal',
                 'NGT - Abnormal', 'NGT - Borderline', 'NGT - Incompletely Imaged', 'NGT - Normal', 
                 'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal',
                 'Swan Ganz Catheter Present']

# Library

In [None]:
%%capture
# ====================================================
# Library
# ====================================================
import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')
import os
import copy
import math
import ast
import time
import random
import shutil

import scipy as sp
import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.utils import check_random_state
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold, train_test_split

from tqdm.auto import tqdm
from functools import partial

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision

import albumentations
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, HueSaturationValue, CoarseDropout
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm

!pip install livelossplot
import livelossplot
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import os
import random
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Model(Modified ResNet200D with Attention)
The need for attention is so we can do layer-wise MSELoss on Attention Maps

In [None]:
class BatchNormBlock(nn.Module):
    def __init__(self, in_features, out_features, kernel_size, padding, stride, groups):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size = kernel_size, padding = padding, stride= stride, groups = groups)
        self.bn = nn.BatchNorm2d(out_features)
    def forward(self, x):
        return self.bn(F.relu(self.conv(x), inplace = True)) 

Squeeze and Excitation Block using CBAM(MaxPool + AvgPool)

In [None]:
class SpatialSqueezeExcite(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size = 7, padding = 3)
    def forward(self, x):
        '''
        x: Tensor(B, C, H, W)
        '''
        B, C, H, W  = x.shape
        avg_pool = torch.mean(x, dim = 1)
        avg_pool = avg_pool.view(B, 1, H, W)
        
        max_pool, _ = torch.max(x, dim = 1)
        max_pool = max_pool.view(B, 1, H, W)
        
        # Concat
        concatenated = torch.cat([max_pool, avg_pool], dim = 1) # (B, 2, H, W)
        return torch.sigmoid(self.conv(concatenated)) * x

In [None]:
class ChannelSqueezeExcite(nn.Module):
    def __init__(self, in_dim, squeeze_dim):
        super().__init__()
        self.in_dim = in_dim
        self.squeeze_dim = squeeze_dim
        self.Squeeze = nn.Linear(self.in_dim, self.squeeze_dim)
        self.Excite = nn.Linear(self.squeeze_dim, self.in_dim)
    def forward(self, x):
        '''
        x: Tensor(B, C, H, W) 
        '''
        avg_pool = torch.mean(x, dim = -1)
        avg_pool = torch.mean(avg_pool, dim = -1)
        squeeze_avg = F.relu(self.Squeeze(avg_pool))
        
        max_pool, _ = torch.max(x, dim = -1)
        max_pool, _ = torch.max(max_pool, dim = -1) # (B, C)
        squeeze_max = F.relu(self.Squeeze(max_pool))
        expanded = torch.sigmoid(self.Excite(squeeze_max) + self.Excite(squeeze_avg)).unsqueeze(-1).unsqueeze(-1)
        return expanded * x

In [None]:
class CBAMSqueezeExcite(nn.Module):
    def __init__(self, in_dim, inner_dim):
        super().__init__()
        self.in_dim = in_dim
        self.inner_dim = inner_dim
        self.Channel = ChannelSqueezeExcite(self.in_dim, self.inner_dim)
        self.Spatial = SpatialSqueezeExcite()
        self.bn = nn.BatchNorm2d(self.in_dim)
        self.gamma = nn.Parameter(torch.zeros(1, device = device))
    def forward(self, x):
        '''
        x: Tensor(B, C, H, W)
        '''
        values = self.Spatial(self.Channel(x))
        return self.bn(values) * self.gamma + x

CBAM Attention(spatial Attention + Feature Attention)

In [None]:
class SpatialAttention(nn.Module):
    def __init__(self, in_dim, inner_dim, num_heads):
        super().__init__()
        self.in_dim = in_dim
        self.inner_dim = inner_dim
        self.num_heads = num_heads
        self.K = nn.Conv2d(self.in_dim, self.inner_dim * self.num_heads, kernel_size = 1)
        self.V = nn.Conv2d(self.in_dim, self.inner_dim * self.num_heads, kernel_size = 1)
        self.Q = nn.Conv2d(self.in_dim, self.inner_dim * self.num_heads, kernel_size = 1)
        self.Linear = nn.Conv2d(self.inner_dim * self.num_heads, self.in_dim, kernel_size = 1)
    def forward(self, x):
        B, C, H, W = x.shape
        K = self.K(x)
        V = self.V(x)
        Q = self.Q(x) # (B, IH, H, W)
        # Reshape Tensors
        K = K.view(B, self.num_heads, self.inner_dim, H * W)
        V = V.view(B, self.num_heads, self.inner_dim, H * W)
        Q = Q.view(B, self.num_heads, self.inner_dim, H * W)
        # Reshape Again
        K = K.view(B * self.num_heads, self.inner_dim, H * W)
        V = V.view(B * self.num_heads, self.inner_dim, H * W)
        Q = Q.view(B * self.num_heads, self.inner_dim, H * W) # (BH, I, HW)
        
        att_mat = F.softmax(torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(self.inner_dim), dim = -1) # (BH, HW, HW)
        attention_scores = torch.bmm(att_mat, V) # (BH, I, HW)
        
        # Reshape
        attention_scores = attention_scores.view(B, self.num_heads, self.inner_dim, H, W)
        attention_scores = attention_scores.view(B, self.num_heads * self.inner_dim, H, W)
        return self.Linear(attention_scores)

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, in_dim, inner_dim, num_heads):
        super().__init__()
        self.in_dim = in_dim
        self.inner_dim = inner_dim
        self.num_heads = num_heads
        self.K = nn.Conv2d(self.in_dim, self.inner_dim * self.num_heads, kernel_size = 1)
        self.V = nn.Conv2d(self.in_dim, self.inner_dim * self.num_heads, kernel_size = 1)
        self.Q = nn.Conv2d(self.in_dim, self.inner_dim * self.num_heads, kernel_size = 1)
        self.Linear = nn.Conv2d(self.inner_dim * self.num_heads, self.in_dim, kernel_size = 1)
    def forward(self, x):
        B, C, H, W = x.shape
        K = self.K(x)
        V = self.V(x)
        Q = self.Q(x) # (B, IH, H, W)
        # Reshape Tensor
        K = K.view(B, self.num_heads, self.inner_dim, H, W)
        V = V.view(B, self.num_heads, self.inner_dim, H, W)
        Q = Q.view(B, self.num_heads, self.inner_dim, H, W)
        # Reshape Tensors
        K = K.view(B * self.num_heads, self.inner_dim, H * W)
        V = V.view(B * self.num_heads, self.inner_dim, H * W)
        Q = Q.view(B * self.num_heads, self.inner_dim, H * W) # (BH, I, HW) 
        
        att_mat = F.softmax(torch.bmm(Q.transpose(1, 2), K) / math.sqrt(self.inner_dim), dim = -1)
        attended = torch.bmm(att_mat, V.transpose(1, 2)) # (BH, I, HW)
        # Reshape Tensor
        attended = attended.view(B, self.num_heads, self.inner_dim, H, W)
        attended = attended.view(B, self.num_heads * self.inner_dim, H, W)
        return self.Linear(attended) 

In [None]:
class CBAMAttention(nn.Module):
    def __init__(self, in_dim, inner_dim, num_heads):
        super().__init__()
        # Note: removed channel attention to reduce memory overhead
        self.SpatialAtt = SpatialAttention(in_dim, inner_dim, num_heads)
        self.bn = nn.BatchNorm2d(in_dim)
        self.gamma = nn.Parameter(torch.zeros(1, device = device))
    def forward(self, x):
        return self.bn(self.SpatialAtt(x)) * self.gamma + x # Scale down the attention in case the NN doesn't need it.

In [None]:
class InvertedBottleNeck(nn.Module):
    '''
    Inverted BottleNeck Block, as proposed in MobileNetv3
    '''
    def __init__(self, in_dim, inner_dim):
        # Standard Convolutional BottleNeck Block
        super().__init__()
        self.in_dim = in_dim
        self.inner_dim = inner_dim
        self.expand = BatchNormBlock(in_dim, inner_dim, 1, 0, 1, 1)
        self.depthwise = BatchNormBlock(inner_dim, inner_dim, 3, 1, 1, inner_dim)
        
        self.se = CBAMSqueezeExcite(inner_dim, inner_dim // 8)
        
        self.squeeze = BatchNormBlock(inner_dim, in_dim, 1, 0, 1, 1)
        self.bn = nn.BatchNorm2d(in_dim)
        self.gamma = nn.Parameter(torch.zeros(1, device = device))
    def forward(self, x):
        expanded = self.expand(x)
        depthwise = self.depthwise(expanded)
        se = self.se(depthwise)
        squeeze = self.squeeze(se)
        return self.bn(squeeze) * self.gamma + x

In [None]:
class BottleNeck(nn.Module):
    def __init__(self, in_dim, inner_dim):
        '''
        in_features, out_features, kernel_size, padding, stride, groups
        '''
        super().__init__()
        self.in_dim = in_dim
        self.inner_dim = inner_dim
        self.squeeze = BatchNormBlock(self.in_dim, self.inner_dim, 1, 0, 1, 1)
        self.process = BatchNormBlock(self.inner_dim, self.inner_dim, 3, 1, 1, 1)
        self.expand = BatchNormBlock(self.inner_dim, self.in_dim, 1, 0, 1, 1)
        self.bn = nn.BatchNorm2d(self.in_dim)
        self.gamma = nn.Parameter(torch.zeros(1, device = device))
        #self.gamma.requires_grad = False
    def forward(self, x):
        return self.bn(self.expand(self.process(self.squeeze(x)))) * self.gamma + x

In [None]:
class ModifiedResNet3(nn.Module):
    def __init__(self, num_classes, drop_prob = 0.2):
        super().__init__()
        self.num_classes = num_classes
        self.model = timm.create_model("resnet200d_320", pretrained = True, features_only = True)
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
            
        self.layer2_att = CBAMAttention(512, 256, 2)
        self.se2 = CBAMSqueezeExcite(512, 256) 
        self.layer3_att = CBAMAttention(1024,  512, 2)
        self.se3 = CBAMSqueezeExcite(1024, 512) 
        # Custom Layer 3.5(Between layer 3 and 4)
        
        self.layer3_5 = nn.Sequential(*[
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            nn.AvgPool2d(kernel_size = 3, padding = 1, stride = 2),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512),
            BottleNeck(1024, 512)
        ])

        self.layer4_att = CBAMAttention(1024, 512, 2)
        self.se4 = CBAMSqueezeExcite(1024, 512) 
        
        self.Dropout = nn.Dropout(drop_prob)
        self.global_avg = nn.AvgPool2d(kernel_size = 5)
        self.Linear = nn.Linear(2048, self.num_classes)
    def forward(self, x):
        B, _, _, _ = x.shape
        base_features = self.model(x)[-3]
        attention_2 = self.layer2_att(base_features) # (B, 512, 80, 80)
        excited_features2 = self.se2(attention_2) # (B, 512, 80, 80)
        # Layer 3
        layer3 = self.model.layer3(excited_features2) # (B, 1024, 40, 40)
        attention_3 = self.layer3_att(layer3)
        excited_features3 = self.se3(attention_3)
        # Layer 3.5
        layer3_5 = self.layer3_5(excited_features3)
        attention_4 = self.layer4_att(layer3_5)
        excited_features4 = self.se4(attention_4)
        # Layer 4
        layer4 = self.model.layer4(excited_features4) # (B, 2048, 5, 5)
        # AVGPOOL
        global_avg = self.global_avg(layer4).view(B, -1)
        dropped_avg = self.Dropout(global_avg)
        logits = self.Linear(dropped_avg)
        
        return logits, excited_features2, excited_features3, excited_features4, layer4, global_avg # Return Layer4 and Logits for MSELoss

In [None]:
def replace_bn_student(module):
    '''
    Recursively put desired batch norm in nn.module module
    set module = net to start code.
    This method used to replace parameters for the student model
    '''
    # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
    for attr_str in dir(module):
        target_attr = getattr(module, attr_str)
        if type(target_attr) == torch.nn.GroupNorm:
            new_bn = nn.BatchNorm2d(target_attr.num_channels, affine = True, track_running_stats = True)
            setattr(module, attr_str, new_bn)
        elif type(target_attr) == torch.nn.Sequential:
            for i in range(len(target_attr)):
                if type(target_attr[i]) == torch.nn.GroupNorm:
                   new_bn = nn.BatchNorm2d(target_attr[i].num_channels, affine = True, track_running_stats = True)
                   target_attr[i] = new_bn
                else:
                    replace_bn_student(target_attr[i])

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
    for name, immediate_child_module in module.named_children():
        replace_bn_student(immediate_child_module)

# Teacher Model Stage 1 Train

In [None]:
class TeacherModel(nn.Module):
    def __init__(self, num_classes, device):
        super().__init__()
        self.device = device
        self.num_classes = num_classes
        self.model = ModifiedResNet3(11, drop_prob = 0.2)
        for parameter in self.model.parameters():
            parameter.requires_grad = True
        self.optim = optim.Adam(self.model.parameters(), lr = 1e-4, weight_decay = 1e-4)
        self.lr_decay = optim.lr_scheduler.CosineAnnealingLR(self.optim, 4, eta_min = 1e-7)
        self.lr_decay2 = optim.lr_scheduler.ExponentialLR(self.optim, 0.95)
        self.criterion = nn.BCEWithLogitsLoss()
    def forward(self, x):
        self.eval()
        with torch.no_grad():
            return self.model(x)
    def roc_auc(self, y_pred, y_true):
        '''
        Computes ROC_AUC
        y_pred: tensor(B, 11)
        y_true: Tensor(B, 11)
        ROC AUC doesn't work for this dataset unless batch size large enough so oof.
        '''
        acc = 0
        B, C = y_pred.shape
        ones = y_pred >= 0.5
        y_pred[:, :] = 0
        y_pred[ones] = 1
        for i in range(B):
            acc += torch.sum((y_pred[i, :] == y_true[i, :]).int()) / C
        return acc / B
        
    def training_loop(self, trainloader, valloader, NUM_EPOCHS, display_every = 64):
        liveloss = livelossplot.PlotLosses()
        best_val_acc = 0.0
        best_val_loss = 9999
        torch.cuda.empty_cache()
        for EPOCH in range(NUM_EPOCHS):
            self.train()
            total_loss = 0.0
            count = 0
            logs = {}
            '''
            for _, annotated_images, labels in trainloader:
                self.optim.zero_grad()
                annotated_images = annotated_images.to(self.device).to(torch.float32)
                labels = labels.to(self.device).to(torch.float32)
                pred, _, _, _, _, _ = self.model(annotated_images)
                loss = self.criterion(pred,labels) 
                total_loss += loss.item()
                loss.backward()
                self.optim.step()
                del annotated_images
                del labels
                del pred
                del loss
                count += 1
                torch.cuda.empty_cache()
                if count == display_every:
                    break
            logs['loss'] = total_loss / count 
            print(f"EPOCH: {EPOCH}, total_loss: {logs['loss']}")
            '''
            self.eval()
            with torch.no_grad():
                logs['accuracy'] = 0
                logs['val_loss'] = 0
                count = 0 
                for _, val_annotated_images, val_labels in valloader:
                    val_annotated_images = val_annotated_images.to(self.device).to(torch.float32)
                    val_labels = val_labels.to(self.device).to(torch.float32)
                    pred, _, _, _, _, _ = self.model(val_annotated_images)
                    logs['accuracy'] += self.roc_auc(torch.sigmoid(pred), val_labels)
                    logs['val_loss'] += self.criterion(pred, val_labels).item()
                    del val_annotated_images
                    del val_labels
                    del pred
                    count += 1
                    torch.cuda.empty_cache()
                logs['accuracy'] /=count
                logs['val_loss'] /=count
            self.lr_decay.step() # Cosine Annealing
            self.lr_decay2.step() # Slowly lower the LR
            print(f"Accuracy: {logs['accuracy']}, loss: {logs['val_loss']}")
            liveloss.update(logs)
            liveloss.send()
            if logs['val_loss'] < best_val_loss:
                best_val_loss = logs['val_loss']
                torch.save(self.state_dict(), 'teacher.pth')

In [None]:
# ====================================================
# Dataset
# ====================================================
COLOR_MAP = {'ETT - Abnormal': (255, 0, 0),
             'ETT - Borderline': (0, 255, 0),
             'ETT - Normal': (0, 0, 255),
             'NGT - Abnormal': (255, 255, 0),
             'NGT - Borderline': (255, 0, 255),
             'NGT - Incompletely Imaged': (0, 255, 255),
             'NGT - Normal': (128, 0, 0),
             'CVC - Abnormal': (0, 128, 0),
             'CVC - Borderline': (0, 0, 128),
             'CVC - Normal': (128, 128, 0),
             'Swan Ganz Catheter Present': (128, 0, 128),
            }


class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, df, df_annotations, annot_size=50, transform=None):
        self.df = df
        self.df_annotations = df_annotations
        self.annot_size = annot_size
        self.file_names = df.index.values
        self.labels = df.iloc[:, :-1].values
        self.transform = transform
        self.train_path = "../input/ranzcr-clip-catheter-line-classification/train/"

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{self.train_path}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        base_image = copy.deepcopy(image)
        query_string = f"StudyInstanceUID == '{file_name}'"
        df = self.df_annotations.query(query_string)
        # Add in the Annotations into the image
        for i, row in df.iterrows():
            label = row["label"]
            data = np.array(ast.literal_eval(row["data"]))
            for d in data:
                image[d[1]-self.annot_size//2:d[1]+self.annot_size//2,
                      d[0]-self.annot_size//2:d[0]+self.annot_size//2,
                      :] = COLOR_MAP[label]
        if self.transform:
            augmentations = self.transform(image = base_image, image1 = image)
            orig_augmented = augmentations['image']
            image = augmentations['image1']
        label = torch.tensor(self.labels[idx]).float()
        return orig_augmented, image, label

In [None]:
# ====================================================
# Transforms
# ====================================================
IMAGE_SIZE = 320
train_transforms = Compose([
            RandomResizedCrop(IMAGE_SIZE, IMAGE_SIZE, scale = (0.9, 0.9)),
            HorizontalFlip(p=0.5),
            #albumentations.VerticalFlip(p = 0.2),
            RandomBrightnessContrast(p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2)),
            HueSaturationValue(p=0.2, hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2),
            ShiftScaleRotate(p=0.2, shift_limit=0.0025, scale_limit=0.01, rotate_limit=10),
            CoarseDropout(p=0.2),
            albumentations.GaussianBlur(),
            albumentations.GaussNoise(),
            Cutout(p=0.2, max_h_size=16, max_w_size=16, fill_value=(0., 0., 0.), num_holes=16),
            Normalize(),
            ToTensorV2(),
        ], additional_targets = {'image1': 'image'})
test_transforms = Compose([
            Resize(IMAGE_SIZE, IMAGE_SIZE),
            Normalize(),
            ToTensorV2(),
        ], additional_targets = {'image1': 'image'})

In [None]:
train = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/train.csv')
train_annotations = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/train_annotations.csv', index_col = "StudyInstanceUID")
train_idx = train.isin(train_annotations.index)['StudyInstanceUID']
train_with_annotations = train[train_idx]
train_with_annotations = train_with_annotations.set_index("StudyInstanceUID")

In [None]:
%%capture
teacher_model = TeacherModel(11, device)
teacher_model.to(device)

In [None]:
#teacher_model.load_state_dict(torch.load("../input/teachermodelbn/teacher.pth", map_location = device))

In [None]:
#teacher_model.training_loop(train_dataloader, val_dataloader, 50)

In [None]:
#torch.save(teacher_model.state_dict(), "./TeacherModel.pth")

# Student Model Stage 2 Train

In [None]:
class CustomConvBlock(nn.Module):
    def __init__(self, in_features, out_features, kernel_size, padding, groups):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size, padding = padding, groups = groups)
        self.bn = nn.BatchNorm2d(out_features)
        self.act = nn.SiLU(inplace = True)
    def forward(self, x):
        return self.bn(self.act(self.conv(x)))

In [None]:
class ConvSqueezeExcite(nn.Module):
    '''
    Standard Convolutional Based Squeeze Excitation Block 
    '''
    def __init__(self, in_features, inner_features):
        super().__init__()
        self.in_features= in_features
        self.inner_features = inner_features
        self.Squeeze = nn.Conv2d(self.in_features, self.inner_features,kernel_size = 1)
        self.act = nn.SiLU(inplace = True)
        self.Excite = nn.Conv2d(self.inner_features, self.in_features, kernel_size = 1)
    def forward(self, x):
        squeezed = self.act(self.Squeeze(x))
        excited = torch.sigmoid(self.Excite(squeezed))
        return excited * x

In [None]:
class InverseBottleNeck(nn.Module):
    def __init__(self, in_features, inner_features, device):
        super().__init__()
        self.device = device
        self.in_features = in_features
        self.inner_features = inner_features
        self.Expand = CustomConvBlock(self.in_features, self.inner_features, 1, 0, 1)
        self.Depthwise = CustomConvBlock(self.inner_features, self.inner_features, 3, 1, self.inner_features)
        self.SE = ConvSqueezeExcite(self.inner_features, self.in_features // 4)
        self.Squeeze = CustomConvBlock(self.inner_features, self.in_features, 1, 0, 1)
        self.gamma = nn.Parameter(torch.zeros((1), device = self.device))
    def forward(self, x):
        expand = self.Expand(x)
        depthwise = self.Depthwise(expand)
        SE = self.SE(depthwise)
        Squeeze = self.Squeeze(SE)
        return Squeeze * self.gamma + x

In [None]:
class ModifiedResNetBeta(nn.Module):
    def __init__(self, num_classes, drop_prob = 0.0):
        super().__init__()
        self.num_classes = num_classes
        self.model = timm.create_model("resnet200d_320", pretrained = True)
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
        self.Dropout = nn.Dropout(p = drop_prob)
        self.global_avg = nn.AdaptiveAvgPool2d((1, 1))
        self.Dense = nn.Linear(2048, self.num_classes)
    def forward(self, x):
        '''
        x: Image input.
        '''
        features = torch.squeeze(self.global_avg(self.model.forward_features(x)))
        dropped = self.Dropout(features)
        return self.Dense(dropped), 0, 0, 0, 0, 0

In [None]:
class ModifiedResNetAlpha(nn.Module):
    def freeze(self, layer):
        for parameter in layer.parameters():
            parameter.requires_grad = False
    def __init__(self, num_classes, device, drop_prob = 0.2):
        # modified ResNet Student, with additional processing and one less attention head for maximal performance
        super().__init__()
        self.device = device
        self.num_classes = num_classes
        self.model = timm.create_model("resnet200d_320", pretrained = True, features_only = True)
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
        
        self.conv1 = self.model.conv1
        self.bn1 = self.model.bn1
        
        self.layer1 = self.model.layer1
        
        self.layer2 = self.model.layer2
        
        self.layer3 = self.model.layer3
        
        # Freeze Initial Layers
        #self.freeze(self.layer1)
        #self.freeze(self.conv1)
        #self.freeze(self.bn1)
        #self.freeze(self.layer2)
        #self.freeze(self.layer3) 
        
        self.layer4 = self.model.layer4 # Extract layers from ResNet
        
        self.layer2_5 = nn.Sequential(*[
            BottleNeck(512, 128) for i in range(12) # 18 BottleNeck Blocks, Really Really Cheap Operation so use wisely.
        ])
        
        self.layer2_att = nn.Identity()#CBAMAttention(512, 256, 2)
        self.se2 = nn.Identity()#CBAMSqueezeExcite(512, 256)
        
        self.layer3_att = nn.Identity()#CBAMAttention(1024, 512, 2)
        self.se3 = nn.Identity()#CBAMSqueezeExcite(1024, 256) 
        # Custom Layer 3.5(Between layer 3 and 4)
        
        self.layer3_5 = nn.Sequential(*[BottleNeck(1024, 256) for i in range(15)]) #+
            #[nn.MaxPool2d(kernel_size = 3, padding = 1, stride = 2)] + 
            #[BottleNeck(1024, 256) for i in range(12)])

        
        self.layer4_att = nn.Identity()#CBAMAttention(1024, 512, 2)
        self.se4 = nn.Identity()#CBAMSqueezeExcite(1024, 256) 
        
        self.Dropout = nn.Dropout(drop_prob)
        self.global_avg = nn.AdaptiveAvgPool2d((1, 1))
        self.Linear = nn.Linear(2048, self.num_classes)
       
        
    def forward(self, x):
        B, _, _, _ = x.shape
        base_features = self.model(x)[-3] # (B, 512, 80, 80)
        # Process through Layer 2.5
        layer2_5 = self.layer2_5(base_features) # (B, 512, 80, 80)
        
        attention2 = self.layer2_att(layer2_5)
        excited_features2 = self.se2(attention2)
    
        layer3 = self.layer3(excited_features2) # (B, 1024, 40, 40)
        
        attention3 = self.layer3_att(layer3)
        excited_features3 = self.se3(attention3)
        
        layer3_5 = self.layer3_5(excited_features3)
        attention4 = self.layer4_att(layer3_5)
        excited_features4 = self.se4(attention4) # (B, 1024, 10, 10)
        
        layer4 = self.layer4(excited_features4) # (B, 2048, 5, 5)
        # Pooled Features
        avg_pooled = torch.squeeze(self.global_avg(layer4)) # (B, 2048)
        dropped_features = self.Dropout(avg_pooled) # (B, 2048)
        
        logits = self.Linear(dropped_features) # (B, 11)
        
        return logits, excited_features2, excited_features3, excited_features4, layer4, avg_pooled 

In [None]:
class StudentTeacher(nn.Module):
    def __init__(self, num_classes, teacher, device):
        super().__init__()
        self.num_classes = num_classes
        self.device = device
        self.teacher = teacher
        for parameter in self.teacher.parameters():
            parameter.requires_grad = False
        self.student = ModifiedResNetAlpha(self.num_classes, self.device, drop_prob = 0.2)
        self.optim = optim.Adam(self.student.parameters(), lr = 1e-4, weight_decay = 1e-3)
        self.lr_decay = optim.lr_scheduler.StepLR(self.optim, 5, 0.9)
        self.lr_decay2 = optim.lr_scheduler.CosineAnnealingLR(self.optim, 5, eta_min = 1e-7)
        self.MSELoss = nn.MSELoss()
        self.criterion = nn.BCEWithLogitsLoss()
        self.BCELoss = nn.BCELoss()
        self.copy_weight = 2 # Weight on how much student should copy teacher 
        self.logits_weight = 0.5 # Weight on how much student should get answer right
        self.sigmoid_temp = 1.5
    def forward(self, x):
        '''
        Runs Inference on the student model
        '''
        self.eval()
        with torch.no_grad():
            pred, _, _, _, _, _= self.student(x)
            return pred
    def sigmoid_temperature(self, teacher_logits, students_logits):
        '''
        Performs BCE Loss on the teacher_logits with sigmoid temperature.
        
        May be a bit unstable of a loss function. 
        '''
        teacher_sigmoid = torch.sigmoid(teacher_logits / self.sigmoid_temp)
        student_sigmoid = torch.sigmoid(students_logits / self.sigmoid_temp)
        return self.BCELoss(student_sigmoid, teacher_sigmoid) * self.copy_weight
    def compare_loss(self, teacher, student):
        '''
        Comparison loss if you need to compare features between the student and teacher.
        '''
        return self.MSELoss(student.view(-1), teacher.view(-1)) * self.copy_weight
    def evaluate(self, pred, y_pred, val = False):
        if val:
            return self.criterion(pred, y_pred)
        return self.criterion(pred, y_pred) * self.logits_weight
    def accuracy(self, student_logits, y_true):
        ones = student_logits >= 0.5
        student_logits[:] = 0
        student_logits[ones] = 1
        B, C = student_logits.shape
        return torch.sum((student_logits == y_true).int()) / B / C
    def training_loop(self, train_dataloader, val_dataloader, NUM_EPOCHS, display_every = 64):
        liveloss = livelossplot.PlotLosses()
        best_val_loss = 999
        best_val_acc = 0
        torch.cuda.empty_cache() # Clear any wasted cuda memory
        self.teacher.eval()
        for EPOCH in range(NUM_EPOCHS):
            logs = {}
            self.student.train()
            total_loss = 0
            count = 0
            logs['val_loss'] = 0
            for images, annotated_images, labels in train_dataloader:
                self.optim.zero_grad()
                images = images.to(self.device)
                annotated_images = annotated_images.to(self.device)
                labels = labels.to(self.device)
                
                with torch.no_grad():
                    teacher_logits, _, _, _, _, _ = self.teacher(annotated_images)
                student_logits, _, _, _, _, _ = self.student(images)
                
                copy_loss = self.sigmoid_temperature(teacher_logits, student_logits)
                logits_loss = self.evaluate(student_logits, labels) 
                
                logs['val_loss'] += self.evaluate(student_logits, labels, val = True).item()
                loss = logits_loss + copy_loss
                loss.backward()
                self.optim.step()
                total_loss += loss.item()
                count += 1
                del student_logits
                del teacher_logits
                del loss
                del labels
                del annotated_images
                del images
                torch.cuda.empty_cache()
                if count == display_every:
                    break
            logs['loss'] = total_loss / count
            logs['val_loss'] /= count
            del total_loss
            del count
            self.lr_decay.step()
            
            print(f"EPOCH: {EPOCH}, loss: {logs['loss']}")
            self.student.eval()
            with torch.no_grad():
                count = 0
            
                logs['accuracy'] = 0
                logs['val_accuracy'] = 0
                for images, annotated_images, labels in val_dataloader:
                    images = images.to(self.device)
                    annotated_images = annotated_images.to(self.device)
                    labels = labels.to(self.device)
                    student_logits, _, _, _, _, _ = self.student(images)
                
                    logs['accuracy'] += self.accuracy(torch.sigmoid(student_logits), labels)
                    logs['val_accuracy'] += self.evaluate(student_logits, labels, val = True).item()
                    del student_logits
                    del labels
                    del annotated_images
                    del images
                    torch.cuda.empty_cache()
                    count += 1
                logs['accuracy'] /= count
                logs['val_accuracy'] /= count
            print(f"Val_LOSS: {logs['val_loss']}, accuracy: {logs['accuracy']}")
    
            liveloss.update(logs)
            liveloss.send()
            if logs['val_accuracy'] <= best_val_loss:
                best_val_loss = logs['val_accuracy']
                torch.save(self.state_dict(), './BestVal.pth')
            if logs['accuracy'] >= best_val_acc:
                best_val_acc = logs['accuracy']
                torch.save(self.state_dict(), "./BestAcc.pth")
            print(f"Val_Loss: {logs['val_accuracy']}")

In [None]:
'''
train_split, val_split = train_test_split(train_with_annotations, train_size = 0.9999, test_size = 0.0001)
## Filter out the Overlapping entries in Train and Val_set
train_unique = set(train_split.PatientID.values)
val_unique = set(val_split.PatientID.values)
overlap = train_unique.intersection(val_unique)
#
train_idx = []
for overlapped in overlap:
    train_idx += train_split.index[train_split.PatientID == overlapped].to_list()
val_split = val_split.append(train_split.loc[train_idx])
train_split = train_split.drop(train_idx)

train_dataset = TrainDataset(train_split, train_annotations, transform = train_transforms) 
val_dataset = TrainDataset(val_split, train_annotations, transform = test_transforms)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 16, shuffle = True) # 14 is the maximum for ResNetStudent, 24 batch size for ResNetBeta.
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size = 32)
'''

In [None]:
%%capture
studentTeacher = StudentTeacher(11, teacher_model, device)
studentTeacher.to(device)

In [None]:
#studentTeacher.load_state_dict(torch.load("../input/simplemodel/BestVal.pth"))

In [None]:
#studentTeacher.training_loop(train_dataloader, val_dataloader, 50, display_every = 128)

In [None]:
#torch.save(studentTeacher.state_dict(), "./LastEpoch.pth")

# Student Model Stage 3 Train

Load in the Full Training Set, as annotations no longer needed.

In [None]:
# Load in the full dataset
train = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/train.csv', index_col = "StudyInstanceUID")
train, val = train_test_split(train, train_size = 0.9999, test_size = 0.0001)
# Find Leaking Patient Data
unique_train = set(train.PatientID.values)
unique_val = set(val.PatientID.values)
same_values = unique_train.intersection(unique_val)
overlapping_train_vals = []
for overlap in same_values:
    overlapped = train.index[(train.PatientID == overlap)].to_list()
    overlapping_train_vals += overlapped
val = val.append(train.loc[overlapping_train_vals])
train = train.drop(overlapping_train_vals)


In [None]:
class TrainDatasetNoAnnot(torch.utils.data.Dataset):
    def __init__(self, x, base_file, transforms):
        self.x = x
        self.files = x.index
        self.transforms = transforms 
        self.base_file = base_file
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        file_val = str(self.files[idx])
        file_path = self.base_file + file_val + ".jpg"
        GT = torch.tensor(self.x.loc[file_val])
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        transformed_image = torch.tensor(self.transforms(image = image)['image'])
        return transformed_image, GT
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, base_file, transforms):
        self.base_file = base_file
        self.test_images = os.listdir(self.base_file)
        self.transforms = transforms
    def __len__(self):
        return len(self.test_images)
    def __getitem__(self, idx):
        file_val = self.test_images[idx].strip('.jpg')
        file_path = self.base_file + file_val + '.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        transformed_image = torch.tensor(self.transforms(image = image)['image'])
        return transformed_image, file_val


In [None]:

training_dataset = TrainDatasetNoAnnot(train.iloc[:, :-1], "../input/ranzcr-clip-catheter-line-classification/train/", train_transforms)
train_dataloader = torch.utils.data.DataLoader(training_dataset, shuffle = True, batch_size = 16)

val_dataset = TrainDatasetNoAnnot(val.iloc[:, :-1], "../input/ranzcr-clip-catheter-line-classification/train/", test_transforms)
val_dataloader = torch.utils.data.DataLoader(val_dataset, shuffle = False, batch_size = 32)

test_dataset = TestDataset("../input/ranzcr-clip-catheter-line-classification/test/", test_transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 16)


In [None]:
class StudentSolver(nn.Module):
    def __init__(self, student, device):
        super().__init__()
        self.device = device
        self.student = student
        for parameter in self.student.parameters():
            parameter.requires_grad = True # Unfreeze all Weights
        self.optim = optim.Adam(self.student.parameters(), 1e-6, weight_decay = 1e-3)
        self.lr_decay = optim.lr_scheduler.CosineAnnealingLR(self.optim, 5, eta_min = 1e-9)
        self.lr_decay2 = optim.lr_scheduler.StepLR(self.optim, 5, 0.95)
        self.criterion = nn.BCEWithLogitsLoss()
    def forward(self, x):
        '''
        Runs Inference on the Student Model, performing a sigmoid on the logits
        '''
        self.eval()
        with torch.no_grad():
            logits = torch.sigmoid(self.student(x))
            return logits
    def accuracy(self, y_pred, y_true):
        one = y_pred >= 0.5
        y_pred[:] = 0
        y_pred[one] = 1
        B, C = y_pred.shape
        return torch.sum((y_pred == y_true).int()) / B / C
    def training_loop(self, trainloader, valloader, NUM_EPOCHS, display_every = 1):
        '''
        Trains the Model 
        '''
        liveloss = livelossplot.PlotLosses()
        best_val_acc = 0
        best_val_loss = 999
        for EPOCH in range(NUM_EPOCHS):
            
            self.train()
            logs = {}
            '''
            count = 0
            total_loss = 0
            for images, labels in trainloader:
                self.optim.zero_grad()
                images = images.to(self.device).to(torch.float32)
                labels = labels.to(self.device).to(torch.float32)
                pred, _, _, _, _, _ = self.student(images)
                loss = self.criterion(pred, labels)
                loss.backward()
                self.optim.step()
                count += 1
                total_loss += loss.item()
                del images
                del labels
                del pred
                del loss
                torch.cuda.empty_cache()
                if count == display_every:
                    break
            logs['loss'] = total_loss / count
            print(f"EPOCH: {EPOCH}, total_loss: {logs['loss']}")
            self.lr_decay.step()
            self.lr_decay2.step()
            '''
            self.eval()
            with torch.no_grad():
                logs['val_loss'] = 0
                logs['accuracy'] = 0
                count = 0
                for images, labels in valloader:
                    images = images.to(self.device).to(torch.float32)
                    labels = labels.to(self.device).to(torch.float32)
                    pred, _, _, _, _, _ = self.student(images)
                    logs['val_loss'] += self.criterion(pred, labels).item()
                    logs['accuracy'] += self.accuracy(torch.sigmoid(pred), labels)
                    count += 1
                    del images
                    del labels
                    del pred
                    torch.cuda.empty_cache()
                logs['val_loss'] /= count
                logs['accuracy'] /= count
            print(f"ACCURACY: {logs['accuracy']}, loss: {logs['val_loss']}")
            liveloss.update(logs)
            liveloss.send()
            if logs['val_loss'] <= best_val_loss:
                best_val_loss = logs['val_loss']
                torch.save(self.state_dict(), "./BestLoss.pth")
            if logs['accuracy'] >= best_val_acc:
                best_val_acc = logs['accuracy']
                torch.save(self.state_dict(), "./BestAcc.pth")
        

In [None]:
%%capture
studentSolver = StudentSolver(studentTeacher.student, device)
studentSolver.to(device)

In [None]:
studentSolver.load_state_dict(torch.load("../input/stage3/FinalModel.pth"))

In [None]:
studentSolver.training_loop(train_dataloader, val_dataloader, 100, display_every = 64)

In [None]:
torch.save(studentSolver.state_dict(), "./FinalModel.pth")