In [1]:
# Standard library imports
import copy
import datetime
import gc
import glob
import heapq
import joblib
import math
import os
import random
import time
from collections import defaultdict
from itertools import chain

# Related third-party imports
import albumentations as A
import cv2
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from albumentations.pytorch import ToTensorV2
from colorama import Fore, Back, Style
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from PIL import Image
from skimage import io
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, f1_score, accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm, tqdm_notebook

# Local application/library specific imports
# (Your local imports here, if any)

# Set up for colored terminal text
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

# Disable warnings
import warnings
warnings.filterwarnings('ignore')

# Additional note: The following import seems to be from a very specific context or not typically used.
# You might want to review if it's necessary or correct:
# from joblib.externals.loky.backend.context import get_context


# 3. Model Architecture

In [2]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

In [3]:
class UBCModel(nn.Module):
    '''
    EfficientNet B0 fine-tune.
    '''
    def __init__(self, model_name, num_classes, pretrained=False, checkpoint_path=None):
        '''
        Fine tune for EfficientNetB0
        Args
            n_classes : int - Number of classification categories.
            learnable_modules : tuple - Names of the modules to fine-tune.
        Return
            
        '''
        super(UBCModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path=checkpoint_path)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.linear = nn.Linear(in_features, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, images):
        """
        Forward function for the fine-tuned model
        Args
            x: 
        Return
            result
        """
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        output = self.linear(pooled_features)
        return output

In [4]:
# This model is used for binary classification on cancerous vs non-cancerous tiles 
class UBCBinaryModel(nn.Module):
    '''
    EfficientNet B0 fine-tune.
    '''
    def __init__(self, model_name, pretrained=False, checkpoint_path=None):
        '''
        Fine tune for EfficientNetB0
        Args
            learnable_modules : tuple - Names of the modules to fine-tune.
        Return
            
        '''
        super(UBCBinaryModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path=checkpoint_path)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.linear = nn.Linear(in_features, 1)
        

    def forward(self, images):
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        logits = self.linear(pooled_features)
        return logits

In [5]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# 4. Training

In [6]:
def fetch_scheduler(optimizer, CONFIG):
    if CONFIG['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'CosineAnnealingWarmRestarts':
        CONFIG['T_0'] = 10
        CONFIG['T_mult'] = 2
        CONFIG['min_lr'] = 1e-6
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CONFIG['T_0'], T_mult=CONFIG['T_mult'],
                                                             eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'ReduceLROnPlateau':
        scheduler =  ReduceLROnPlateau(optimizer, mode='min', factor=kwargs.get('factor', 0.1), patience=kwargs.get('patience', 5), verbose=False)
    elif CONFIG['scheduler'] == 'LambdaLR':
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda)
    elif CONFIG['scheduler'] == None:
        return None
    return scheduler

def get_optimizer(optimizer_name, model):
    if optimizer_name.lower() == "adam":
        CONFIG['learning_rate'] = 1e-4
        CONFIG['weight_decay'] = 1e-5
        CONFIG['betas'] = (0.9, 0.999)
        CONFIG['eps'] = 1e-8
        optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], betas=CONFIG['betas'], eps=CONFIG['eps'],  weight_decay=CONFIG['weight_decay'])
    elif optimizer_name.lower() == "sgd":
        CONFIG['learning_rate'] = 1e-3
        CONFIG['weight_decay'] = 1e-3
        CONFIG['momentum'] = 1e-3
        optimizer = optim.SGD(model.parameters(), lr=CONFIG['learning_rate'], momentum=CONFIG['momentum'], weight_decay=CONFIG['weight_decay'])
    elif optimizer_name.lower() == "radam":
        CONFIG['learning_rate'] = 1e-4
        CONFIG['weight_decay'] = 0
        CONFIG['betas'] = (0.9, 0.999)
        CONFIG['eps'] = 1e-8
        optimizer = torch_optimizer.RAdam(
            model.parameters(),
            lr= CONFIG['learning_rate'],
            betas=CONFIG['betas'],
            eps=CONFIG['eps'],
            weight_decay=CONFIG['weight_decay'],
        )
    elif optimizer_name.lower() == "rmsprop":
        CONFIG['learning_rate'] = 0.256
        CONFIG['alpha'] = 0.9
        CONFIG['momentum'] = 0.9
        CONFIG['weight_decay'] = 1e-5
        optimizer = optim.RMSprop(model.parameters(), lr=CONFIG['learning_rate'], alpha=CONFIG['learning_rate'], 
                                  momentum=CONFIG['learning_rate'], weight_decay=CONFIG['learning_rate'])
    else:
        raise ValueError("Invalid Optimizer given!")
    return optimizer
    

