In [1]:
import os
import gc
import cv2
import math
import copy
import time
import random
import glob
from matplotlib import pyplot as plt

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import torchvision

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, f1_score

import albumentations as A
from albumentations.pytorch import ToTensorV2

# For Image Models
import timm

# Albumentations for augmentations
# import albumentations as A
# from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings


## 1. Data Preprocessing

In [2]:
class UBCDataset(Dataset):
    def __init__(self, df, transforms=None, apply_vertical_crop=True):
        self.df = df
        self.filenames = df.file_path.values
        self.labels =  df.target_label.values
        self.transforms = transforms
        self.apply_vertical_crop = apply_vertical_crop

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.apply_vertical_crop:
            img = crop_vertical(img)
                
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            "image": img,
            "label": torch.tensor(self.labels[idx], dtype=torch.long)
               }

def crop_vertical(image):
    """
    Function crops images if multiple slices contained and separated by black vertical background.
    """
    vertical_sum = np.sum(image, axis=(0, 2))

    # Identify the positions where the sum is zero
    zero_positions = np.where(vertical_sum == 0)[0]

    if len(zero_positions)==0:
        cropped_images = [image]
    else:
        # If the image does not start with a black area, add index 0
        if zero_positions[0] != 0:
            zero_positions = np.insert(zero_positions, 0, 0)

        # If the image does not end with a black area, add the image width
        if zero_positions[-1] != image.shape[1] - 1:
            zero_positions = np.append(zero_positions, image.shape[1] - 1)

        start_idx = zero_positions[0]
        cropped_images = []

        for idx in range(1, len(zero_positions)):
            end_idx = zero_positions[idx]
            if end_idx - start_idx > 1:  # If the width of the cropped section is greater than 1
                cropped = image[:, start_idx:end_idx]
                # only include samples which are of min size
                if cropped.shape[1]>200:  
                    cropped_images.append(cropped)
                    # cv2.imwrite(f"{save_prefix}_{idx}.jpg", cropped)
            start_idx = end_idx

    final_crops = []
    # remove black bars above/below the crops 
    for cropped in cropped_images:
        horizontal_sum = np.sum(cropped, axis=(1, 2))
        zero_positions = np.where(horizontal_sum == 0)[0]
        img_ = np.delete(cropped, zero_positions, axis=0)
        final_crops.append(img_)
    if len(final_crops)==0:
        return image
    return final_crops[0]


def custom_center_crop_or_resize(image, crop_size):
    # If both dimensions of the image are greater than or equal to the desired size, apply CenterCrop
    if image.shape[0] >= crop_size[0] and image.shape[1] >= crop_size[1]:
        return A.CenterCrop(crop_size[0], crop_size[1])(image=image)["image"]
    # Else, just resize the image to the desired size
    else:
        return A.Resize(crop_size[0], crop_size[1])(image=image)["image"]

