# Environment setup

In [None]:
import os
import gc
import cv2
import math
import copy
import time
import random
import glob
from typing import Dict, List
from matplotlib import pyplot as plt
import seaborn as sns
from random import sample
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import torchvision

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score

# For Image Models
import timm
from PIL import Image

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

from pathlib import Path

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(pow(2,60))
print(f"torch version {torch.__version__}") 
print(f'Torchvision version {torchvision.__version__}')

In [None]:
CONFIG = {
    "seed": 42,
    "img_size": 1024,
    "num_tiles": 4,
    "model_name": "effnet-th-tiles-25",
    "num_classes": 5,
    "valid_batch_size": 16,
    "test_batch_size": 1,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "train": False, # To train and save the model. Should be False when submitting
    "split_ratio": 0.2,
    "num_workers": os.cpu_count(),
    "epochs": 50,
    "sandbox": True, # True when finding optimal hyperparameters. Should be False when submitting.
}

In [None]:
ROOT_DIR = '/kaggle/input/processed-ubc-thumbnails'
TRAIN_DIR = '/kaggle/input/processed-ubc-thumbnails/train_thumbnails'

# Data processing

In [None]:
class UBCDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.file_names = df['file_path'].values
        self.labels = df['label'].values
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        folder_path = self.file_names[index]
        tiles_file_paths = os.listdir(folder_path)
        tiles = []
        for tile_path in tiles_file_paths:
            tile = plt.imread(tile_path)
            if self.transforms:
                tile = self.transforms(image=tile)["image"]
            tiles.append(tile)
            
        label = self.labels[index]
        
        return torch.tensor(tiles), torch.tensor(label, dtype=torch.long)
        

In [None]:
data_transforms = {
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.),
    "train": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        # A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
        A.GaussNoise(var_limit=[10, 50]),
        A.GaussianBlur(),
        A.MotionBlur(),
        ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(512* 0.3), max_height=int(512* 0.3), mask_fill_value=0, p=0.5),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.)
}

In [None]:
traindf = pd.read_csv('/kaggle/input/UBC-OCEAN/train.csv')
traindf

In [None]:
def get_train_file_path(image_id):
    if os.path.exists(f"{TRAIN_DIR}/{image_id}"):
        return f"{TRAIN_DIR}/{image_id}"
    else:
        return f"NO FILE"

In [None]:
traindf['file_path'] = traindf['image_id'].apply(get_train_file_path)
traindf

## Preparing the training data frame for training

In [None]:
traindf['label'][traindf['label']=="HGSC"] = 0
traindf['label'][traindf['label']=="EC"] = 1
traindf['label'][traindf['label']=="CC"] = 2
traindf['label'][traindf['label']=="LGSC"] = 3
traindf['label'][traindf['label']=="MC"] = 4
traindf = traindf[traindf['file_path'] != 'NO FILE']
traindf

In [None]:
train_dataset = UBCDataset(traindf, transforms=data_transforms["valid"]) # TODO: Add training transforms
train_dataloader = DataLoader(train_dataset, batch_size=CONFIG['valid_batch_size'], 
                          num_workers=CONFIG["num_workers"], shuffle=True, pin_memory=True)

## Training and validation split from the training set

In [None]:
# Split the data into training and validation sets
if CONFIG["sandbox"]:
    train_data, val_data = train_test_split(traindf, test_size=CONFIG["split_ratio"], random_state=CONFIG["seed"])

    train_dataset = UBCDataset(train_data, transforms=data_transforms["train"]) # TODO: Add training transforms
    train_dataloader = DataLoader(train_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=CONFIG["num_workers"], shuffle=True, pin_memory=True)

    val_dataset = UBCDataset(val_data, transforms=data_transforms["valid"]) 
    val_dataloader = DataLoader(val_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=CONFIG["num_workers"], shuffle=False, pin_memory=True)

# Train and save model

## Save model

