# Introduction 

**This is a basic CNN Model training notebook**

It is based on: 
- Thumbnail images
- Basic data transformation (using Albumentation):
    - resizing images to 512x512
    - normalizing pixel values
- CNN Architecture


**Todos:**

- Learn about Dataset & DataLoader
- add augmentations (albumentation)
- gem pooling

In [None]:
!pip install --quiet torch_optimizer

In [None]:
!pip install --quiet mlflow dagshub

In [None]:


import os
import gc
import cv2
import datetime
import math
import copy
import time
import random
import glob
from matplotlib import pyplot as plt
from skimage import io


# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.cuda import amp
import torchvision
import torch_optimizer as torch_optimizer

import optuna
from optuna.trial import TrialState

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict


from PIL import Image
from joblib import Parallel, delayed
from tqdm.auto import tqdm

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, f1_score

# For Image Models
import timm

import dagshub
from getpass import getpass
import mlflow.pytorch 
from mlflow import MlflowClient

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings
# warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
os.environ["MLFLOW_TRACKING_USERNAME"]="Niggl0n"
os.environ["MLFLOW_TRACKING_PASSWORD"] = "7a3590e8c5558d4598dacc7810befa70a4baac9e"
os.environ['MLFLOW_TRACKING_PROJECTNAME'] = "UBC_Cancer_Classification"
#dagshub.auth.add_app_token("7a3590e8c5558d4598dacc7810befa70a4baac9e")
mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME'] + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')

In [None]:
def get_or_create_experiment_id(name):
    exp = mlflow.get_experiment_by_name(name)
    if exp is None:
        exp_id = mlflow.create_experiment(name)
        return exp_id
    return exp.experiment_id

mlflow_experiment_id = get_or_create_experiment_id(os.environ['MLFLOW_TRACKING_PROJECTNAME'])
mlflow_experiment_id

In [None]:
CONFIG = {
    "is_submission": False,
    "weighted_loss": True,
    "datetime_now": datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 
    "n_fold":5, 
    "test_fold": 0,
    "seed": 42,
    "img_size": 512,
    "model_name": "tf_efficientnetv2_s_in21ft1k",   # "tf_efficientnet_b0_ns", # "tf_efficientnetv2_s_in21ft1k"
    "checkpoint_path": "/kaggle/input/tf-efficientnet-b0-aa-827b6e33-pth/tf_efficientnet_b0_aa-827b6e33.pth",
    "num_classes": 5,
    "train_batch_size": 8,
    "valid_batch_size": 8,
    "n_tiles": 10,
    "n_tiles_test": 10,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "num_epochs": 15,
    "early_stopping": True,
    "patience": 6,
    "optimizer": 'adam',
    "scheduler": 'CosineAnnealingLR',
    "min_lr": 1e-6,
    "T_max": 10,
    "momentum": 0.9,
    "weight_decay": 1e-4,
}

## 1. Data Preparation

In [None]:
ROOT_DIR = '/kaggle/input/UBC-OCEAN'
TRAIN_DIR = '/kaggle/input/tiles-of-cancer-2048px-scale-0-25/'
TEST_DIR = '/kaggle/input/UBC-OCEAN/test_thumbnails'

# ALT_TEST_DIR = '/kaggle/input/UBC-OCEAN/test_images'
# TMA_TRAIN_DIR = '/kaggle/input/UBC-OCEAN/train_images'

def get_train_file_path(df_train_row):
    return f"{TRAIN_DIR}{df_train_row.image_id}.png"

def get_test_file_path(image_id):
    if os.path.exists(f"{TEST_DIR}/{image_id}_thumbnail.png"):
        return f"{TEST_DIR}/{image_id}_thumbnail.png"
    else:
        return f"{ALT_TEST_DIR}/{image_id}.png"



In [None]:
train_images = sorted(glob.glob(f"{TRAIN_DIR}/*.png"))
df_train = pd.read_csv("/kaggle/input/UBC-OCEAN/train.csv")
print(df_train.shape)
df_train['file_path'] = df_train.apply(lambda row: get_train_file_path(row), axis=1)
# only consider WSI / Thumbnail images
#df_train = df_train[ 
#    df_train["file_path"].isin(train_images) ].reset_index(drop=True)
print(df_train.shape)

# encode to numericalt target
encoder = LabelEncoder()
df_train['target_label'] = encoder.fit_transform(df_train['label'])