In [3]:
data_transforms = {
    "train": A.Compose([
        A.RandomResizedCrop(512, 512, scale=(0.8, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.2),
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
        A.CoarseDropout(p=0.2),
        A.Cutout(p=0.2),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(512, 512),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        ToTensorV2()], p=1.)
}



In [4]:
def get_dataloaders(df, fold, CONFIG):
    df_train = df[df["kfold"]!=fold].reset_index(drop=True)
    df_valid = df[df["kfold"]==fold].reset_index(drop=True)
    
    train_dataset = UBCDataset(df_train, transforms=data_transforms["train"])
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    valid_dataset = UBCDataset(df_valid, transforms=data_transforms["valid"])
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    return train_loader, valid_loader, df_train, df_valid

### 2. Model Architecture

In [5]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'


class EfficientNetB0(nn.Module):
    '''
    EfficientNet B0 fine-tune.
    '''
    def __init__(self, model_name, num_classes, pretrained=False, checkpoint_path=None):
        '''
        Fine tune for EfficientNetB0
        Args
            n_classes : int - Number of classification categories.
            learnable_modules : tuple - Names of the modules to fine-tune.
        Return
            
        '''
        super(EfficientNetB0, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path=checkpoint_path)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.linear = nn.Linear(in_features, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, images):
        """
        Forward function for the fine-tuned model
        Args
            x: 
        Return
            result
        """
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        output = self.linear(pooled_features)
        return output

In [6]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [7]:
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, OneCycleLR

def get_lr_scheduler(optimizer, scheduler_type, **kwargs):
    if scheduler_type == 'steplr':
        return StepLR(optimizer, step_size=kwargs.get('step_size', 10), gamma=kwargs.get('gamma', 0.1))
    
    elif scheduler_type == 'plateau':
        return ReduceLROnPlateau(optimizer, mode='min', factor=kwargs.get('factor', 0.1), patience=kwargs.get('patience', 5), verbose=True)
    
    elif scheduler_type == 'onecycle':
        return OneCycleLR(optimizer, max_lr=kwargs.get('max_lr', 0.01), steps_per_epoch=len(train_loader), epochs=kwargs.get('num_epochs', 10))
    
    else:
        raise ValueError(f"Unknown scheduler type: {scheduler_type}")
        
def fetch_scheduler(optimizer, CONFIG):
    if CONFIG['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CONFIG['T_0'], 
                                                             eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'ReduceLROnPlateau':
        scheduler =  ReduceLROnPlateau(optimizer, mode='min', factor=kwargs.get('factor', 0.1), patience=kwargs.get('patience', 5), verbose=False)
    elif CONFIG['scheduler'] == None:
        return None
        
    return scheduler

def get_optimizer(optimizer_name, model, CONFIG):
    if optimizer_name.lower() == "adam":
            optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'],  weight_decay=CONFIG['weight_decay'])
    elif optimizer_name.lower() == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=CONFIG['learning_rate'], momentum=CONFIG['momentum'], weight_decay=CONFIG['weight_decay'])
    else:
        raise ValueError("Invalid Optimizer given!")
    return optimizer

In [8]:
def convert_dict_to_tensor(dict_):
    """Converts the values of a dict into a PyTorch tensor."""

    # Create a new PyTorch tensor
    tensor = torch.empty(len(dict_))

    # Iterate over the dict and for each key-value pair, convert the value to a PyTorch tensor and add it to the new tensor
    for i, (key, value) in enumerate(dict_.items()):
        tensor[i] = value

    # Return the new tensor
    return tensor

def get_class_weights(df_train):
    label_counts = df_train.target_label.value_counts().sort_index().to_dict()
    ratios_dict = {}
    for key,val in label_counts.items():
        ratios_dict[key] = val / df_train.shape[0]
    ratios_dict
    weights = {}
    sum_weights = 0
    for key, val in ratios_dict.items():
        weights[key] = 1 / val
        sum_weights +=  1 / val
    for key, val in weights.items():
        weights[key] = val / sum_weights
    weight_tensor = convert_dict_to_tensor(weights)
    return weight_tensor

In [9]:
def predict_val_dataset(model, CONFIG, df_validate, encoder, TRAIN_DIR=None, val_size=1.0):
    
    valid_dataset = UBCDataset(df_validate, transforms=data_transforms["valid"])
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                          num_workers=2, shuffle=False, pin_memory=True)

    preds = []
    labels_list = []
    output_list = []
    valid_acc = 0.0

    with torch.no_grad():
        bar = tqdm(enumerate(valid_loader), total=len(valid_loader))
        for step, data in bar: 
            # print(step)
            images = data['image'].to(CONFIG["device"], dtype=torch.float)
            labels = data['label'].to(CONFIG["device"], dtype=torch.long)

            batch_size = images.size(0)
            outputs = model(images)
            outputs = model.softmax(outputs)
            _, predicted = torch.max(outputs, 1)
            preds.append( predicted.detach().cpu().numpy() )
            labels_list.append(labels.detach().cpu().numpy())
            output_list.append(outputs.detach().cpu().numpy())
            acc = torch.sum( predicted == labels )
            valid_acc  += acc.item()
    valid_acc /= len(valid_loader.dataset)
    preds = np.concatenate(preds).flatten()
    labels_list = np.concatenate(labels_list).flatten()
    pred_labels = encoder.inverse_transform( preds )

    # Calculate Balanced Accuracy
    bal_acc = balanced_accuracy_score(labels_list, preds)
    # Calculate Confusion Matrix
    conf_matrix = confusion_matrix(labels_list, preds)
    macro_f1 = f1_score(labels_list, preds, average='macro')
    micro_f1 = f1_score(labels_list, preds, average='micro')
    weighted_f1 = f1_score(labels_list, preds, average='weighted')

    print(f"Validation Accuracy: {valid_acc}")
    print(f"Balanced Accuracy: {bal_acc}")
    print(f"Macro F1-Score: {macro_f1}")
    print(f"Micro F1-Score: {micro_f1}")
    print(f"Weighted F1-Score: {weighted_f1}")
    print(f"Confusion Matrix: {conf_matrix}")

    # add to validation dataframe
    df_validate["pred"] = preds
    df_validate["pred_labels"] = pred_labels
    return df_validate, preds, labels_list, output_list