In [None]:
def save_model(model: torch.nn.Module,
               target_dir: str,
               model_name: str):
    """Saves a PyTorch model to a target directory.
    
    Args:
        model: A target PyTorch model to save.
        target_dir: A directory for saving the model to.
        model_name: A filename for the saved model. Should include either ".pth" or ".pt" as the file extension.
        
    Example usage:
        save_model(model=model_0,
        targer_dir="models", 
        model_name="model_1")
    """

    # Create target directory
    target_dir_path = Path(target_dir)
    target_dir_path.mkdir(parents=True,
                          exist_ok=True)

    # Create model save path
    assert model_name.endswith(".pth") or model_name.endswith(".pt"),  "model_name should end with '.pt' or '.pth'"
    model_save_path = target_dir_path / model_name

    # Save the model state_dict()
    print(f"[INFO] Saving model to : {model_save_path}")
    torch.save(obj=model,
               f=model_save_path)


## Importing a pretrained model
This is done online before we turn off the internet access. 

#### Indentity module for removing last classifier layer

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [None]:
if CONFIG["train"] or CONFIG["sandbox"]:
    # 1. Get pretrained weights for ViT-base
    pretrained_weights = torchvision.models.EfficientNet_B7_Weights.DEFAULT
    #pretrained_vit_weights=torch.load('/kaggle/input/vit-weights/vit_b_16-c867db91.pth')
    # 2. Setup a ViT model instance with pretrained weights
    pretrained_model = torchvision.models.efficientnet_b7(weights=pretrained_weights).to(CONFIG["device"])

    # 3. Freeze the base parameters
    for parameter in pretrained_model.parameters():
        parameter.requires_grad = False
        parameter.to(CONFIG["device"])

    # 4. Change the classifier head
    # set_seed()
    # print(pretrained_model)

    pretrained_model.classifier = Identity()
    pretrained_model

    save_model(pretrained_model, # Needs to be saved as a dataset so that it works offline
             "/kaggle/working/",
             f"{CONFIG['model_name']}-pretrained.pth")

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), # Same convolution, input == output dim. No bias due to batchnorm
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class U_NET(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=[16, 32, 64, 128, 256]):
        super(U_NET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of U-NET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Up part of U-NET
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck= DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # Reversing the list

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape: # If we do inputs which are not divisible by 16 we need this
                x = TF.resize(x, size=skip_connection.shape[2:]) # Just taking out height and width, not batch size and channels
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)
        
        return self.final_conv(x)
UNET = U_NET()

UNET.load_state_dict(torch.load('/kaggle/input/ubc-u-et/UNET_dict_80e_split_1024_tiles.pth'))
# UNET = UNET.to(CONFIG["device"])
# test_vector = torch.randn((4, 3, 1024, 1024)).to(CONFIG["device"])
# print(test_vector.shape)
# test_vector = UNET(test_vector)
# print(test_vector.shape)