# save encoder
with open("label_encoder_"+ CONFIG["datetime_now"] +".pkl", "wb") as fp:
    joblib.dump(encoder, fp)
    
# use stratified K Fold for crossvalidation 
skf = StratifiedKFold(n_splits=CONFIG['n_fold'], shuffle=True, random_state=CONFIG["seed"])

for fold, ( _, val_) in enumerate(skf.split(X=df_train, y=df_train.target_label)):
    df_train.loc[val_ , "kfold"] = int(fold)
display(df_train.head())

# separate train and test dataset
df_test = df_train[df_train["kfold"]==CONFIG["test_fold"]].reset_index(drop=True)
df_train = df_train[df_train["kfold"]!=CONFIG["test_fold"]].reset_index(drop=True)
print(f"Shape df_train: {df_train.shape}, Shape df_test: {df_test.shape} ")

In [None]:
df_train.loc[0, "file_path"]

In [None]:
"""
def _color_means(img_path):
    img = np.array(Image.open(img_path))
    mask = np.sum(img[..., :3], axis=2) == 0
    img[mask, :] = 255
    if np.max(img) > 1.5:
        img = img / 255.0
    clr_mean = {i: np.mean(img[..., i]) for i in range(3)}
    clr_std = {i: np.std(img[..., i]) for i in range(3)}
    return clr_mean, clr_std

# os.path.join(DATASET_SMALL_FOLDER, "train_images")
ls_images = glob.glob(os.path.join(TRAIN_DIR, "*", "*.png"))
clr_mean_std = Parallel(n_jobs=os.cpu_count())(delayed(_color_means)(fn) for fn in tqdm(ls_images[:9000]))

img_color_mean = pd.DataFrame([c[0] for c in clr_mean_std]).describe()
display(img_color_mean.T)
img_color_std = pd.DataFrame([c[1] for c in clr_mean_std]).describe()
display(img_color_std.T)

img_color_mean = list(img_color_mean.T["mean"])
img_color_std = list(img_color_std.T["mean"])
print(f"{img_color_mean=}\n{img_color_std=}")
"""

## histogram matching 
#from skimage.exposure import match_histograms
#ref_img = np.array(Image.open("/kaggle/input/tiles-of-cancer-2048px-scale-0-25/10077/000067_16-3.png"))
#bef_img = np.array(Image.open("/kaggle/input/tiles-of-cancer-2048px-scale-0-25/12522/000028_6-2.png"))
#start = time.time()
#aft_img = match_histograms(bef_img, ref_img, channel_axis=-1)
#print(time.time()-start)



In [None]:
class CancerTilesDataset(Dataset):
    def __init__(
        self,
        df_data,
        path_img_dir: str =  '',
        transforms = None,
        mode: str = 'train',
        labels_lut = None,
        white_thr: int = 225,
        thr_max_bg: float = 0.2,
        split: float = 0.90,
        n_tiles: int = 1
    ):
        assert os.path.isdir(path_img_dir)
        self.path_img_dir = path_img_dir
        self.transforms = transforms
        self.mode = mode
        self.white_thr = white_thr
        self.thr_max_bg = thr_max_bg
        self.split = split
        self.n_tiles = n_tiles

        self.data = df_data
        self.labels_unique = sorted(self.data["label"].unique())
        self.labels_lut = labels_lut or {lb: i for i, lb in enumerate(self.labels_unique)}
        # shuffle data
        self.data = self.data.sample(frac=1, random_state=42).reset_index(drop=True)

        # split dataset
        assert 0.0 <= self.split <= 1.0
        frac = int(self.split * len(self.data))
        self.data = self.data[:frac] if mode in ["train", "test"] else self.data[frac:]
        self.img_dirs = [glob.glob(os.path.join(path_img_dir, str(idx), "*.png")) for idx in self.data["image_id"]] 
        self.img_dirs = self.img_dirs * self.n_tiles
        self.img_paths = []
        #print(f"missing: {sum([not os.path.isfile(os.path.join(self.path_img_dir, im))
        #                       for im in self.img_names])}")
        # self.labels = list(self.data['label'])
        self.labels =  np.array(self.data.target_label.values.tolist() * self.n_tiles)

    def __getitem__(self, idx: int) -> tuple:
        nth_iteration = idx//len(self.data)
        if self.mode=="train":
            random.seed()
        else:
            random.seed(CONFIG["seed"]+nth_iteration)
        random.shuffle(self.img_dirs[idx])
        for img_path in self.img_dirs[idx]:
            assert os.path.isfile(img_path), f"missing: {img_path}"
            tile = cv2.imread(img_path)
            tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB)
        
            # tile = np.array(Image.open(img_path))[..., :3]
            black_bg = np.sum(tile, axis=2) == 0
            tile[black_bg, :] = 255
            mask_bg = np.mean(tile, axis=2) > self.white_thr
            if np.sum(mask_bg) < (np.prod(mask_bg.shape) * self.thr_max_bg):
                #self.img_paths.append(img_path)
                #print(f"Idx: {idx}, Path: {img_path}, len img_pths: {len(self.img_paths)}, nunique img_paths: {len(set(self.img_paths))}")
                break

        if self.transforms:
            tile = self.transforms(image=tile)["image"]
        #print(f"img dim: {img.shape}")
        return {
            "image": tile,
            "label": torch.tensor(self.labels[idx], dtype=torch.long)
               }
    def __len__(self) -> int:
        return len(self.img_dirs)


In [None]:
img_color_mean=[0.8661704276539922, 0.7663107094675368, 0.8574260897185548]
img_color_std=[0.08670629753900036, 0.11646580094195522, 0.07164169171856792]

data_transforms = {
    "train": A.Compose([
        A.Resize(512, 512),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
        A.GaussNoise(var_limit=[10, 50]),
        A.GaussianBlur(),
        A.MotionBlur(),
        ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(512* 0.3), max_height=int(512* 0.3),
        mask_fill_value=0, p=0.5),
        A.Normalize(img_color_mean, img_color_std), 
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(img_color_mean, img_color_std), 
        ToTensorV2()], p=1.)
}

"""        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),"""

## 2. Model Creation

In [None]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'


class UBCModel(nn.Module):

    def __init__(self, model_name, num_classes, pretrained=False, checkpoint_path=None):
        '''
        Fine tune for EfficientNetB0
        Args
            n_classes : int - Number of classification categories.
            learnable_modules : tuple - Names of the modules to fine-tune.
        Return
            
        '''
        super(UBCModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path=checkpoint_path)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.linear = nn.Linear(in_features, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, images):
        """
        Forward function for the fine-tuned model
        Args
            x: 
        Return
            result
        """
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        output = self.linear(pooled_features)
        return output


## 3. Training

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model to path {self.path}')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [None]:
def fetch_scheduler(optimizer):
    if CONFIG['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'CosineAnnealingWarmRestarts':
        CONFIG['T_0'] = 20
        CONFIG['T_mult'] = 2
        CONFIG['min_lr'] = 1e-6
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CONFIG['T_0'], T_mult=CONFIG['T_mult'],
                                                             eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'ReduceLROnPlateau':
        scheduler =  ReduceLROnPlateau(optimizer, mode='min', factor=kwargs.get('factor', 0.1), patience=kwargs.get('patience', 5), verbose=False)
    elif CONFIG['scheduler'] == 'LambdaLR':
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda)
    elif CONFIG['scheduler'] == None:
        return None
        
    return scheduler

def get_optimizer(optimizer_name, model):
    if optimizer_name.lower() == "adam":
        CONFIG['learning_rate'] = 1e-4
        CONFIG['weight_decay'] = 1e-5
        CONFIG['betas'] = (0.9, 0.999)
        CONFIG['eps'] = 1e-8
        optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], betas=CONFIG['betas'], eps=CONFIG['eps'],  weight_decay=CONFIG['weight_decay'])
    elif optimizer_name.lower() == "sgd":
        CONFIG['learning_rate'] = 1e-3
        CONFIG['weight_decay'] = 1e-3
        CONFIG['momentum'] = 1e-3
        optimizer = optim.SGD(model.parameters(), lr=CONFIG['learning_rate'], momentum=CONFIG['momentum'], weight_decay=CONFIG['weight_decay'])
    elif optimizer_name.lower() == "radam":
        CONFIG['learning_rate'] = 1e-4
        CONFIG['weight_decay'] = 0
        CONFIG['betas'] = (0.9, 0.999)
        CONFIG['eps'] = 1e-8
        optimizer = torch_optimizer.RAdam(
            model.parameters(),
            lr= CONFIG['learning_rate'],
            betas=CONFIG['betas'],
            eps=CONFIG['eps'],
            weight_decay=CONFIG['weight_decay'],
        )
    elif optimizer_name.lower() == "rmsprop":
        CONFIG['learning_rate'] = 0.256
        CONFIG['alpha'] = 0.9
        CONFIG['momentum'] = 0.9
        CONFIG['weight_decay'] = 1e-5
        optimizer = optim.RMSprop(model.parameters(), lr=CONFIG['learning_rate'], alpha=CONFIG['learning_rate'], 
                                  momentum=CONFIG['learning_rate'], weight_decay=CONFIG['learning_rate'])
    else:
        raise ValueError("Invalid Optimizer given!")
    return optimizer
    