In [None]:
class CUT_OFF_UNET(nn.Module):
    def __init__(self, UNET):
        super(CUT_OFF_UNET, self).__init__()
        self.UNET = UNET
        self.conv = nn.Conv2d(256, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    
    def forward(self, x):
        for down in self.UNET.downs:
            x = down(x)
            x = self.UNET.pool(x)
        x = self.conv(x)
        x = x.reshape(3, 64, 64)
        return x
    
encoder = CUT_OFF_UNET(UNET).to(CONFIG["device"])
test_vector = torch.randn((4, 3, 1024, 1024)).to(CONFIG["device"])
print(encoder)
print(test_vector.shape)
encoder.eval()
test_vector = encoder(test_vector)
print(test_vector.shape)


In [None]:
# Feature extraction --> Dim reduction (Vilket backbone som helst)
# ([4, 3, 1024, 1024]) --> ([4, 256, 32, 32])

# Go back to three channels
# ([4, 256, 32, 32]) --> ([4, 3, 32, 32])

# Create feature cube
# ([4, 3, 32, 32]) --> ([3, 128, 128])



In [None]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool1d(x.clamp(min=eps).pow(p), x.size(-1)).pow(1./p) # Changed to 1d since we only want to pool over the tiles dimension
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [None]:
pooling = GeM()
test_vector = torch.randn((16, 25, 2560))
print(test_vector.permute(0, 2, 1).shape)
print(pooling(test_vector.permute(0, 2, 1)).squeeze(dim=-1).shape)

In [None]:
test_vector = torch.randn((16, 25, 2560))
print(test_vector.shape)
test_vector = test_vector.mean(dim=(-2))
print(test_vector.shape)

In [None]:
class SSFE(nn.Module):
    def __init__(self, feature_extractor, backbone, backbone_output_dim):
        super().__init__()
        self.feature_extractor = feature_extractor.to(CONFIG["device"])
        self.backbone = backbone.to(CONFIG["device"])
        self.head = nn.Sequential(
            nn.Dropout(0.8),
            nn.Linear(in_features=backbone_output_dim, out_features=1024),
            nn.ReLU(),
            nn.Dropout(0.8),
            nn.Linear(in_features=1024, out_features=128),
            nn.ReLU(),
            nn.Dropout(0.8),
            nn.Linear(in_features=128, out_features=CONFIG["num_classes"]),
        ).to(CONFIG["device"])
        
        def forward(self, x):
            batch_features = torch.tensor([]).to(CONFIG["device"])
            for image in x: # [tile1, tile2,...]
                tiles = torch.stack(image, dim=0) # (tile_number, tile)
                # features = self.feature_extractor(tiles)
                
                
                    
        

In [None]:

# model = SSFE()

In [None]:
class tilingModel(nn.Module):
    def __init__(self, backbone, n_tiles, tile_size, backbone_output_dim):
        super().__init__()
        self.backbone, self.n_tiles, self.tile_size = backbone.to(CONFIG["device"]), n_tiles, tile_size
        self.GeM_pooling = GeM()
        self.head = nn.Sequential(
            nn.Dropout(0.8),
            nn.Linear(in_features=backbone_output_dim, out_features=1024),
            nn.ReLU(),
            nn.Dropout(0.8),
            nn.Linear(in_features=1024, out_features=128),
            nn.ReLU(),
            nn.Dropout(0.8),
            nn.Linear(in_features=128, out_features=CONFIG["num_classes"]),
        ).to(CONFIG["device"])
        
    def forward(self, x): # TODO: Make this work for one batch at a time
        images = split_images_into_tiles(x, self.tile_size) # Split x into tiles Result: [Batch_size, n_tiles, channels, width, height]
        batch_features = torch.tensor([]).to(CONFIG["device"])
        
        for tile in range(self.n_tiles): # TODO: Skip completely black tiles? Or maybe trim tiles?
            tile_batch = images[:, tile, :, :, :] # Desired shape of [batch_size, channels, width, height]
            batch_features = torch.cat((batch_features, self.backbone(tile_batch).unsqueeze(1)), dim = 1)
        # batch_features = batch_features.mean(dim=(-2)) # Average pooling
        batch_features = self.GeM_pooling(batch_features.permute(0, 2, 1)).squeeze(dim=-1)
        x = self.head(batch_features)
        return(x)

if CONFIG["train"] or CONFIG["sandbox"]:                                                                        
    tilingmodel = tilingModel(pretrained_model, CONFIG["num_tiles"], 600, 2560).to(CONFIG["device"])

    img = plt.imread("/kaggle/input/UBC-OCEAN/train_thumbnails/10077_thumbnail.png")
    img = data_transforms["valid"](image=img)["image"]

    img2 = plt.imread("/kaggle/input/UBC-OCEAN/train_thumbnails/10143_thumbnail.png")
    img2 = data_transforms["valid"](image=img2)["image"]

    img_batch = torch.stack((img, img2))
    print(f"img_batch: {img_batch.shape}")
    output = tilingmodel(img_batch)
    print(f"Shape of output: {output.shape}")
    print(output)

## Training the model

In [None]:
"""
Contains functions for training and testing a PyTorch model.
"""
import torch

from tqdm.auto import tqdm
from typing import Dict, List, Tuple

def train_step(model: torch.nn.Module, 
               dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer,
               device: torch.device) -> Tuple[float, float]:
    """Trains a PyTorch model for a single epoch.

    Turns a target PyTorch model to training mode and then
    runs through all of the required training steps (forward
    pass, loss calculation, optimizer step).

    Args:
    model: A PyTorch model to be trained.
    dataloader: A DataLoader instance for the model to be trained on.
    loss_fn: A PyTorch loss function to minimize.
    optimizer: A PyTorch optimizer to help minimize the loss function.
    device: A target device to compute on (e.g. "cuda" or "cpu").

    Returns:
    A tuple of training loss and training accuracy metrics.
    In the form (train_loss, train_accuracy). For example:

    (0.1112, 0.8743)
    """
    # Put model in train mode
    model.train()

    # Setup train loss and train accuracy values
    train_loss, train_acc = 0, 0

    # Loop through data loader data batches
    for batch, (X, y) in enumerate(dataloader):
        # Send data to target device
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Calculate  and accumulate loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item() 

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        # Calculate and accumulate accuracy metric across all batches
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)

    # Adjust metrics to get average loss and accuracy per batch 
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