In [None]:
def convert_dict_to_tensor(dict_):
    """Converts the values of a dict into a PyTorch tensor."""

    # Create a new PyTorch tensor
    tensor = torch.empty(len(dict_))

    # Iterate over the dict and for each key-value pair, convert the value to a PyTorch tensor and add it to the new tensor
    for i, (key, value) in enumerate(dict_.items()):
        tensor[i] = value

    # Return the new tensor
    return tensor

def get_class_weights(df_train):
    label_counts = df_train.target_label.value_counts().sort_index().to_dict()
    ratios_dict = {}
    for key,val in label_counts.items():
        ratios_dict[key] = val / df_train.shape[0]
    ratios_dict
    weights = {}
    sum_weights = 0
    for key, val in ratios_dict.items():
        weights[key] = 1 / val
        sum_weights +=  1 / val
    for key, val in weights.items():
        weights[key] = val / sum_weights
    weight_tensor = convert_dict_to_tensor(weights)
    return weight_tensor

def get_dataloaders(df, n_tiles=1):
    # df_train = df[df["kfold"]!=fold].reset_index(drop=True)
    train_dataset = CancerTilesDataset(df_train, TRAIN_DIR, transforms=data_transforms["train"], mode="train", n_tiles=n_tiles)
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    valid_dataset = CancerTilesDataset(df_train, TRAIN_DIR, transforms=data_transforms["valid"], mode="valid", n_tiles=n_tiles)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    print(f"Len Train Dataset: {len(train_dataset)}, Len Validation Dataset: {len(valid_dataset)}" )
    return train_loader, valid_loader, df_train

def print_logged_info(r):
    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print(f"run_id: {r.info.run_id}")
    print(f"artifacts: {artifacts}")
    print(f"params: {r.data.params}")
    print(f"metrics: {r.data.metrics}")
    print(f"tags: {tags}")


In [None]:
def train_one_epoch(model, train_loader, optimizer, criterion, device, writer, epoch, scheduler=None):
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
    model.train()
    train_loss = 0.0
    bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, data in bar:
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)

        optimizer.zero_grad()
        outputs = model(images)
        
        # crossentropy loss
        loss = criterion(outputs, labels)
        # Focal Loss
        #criterion = FocalLoss(gamma=0.7)
        #m = torch.nn.Softmax(dim=-1)
        #loss = criterion(m(outputs), labels)
        
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
    # Update learning rate using the scheduler
    if scheduler:
        scheduler.step()
        
        # Log the training loss to TensorBoard
        writer.add_scalar('loss/train_batch', loss.item(), epoch * len(train_loader) + step)
    
    train_loss /= len(train_loader.dataset)
    # Log the average training loss for the epoch to TensorBoard
    writer.add_scalar('loss/train_epoch', train_loss, epoch)
    # gc.collect()
    return train_loss

def validate_one_epoch(model, valid_loader, criterion, device, writer, epoch):
    model.eval()
    valid_loss = 0.0
    valid_acc = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        bar_val = tqdm(enumerate(valid_loader), total=len(valid_loader))
        for step, data in bar_val:
            images = data['image'].to(device, dtype=torch.float)
            labels = data['label'].to(device, dtype=torch.long)
            outputs = model(images)
            
            # crossentropy loss
            loss = criterion(outputs, labels)
            # Focal Loss
            #criterion = FocalLoss(gamma=0.7)
            #m = torch.nn.Softmax(dim=-1)
            #loss = criterion(m(outputs), labels)
        
            valid_loss += loss.item() * images.size(0)
            _, predicted = torch.max(model.softmax(outputs), 1)
            acc = torch.sum( predicted == labels )
            valid_acc  += acc.item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
            writer.add_scalar('loss/valid_batch', loss.item(), epoch * len(valid_loader) + step)
            writer.add_scalar('acc/valid_batch', acc.item(), epoch * len(valid_loader) + step)
    valid_loss /= len(valid_loader.dataset)
    valid_acc /= len(valid_loader.dataset)
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    # At the end of your validation loop:
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    micro_f1 = f1_score(all_labels, all_preds, average='micro')
    weighted_f1 = f1_score(all_labels, all_preds, average='weighted')

    # Logging to TensorBoard
    writer.add_scalar('loss/val_epoch', valid_loss, epoch)
    writer.add_scalar('acc/val_epoch', valid_acc, epoch)
    writer.add_scalar('balanced_acc/val_epoch', bal_acc, epoch)
    writer.add_scalar('F1/macro', macro_f1, epoch)
    writer.add_scalar('F1/micro', micro_f1, epoch)
    writer.add_scalar('F1/weighted', weighted_f1, epoch)
    # in order to put multiple lines within one graph
    #writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
    #                        'xcosx':i*np.cos(i/r),
    #                        'tanx': np.tan(i/r)}, i)
    return valid_loss, valid_acc, bal_acc, weighted_f1