def test_step(model: torch.nn.Module, 
              dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module,
              device: torch.device) -> Tuple[float, float]:
    """Tests a PyTorch model for a single epoch.

    Turns a target PyTorch model to "eval" mode and then performs
    a forward pass on a testing dataset.

    Args:
    model: A PyTorch model to be tested.
    dataloader: A DataLoader instance for the model to be tested on.
    loss_fn: A PyTorch loss function to calculate loss on the test data.
    device: A target device to compute on (e.g. "cuda" or "cpu").

    Returns:
    A tuple of testing loss and testing accuracy metrics.
    In the form (test_loss, test_accuracy). For example:

    (0.0223, 0.8985)
    """
    # Put model in eval mode
    model.eval() 

    # Setup test loss and test accuracy values
    test_loss, test_acc, balanced_acc = 0, 0, 0
    pred_labels = []
    true_labels = []
    # Turn on inference context manager
    with torch.inference_mode():
        # Loop through DataLoader batches
        for batch, (X, y) in enumerate(dataloader):
            # Send data to target device
            X, y = X.to(device), y.to(device)

            # 1. Forward pass
            test_pred_logits = model(X)

            # 2. Calculate and accumulate loss
            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()

            # Calculate and accumulate accuracy
            test_pred_labels = test_pred_logits.argmax(dim=1)
            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
            pred_labels = pred_labels + test_pred_labels.cpu().tolist()
            true_labels = true_labels + y.cpu().tolist()

    # Adjust metrics to get average loss and accuracy per batch 
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    balanced_acc = balanced_accuracy_score(true_labels, pred_labels)
    return test_loss, test_acc, balanced_acc

def train(model: torch.nn.Module, 
          train_dataloader: torch.utils.data.DataLoader, 
          test_dataloader: torch.utils.data.DataLoader, 
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device) -> Dict[str, List]:
    """Trains and tests a PyTorch model.

    Passes a target PyTorch models through train_step() and test_step()
    functions for a number of epochs, training and testing the model
    in the same epoch loop.

    Calculates, prints and stores evaluation metrics throughout.

    Args:
    model: A PyTorch model to be trained and tested.
    train_dataloader: A DataLoader instance for the model to be trained on.
    test_dataloader: A DataLoader instance for the model to be tested on.
    optimizer: A PyTorch optimizer to help minimize the loss function.
    loss_fn: A PyTorch loss function to calculate loss on both datasets.
    epochs: An integer indicating how many epochs to train for.
    device: A target device to compute on (e.g. "cuda" or "cpu").

    Returns:
    A dictionary of training and testing loss as well as training and
    testing accuracy metrics. Each metric has a value in a list for 
    each epoch.
    In the form: {train_loss: [...],
                  train_acc: [...],
                  test_loss: [...],
                  test_acc: [...]} 
    For example if training for epochs=2: 
                 {train_loss: [2.0616, 1.0537],
                  train_acc: [0.3945, 0.3945],
                  test_loss: [1.2641, 1.5706],
                  test_acc: [0.3400, 0.2973]} 
    """
    # Create empty results dictionary
    results = {"train_loss": [],
                "train_acc": [],
                "test_loss": [],
                "test_acc": [],
                "balanced_acc": [],
                }

    # Loop through training and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                          dataloader=train_dataloader,
                                          loss_fn=loss_fn,
                                          optimizer=optimizer,
                                          device=device)
        test_loss, test_acc, balanced_acc = test_step(model=model,
          dataloader=test_dataloader,
          loss_fn=loss_fn,
          device=device)

        # Print out what's happening
        print(
                f"Epoch: {epoch+1} | "
                f"train_loss: {train_loss:.4f} | "
                f"train_acc: {train_acc:.4f} | "
                f"test_loss: {test_loss:.4f} | "
                f"test_acc: {test_acc:.4f} | "
                f"balanced_acc: {balanced_acc:.4f}"
        )

        # Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)
        results["balanced_acc"].append(balanced_acc)

    # Return the filled results at the end of the epochs
    return results