def train_model(model, train_loader, valid_loader, optimizer, criterion, device, num_epochs, scheduler, save_model_path=None):
    model_name = "model_epochs" + str(CONFIG["num_epochs"]) + "_bs"+str(CONFIG["train_batch_size"] )+ "_opt" +CONFIG["optimizer"]+ "_sched" + CONFIG["scheduler"] + "_lr"+str(CONFIG["learning_rate"])+ "_wd" + str(CONFIG["weight_decay"])
    print(f"Training model: {model_name}")
    datetime_now =  datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    if not save_model_path:
        save_model_path = 'best_model_checkpoint' + datetime_now + '.pth'
    print(f"Path for saving model: {save_model_path}")
    # Initialize TensorBoard writer
    writer = SummaryWriter('logs/fit/' + model_name)
    early_stopping = EarlyStopping(patience=CONFIG["patience"], verbose=True, path=save_model_path)
    #if scheduler_type:
    #    print(f"Define {scheduler_type} scheduler")
    #    scheduler = get_lr_scheduler(optimizer, scheduler_type, num_epochs=num_epochs)
    
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, writer, epoch, scheduler)
        valid_loss, valid_acc, bal_acc, weighted_f1 = validate_one_epoch(model, valid_loader, criterion, device, writer, epoch)
        print(f"Epoch {epoch+1}/{num_epochs} - Train loss: {train_loss:.4f}, Validation loss: {valid_loss:.4f}, Validation acc: {valid_acc:.4f}, Balanced acc: {bal_acc:.4f}, Weighted F1-Score: {weighted_f1:.4f}")
        # Call early stopping
        if CONFIG["early_stopping"]:
            early_stopping(valid_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break
        writer.close()

        try:
            mlflow.log_metrics({
                'epoch': epoch,
                'train_loss': train_loss,
                'valid_loss': valid_loss,
                'valid_acc': valid_acc,
                'balanced_acc': bal_acc,
                'weighted_f1': weighted_f1
            }, step=epoch)
        except: 
            pass
            
    return train_loss, valid_loss, valid_acc, save_model_path
    # Load the last checkpoint with the best model
    #model.load_state_dict(torch.load('best_model_checkpoint.pth'))



In [None]:
def test_on_holdout(model, CONFIG, df_test, TRAIN_DIR=None, val_size=1.0, n_tiles=1):
    if not CONFIG["is_submission"]:
        model.eval()
        test_dataset = CancerTilesDataset(df_test, TRAIN_DIR, transforms=data_transforms["valid"], mode="test", split=1.0, n_tiles=n_tiles)
        test_loader = DataLoader(test_dataset, batch_size=CONFIG['valid_batch_size'], 
                                  num_workers=2, shuffle=False, pin_memory=True)
        print(f"Test-Dataset Size: {len(test_dataset)}")

        preds = []
        labels_list = []
        test_acc = 0.0

        with torch.no_grad():
            bar = tqdm(enumerate(test_loader), total=len(test_loader))
            for step, data in bar: 
                # print(step)
                images = data['image'].to(CONFIG["device"], dtype=torch.float)
                labels = data['label'].to(CONFIG["device"], dtype=torch.long)

                batch_size = images.size(0)
                outputs = model(images)
                _, predicted = torch.max(model.softmax(outputs), 1)
                preds.append(predicted.detach().cpu().numpy() )
                labels_list.append(labels.detach().cpu().numpy() )
                acc = torch.sum(predicted == labels )
                test_acc  += acc.item()
        test_acc /= len(test_loader.dataset)
        preds = np.concatenate(preds).flatten()
        labels_list = np.concatenate(labels_list).flatten()
        pred_labels = encoder.inverse_transform( preds )
        
        # Calculate Balanced Accuracy
        bal_acc = balanced_accuracy_score(labels_list, preds)
        # Calculate Confusion Matrix
        conf_matrix = confusion_matrix(labels_list, preds)
        macro_f1 = f1_score(labels_list, preds, average='macro')

    
        print(f"Test Accuracy: {test_acc}")
        print(f"Balanced Accuracy: {bal_acc}")
        print(f"Confusion Matrix: {conf_matrix}")
        
        # add to validation dataframe
        num_samples = len(df_test)
        for i in range(0,n_tiles):
            df_test[f"label_tile_{str(i)}"] = labels_list[i*num_samples:(i+1)*num_samples]
            df_test[f"pred_tile_{str(i)}"] = preds[i*num_samples:(i+1)*num_samples]
            df_test[f"pred_label_tile_{str(i)}"] = pred_labels[i*num_samples:(i+1)*num_samples]
            #df_test["pred"] = preds
            #df_test["pred_labels"] = pred_labels
        try: 
            mlflow.log_metrics({
                'test_acc': test_acc,
                'test_balanced_acc': bal_acc,
                'test_f1_score': macro_f1,
            })
        except: 
            pass
        return df_test
    else:
        print("Skip validation on training set due to submission!")
        return None

In [None]:
if CONFIG["weighted_loss"]:
    class_weights = get_class_weights(df_train).to(CONFIG['device'], dtype=torch.float)
    print(f"Class weights: {class_weights}")
else:
    class_weights=None
criterion = nn.CrossEntropyLoss(weight=class_weights)

In [None]:
print(f"Shape df_train: {df_train.shape}, Shape df_test: {df_test.shape}")
with mlflow.start_run(experiment_id=mlflow_experiment_id) as run:
    train_loader, valid_loader, df_train_fold = get_dataloaders(df_train.copy(), n_tiles=CONFIG["n_tiles"])

    model = UBCModel(CONFIG['model_name'], CONFIG['num_classes'], pretrained=False , checkpoint_path=CONFIG["checkpoint_path"])
    # model.load_state_dict(torch.load(CONFIG["checkpoint_path"]))
    model.to(CONFIG['device']);

    optimizer = get_optimizer(CONFIG["optimizer"], model)
    scheduler = fetch_scheduler(optimizer)

    _, _, _, save_model_path = train_model(model, train_loader, valid_loader, optimizer, criterion, CONFIG["device"], CONFIG["num_epochs"], scheduler)
    model.load_state_dict(torch.load(save_model_path))

    
    print("Validate on Holdout Set:")
    df_test = test_on_holdout(model, CONFIG, df_test, TRAIN_DIR, val_size=1, n_tiles=CONFIG["n_tiles_test"])
    df_test_file_path = "df_test_results.csv"
    df_test.to_csv(df_test_file_path, index=False)
    try: 
        mlflow.log_params(CONFIG)
        mlflow.pytorch.log_model(model, "model")
        mlflow.log_params(save_model_path)
        mlflow.log_artifact(df_test_file_path)
        print_logged_info(mlflow.get_run(run_id=run.info.run_id))
    except:
        pass



In [None]:
# model.load_state_dict(torch.load('/kaggle/working/best_model_checkpoint' + CONFIG["datetime_now"] + '.pth'))
# df_test = test_on_holdout(model, CONFIG, df_test, TRAIN_DIR, val_size=1)
# df_test

In [None]:
model = UBCModel(CONFIG['model_name'], CONFIG['num_classes'], pretrained=False , checkpoint_path=None)
model.load_state_dict(torch.load("/kaggle/input/effnet-version-28/best_model_checkpoint2023-11-21_15-47-39.pth"))
model.to(CONFIG['device']);

In [None]:
df_test = test_on_holdout(model, CONFIG, df_test, TRAIN_DIR, val_size=1, n_tiles=CONFIG["n_tiles_test"])