In [None]:
# Create optimizer and loss function
def train_model(train_dl, test_dl, model_name, store_model=True):
    tilingmodel = tilingModel(pretrained_model, 4, 600, 2560).to(CONFIG["device"])
    optimizer = torch.optim.Adam(params=tilingmodel.parameters(),
                                 lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    
    if test_dl is None:
        test_dl = train_dl
    # Train the classifier head of the pretrained ViT feature extractor model
    # set_seed()
    results = train(model=tilingmodel,
                   train_dataloader=train_dl,
                   test_dataloader=test_dl,
                   optimizer=optimizer,
                   loss_fn=loss_fn,
                   epochs=CONFIG["epochs"],
                   device=CONFIG["device"])
    if store_model:
        save_model(tilingmodel, # Needs to be saved as a dataset so that it works offline
                 "/kaggle/working/",
                 model_name)
        
    return results

In [None]:
if CONFIG["train"]:
    results = train_model(train_dl=train_dataloader, test_dl=train_dataloader, model_name=f"{CONFIG['model_name']}.pth")
elif CONFIG["sandbox"]:
    results = train_model(train_dl=train_dataloader, test_dl=val_dataloader, model_name=f"{CONFIG['model_name']}.pth")

In [None]:
def plot_loss_curves(results: Dict[str, List[float]]):
    """Plots training curves of a results dictionary.

    Args:
        results (dict): dictionary containing list of values, e.g.
            {"train_loss": [...],
             "train_acc": [...],
             "test_loss": [...],
             "test_acc": [...]}
    """

    # Get the loss values of the results dictionary (training and test)
    loss = results['train_loss']
    test_loss = results['test_loss']

    # Get the accuracy values of the results dictionary (training and test)
    accuracy = results['train_acc']
    test_accuracy = results['test_acc']

    # Figure out how many epochs there were
    epochs = range(len(results['train_loss']))

    # Setup a plot
    plt.figure(figsize=(15, 7))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, label='train_loss')
    plt.plot(epochs, test_loss, label='test_loss')
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy, label='train_accuracy')
    plt.plot(epochs, test_accuracy, label='test_accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.legend()
    plt.show()
    
    # Plot accuracy
    plt.subplot(1, 1, 1)
    plt.plot(epochs, results["balanced_acc"], label='balanced_acc')
    plt.title('Balanced Accuracy for validation data')
    plt.xlabel('Epochs')
    plt.legend()
    plt.show()

In [None]:
# Plot the loss curves
if CONFIG["train"] or CONFIG["sandbox"]:
    plot_loss_curves(results) 

In [None]:
def get_label(label):
    return class_names[label]

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

if CONFIG["sandbox"]:
    preds = []
    true = []
    with torch.no_grad():
        bar = tqdm(enumerate(val_dataloader), total=len(val_dataloader))
        for step, (data, y) in bar:        
            images = data.to(CONFIG["device"], dtype=torch.float)        
            batch_size = images.size(0)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            true.append(y.detach().cpu().numpy())
            preds.append(predicted.detach().cpu().numpy())
    preds = [get_label(item) for item in np.concatenate(preds).flatten()]
    true = [get_label(item) for item in np.concatenate(true).flatten()]

    # Compute the confusion matrix
    cm = confusion_matrix(true, preds)

    # Create a ConfusionMatrixDisplay with display_labels parameter
    cmd = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    cmd.plot(cmap=plt.cm.Blues)  # You can choose a different color map if needed
    plt.show()