# Introduction 

**This is a basic CNN Model training notebook**

It is based on: 
- Thumbnail images
- Basic data transformation (using Albumentation):
    - resizing images to 512x512
    - normalizing pixel values
- CNN Architecture


**Todos:**

- Learn about Dataset & DataLoader
- add augmentations (albumentation)
- gem pooling

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input/tiles-of-cancer-2048px-scale-0-25'):
    for filename in filenames:
        # print(os.path.join(dirname, filename))
        continue



# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:


import os
import gc
import cv2
import datetime
import math
import copy
import time
import random
import glob
from matplotlib import pyplot as plt
from skimage import io


# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.cuda import amp
import torchvision

import optuna
from optuna.trial import TrialState

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict



from PIL import Image
from joblib import Parallel, delayed
from tqdm.auto import tqdm

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, f1_score

# For Image Models
import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings
# warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"



In [3]:
def count_files_and_folders(directory):
    n_folders = 0
    n_files = 0

    # Walk through all subdirectories and files in the directory
    for root, dirs, files in os.walk(directory):
        n_folders += len(dirs)
        n_files += len(files)

    return n_folders, n_files

# Specify the directory path
directory_path = "/kaggle/input/tiles-of-cancer-2048px-scale-0-25"

# Call the function and get the counts
n_folders, n_files = count_files_and_folders(directory_path)

print(f"Number of folders: {n_folders}")
print(f"Number of files: {n_files}")

Number of folders: 538
Number of files: 123715


In [4]:
!pip install --quiet mlflow dagshub

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ydata-profiling 4.3.1 requires dacite>=1.8, but you have dacite 1.6.0 which is incompatible.
ydata-profiling 4.3.1 requires scipy<1.11,>=1.4.1, but you have scipy 1.11.2 which is incompatible.[0m[31m
[0m

In [5]:
import dagshub
from getpass import getpass
import mlflow.pytorch 
from mlflow import MlflowClient

In [6]:
os.environ["MLFLOW_TRACKING_USERNAME"]="Niggl0n"
os.environ["MLFLOW_TRACKING_PASSWORD"] = "7a3590e8c5558d4598dacc7810befa70a4baac9e"
os.environ['MLFLOW_TRACKING_PROJECTNAME'] = "UBC_Cancer_Classification"
#dagshub.auth.add_app_token("7a3590e8c5558d4598dacc7810befa70a4baac9e")
mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME'] + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')

In [7]:
def get_or_create_experiment_id(name):
    exp = mlflow.get_experiment_by_name(name)
    if exp is None:
        exp_id = mlflow.create_experiment(name)
        return exp_id
    return exp.experiment_id

mlflow_experiment_id = get_or_create_experiment_id(os.environ['MLFLOW_TRACKING_PROJECTNAME'])
mlflow_experiment_id

'1'

In [8]:
def get_random_image_file_path(directory):
    # Get all subdirectories
    subdirs = [subdir for subdir, dirs, files in os.walk(directory) if files]
    # Choose a random subdirectory
    random_subdir = random.choice(subdirs)
    # Get all files in that subdirectory
    files = os.listdir(random_subdir)
    # Choose a random file
    random_file = random.choice(files)
    # Return the full path to the file
    return os.path.join(random_subdir, random_file)

# Specify the directory path
directory_path = "/kaggle/input/tiles-of-cancer-2048px-scale-0-25"

# Get a random file path
random_image_path = get_random_image_file_path(directory_path)

# Read the image using OpenCV
image = cv2.imread(random_image_path)

# Print the shape of the image
print(f"The shape of the image is: {image.shape}")

The shape of the image is: (512, 512, 3)


In [9]:
CONFIG = {
    "is_submission": False,
    "crop_vertical": True,
    "weighted_loss": True,
    "datetime_now": datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 
    "n_fold": 5,
    "test_fold": 0,
    "fold": 1,
    "seed": 42,
    "img_size": 512,
    'center_crop_size': 1024,
    "model_name": "tf_efficientnet_b0_ns",
    "num_classes": 5,
    "train_batch_size": 16,
    "valid_batch_size": 16,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "num_epochs": 30,
    "early_stopping": True,
    "patience": 7,
    "optimizer": 'Adam',
    "learning_rate": 1e-4,
    "scheduler": 'CosineAnnealingLR',
    "min_lr": 1e-6,
    "T_max": 30,
    "momentum": 0.9,
    "weight_decay": 1e-4,
}

## 1. Data Preparation

In [10]:
ROOT_DIR = '/kaggle/input/UBC-OCEAN'
TRAIN_DIR = '/kaggle/input/tiles-of-cancer-2048px-scale-0-25/'
TEST_DIR = '/kaggle/input/UBC-OCEAN/test_thumbnails'

# ALT_TEST_DIR = '/kaggle/input/UBC-OCEAN/test_images'
# TMA_TRAIN_DIR = '/kaggle/input/UBC-OCEAN/train_images'

def get_train_file_path(df_train_row):
    return f"{TRAIN_DIR}/{df_train_row.image_id}_thumbnail.png"

def get_test_file_path(image_id):
    if os.path.exists(f"{TEST_DIR}/{image_id}_thumbnail.png"):
        return f"{TEST_DIR}/{image_id}_thumbnail.png"
    else:
        return f"{ALT_TEST_DIR}/{image_id}.png"



In [11]:
train_images = sorted(glob.glob(f"{TRAIN_DIR}/*.png"))
df_train = pd.read_csv("/kaggle/input/UBC-OCEAN/train.csv")
print(df_train.shape)
df_train['file_path'] = df_train.apply(lambda row: get_train_file_path(row), axis=1)
# only consider WSI / Thumbnail images
#df_train = df_train[ 
#    df_train["file_path"].isin(train_images) ].reset_index(drop=True)
print(df_train.shape)

# encode to numericalt target
encoder = LabelEncoder()
df_train['target_label'] = encoder.fit_transform(df_train['label'])

# save encoder
with open("label_encoder_"+ CONFIG["datetime_now"] +".pkl", "wb") as fp:
    joblib.dump(encoder, fp)
    
# use stratified K Fold for crossvalidation 
skf = StratifiedKFold(n_splits=CONFIG['n_fold'], shuffle=True, random_state=CONFIG["seed"])

for fold, ( _, val_) in enumerate(skf.split(X=df_train, y=df_train.target_label)):
    df_train.loc[val_ , "kfold"] = int(fold)
df_train.head()

(538, 5)
(538, 6)


Unnamed: 0,image_id,label,image_width,image_height,is_tma,file_path,target_label,kfold
0,4,HGSC,23785,20008,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,2,3.0
1,66,LGSC,48871,48195,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,3,2.0
2,91,HGSC,3388,3388,True,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,2,4.0
3,281,LGSC,42309,15545,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,3,2.0
4,286,EC,37204,30020,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,1,2.0


In [12]:
df_test = pd.read_csv("/kaggle/input/UBC-OCEAN/test.csv")
df_test['file_path'] = df_test['image_id'].apply(get_test_file_path)
df_test["target_label"] = 0 
df_test

Unnamed: 0,image_id,image_width,image_height,file_path,target_label
0,41,28469,16987,/kaggle/input/UBC-OCEAN/test_thumbnails/41_thu...,0


In [13]:
class CancerTilesDataset(Dataset):
    def __init__(
        self,
        df_data,
        path_img_dir: str =  '',
        transforms = None,
        mode: str = 'train',
        labels_lut = None,
        white_thr: int = 225,
        thr_max_bg: float = 0.2,
        split: float = 0.90
    ):
        assert os.path.isdir(path_img_dir)
        self.path_img_dir = path_img_dir
        self.transforms = transforms
        self.mode = mode
        self.white_thr = white_thr
        self.thr_max_bg = thr_max_bg
        self.split = split

        self.data = df_data
        self.labels_unique = sorted(self.data["label"].unique())
        self.labels_lut = labels_lut or {lb: i for i, lb in enumerate(self.labels_unique)}
        # shuffle data
        self.data = self.data.sample(frac=1, random_state=42).reset_index(drop=True)

        # split dataset
        assert 0.0 <= self.split <= 1.0
        frac = int(self.split * len(self.data))
        self.data = self.data[:frac] if mode in ["train", "test"] else self.data[frac:]
        self.img_dirs = [glob.glob(os.path.join(path_img_dir, str(idx), "*.png")) for idx in self.data["image_id"]]
        #print(f"missing: {sum([not os.path.isfile(os.path.join(self.path_img_dir, im))
        #                       for im in self.img_names])}")
        # self.labels = list(self.data['label'])
        self.labels =  self.data.target_label.values

    def __getitem__(self, idx: int) -> tuple:
        random.shuffle(self.img_dirs[idx])
        for img_path in self.img_dirs[idx]:
            assert os.path.isfile(img_path), f"missing: {img_path}"
            tile = cv2.imread(img_path)
            tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB)
        
            # tile = np.array(Image.open(img_path))[..., :3]
            black_bg = np.sum(tile, axis=2) == 0
            tile[black_bg, :] = 255
            mask_bg = np.mean(tile, axis=2) > self.white_thr
            if np.sum(mask_bg) < (np.prod(mask_bg.shape) * self.thr_max_bg):
                break

        # augmentation
        if self.transforms:
            tile = self.transforms(image=tile)["image"]
        #print(f"img dim: {img.shape}")
        return {
            "image": tile,
            "label": torch.tensor(self.labels[idx], dtype=torch.long)
               }

    def __len__(self) -> int:
        return len(self.data)

In [14]:
data_transforms = {
    "train": A.Compose([
        A.RandomResizedCrop(CONFIG['img_size'], CONFIG['img_size'], scale=(0.8, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.ShiftScaleRotate(shift_limit=0.125, scale_limit=0.2, rotate_limit=15, p=0.5),
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
        A.CoarseDropout(p=0.2),
        A.Cutout(p=0.2),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        ToTensorV2()], p=1.)
}



## 2. Model Creation

In [15]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'


class EfficientNetB5(nn.Module):
    '''
    EfficientNet B0 fine-tune.
    '''
    def __init__(self, model_name, num_classes, pretrained=False, checkpoint_path=None):
        '''
        Fine tune for EfficientNetB0
        Args
            n_classes : int - Number of classification categories.
            learnable_modules : tuple - Names of the modules to fine-tune.
        Return
            
        '''
        super(EfficientNetB5, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path=checkpoint_path)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.linear = nn.Linear(in_features, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, images):
        """
        Forward function for the fine-tuned model
        Args
            x: 
        Return
            result
        """
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        output = self.linear(pooled_features)
        return output


## 3. Training

In [16]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [17]:
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, OneCycleLR

def get_lr_scheduler(optimizer, scheduler_type, **kwargs):
    if scheduler_type == 'steplr':
        return StepLR(optimizer, step_size=kwargs.get('step_size', 10), gamma=kwargs.get('gamma', 0.1))
    
    elif scheduler_type == 'plateau':
        return ReduceLROnPlateau(optimizer, mode='min', factor=kwargs.get('factor', 0.1), patience=kwargs.get('patience', 5), verbose=True)
    
    elif scheduler_type == 'onecycle':
        return OneCycleLR(optimizer, max_lr=kwargs.get('max_lr', 0.01), steps_per_epoch=len(train_loader), epochs=kwargs.get('num_epochs', 10))
    
    else:
        raise ValueError(f"Unknown scheduler type: {scheduler_type}")
        
def fetch_scheduler(optimizer):
    if CONFIG['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CONFIG['T_0'], 
                                                             eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'ReduceLROnPlateau':
        scheduler =  ReduceLROnPlateau(optimizer, mode='min', factor=kwargs.get('factor', 0.1), patience=kwargs.get('patience', 5), verbose=False)
    elif CONFIG['scheduler'] == None:
        return None
        
    return scheduler

def get_optimizer(optimizer_name, model):
    if optimizer_name.lower() == "adam":
            optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'],  weight_decay=CONFIG['weight_decay'])
    elif optimizer_name.lower() == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=CONFIG['learning_rate'], momentum=CONFIG['momentum'], weight_decay=CONFIG['weight_decay'])
    else:
        raise ValueError("Invalid Optimizer given!")
    return optimizer
    

In [18]:
def train_one_epoch(model, train_loader, optimizer, criterion, device, writer, epoch, scheduler=None):
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
    model.train()
    train_loss = 0.0
    bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, data in bar:
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)

        optimizer.zero_grad()
        outputs = model(images)
        
        # crossentropy loss
        loss = criterion(outputs, labels)
        # Focal Loss
        #criterion = FocalLoss(gamma=0.7)
        #m = torch.nn.Softmax(dim=-1)
        #loss = criterion(m(outputs), labels)
        
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        # Update learning rate using the scheduler
        if scheduler:
            scheduler.step()
        
        # Log the training loss to TensorBoard
        writer.add_scalar('loss/train_batch', loss.item(), epoch * len(train_loader) + step)
    
    train_loss /= len(train_loader.dataset)
    # Log the average training loss for the epoch to TensorBoard
    writer.add_scalar('loss/train_epoch', train_loss, epoch)
    # gc.collect()
    return train_loss

def validate_one_epoch(model, valid_loader, criterion, device, writer, epoch):
    model.eval()
    valid_loss = 0.0
    valid_acc = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        bar_val = tqdm(enumerate(valid_loader), total=len(valid_loader))
        for step, data in bar_val:
            images = data['image'].to(device, dtype=torch.float)
            labels = data['label'].to(device, dtype=torch.long)
            outputs = model(images)
            
            # crossentropy loss
            loss = criterion(outputs, labels)
            # Focal Loss
            #criterion = FocalLoss(gamma=0.7)
            #m = torch.nn.Softmax(dim=-1)
            #loss = criterion(m(outputs), labels)
        
            valid_loss += loss.item() * images.size(0)
            _, predicted = torch.max(model.softmax(outputs), 1)
            acc = torch.sum( predicted == labels )
            valid_acc  += acc.item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
            writer.add_scalar('loss/valid_batch', loss.item(), epoch * len(valid_loader) + step)
            writer.add_scalar('acc/valid_batch', acc.item(), epoch * len(valid_loader) + step)
    valid_loss /= len(valid_loader.dataset)
    valid_acc /= len(valid_loader.dataset)
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    # At the end of your validation loop:
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    micro_f1 = f1_score(all_labels, all_preds, average='micro')
    weighted_f1 = f1_score(all_labels, all_preds, average='weighted')

    # Logging to TensorBoard
    writer.add_scalar('loss/val_epoch', valid_loss, epoch)
    writer.add_scalar('acc/val_epoch', valid_acc, epoch)
    writer.add_scalar('balanced_acc/val_epoch', bal_acc, epoch)
    writer.add_scalar('F1/macro', macro_f1, epoch)
    writer.add_scalar('F1/micro', micro_f1, epoch)
    writer.add_scalar('F1/weighted', weighted_f1, epoch)
    # in order to put multiple lines within one graph
    #writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
    #                        'xcosx':i*np.cos(i/r),
    #                        'tanx': np.tan(i/r)}, i)
    return valid_loss, valid_acc, bal_acc, weighted_f1

def train_model(model, train_loader, valid_loader, optimizer, criterion, device, num_epochs, scheduler, save_model_path=None):
    model_name = "epochs" + str(CONFIG["num_epochs"]) + "_bs"+str(CONFIG["train_batch_size"] )+ "_opt" +CONFIG["optimizer"]+ "_sched" + CONFIG["scheduler"] + "_lr"+str(CONFIG["learning_rate"])+ "_wd" + str(CONFIG["weight_decay"])
    print(f"Training model: {model_name}")
    datetime_now =  datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    if not save_model_path:
        save_model_path = 'best_model_checkpoint' + datetime_now + '.pth'
    print(f"Path for saving model: {save_model_path}")
    # Initialize TensorBoard writer
    writer = SummaryWriter('logs/fit/' + model_name)
    early_stopping = EarlyStopping(patience=CONFIG["patience"], verbose=True, path=save_model_path)
    #if scheduler_type:
    #    print(f"Define {scheduler_type} scheduler")
    #    scheduler = get_lr_scheduler(optimizer, scheduler_type, num_epochs=num_epochs)
    
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, writer, epoch, scheduler)
        valid_loss, valid_acc, bal_acc, weighted_f1 = validate_one_epoch(model, valid_loader, criterion, device, writer, epoch)
        print(f"Epoch {epoch+1}/{num_epochs} - Train loss: {train_loss:.4f}, Validation loss: {valid_loss:.4f}, Validation acc: {valid_acc:.4f}, Balanced acc: {bal_acc:.4f}, Weighted F1-Score: {weighted_f1:.4f}")
        
        # Log metrics for each epoch
        mlflow.log_metrics({
            'epoch': epoch,
            'train_loss': train_loss,
            'valid_loss': valid_loss,
            'valid_acc': valid_acc,
            'balanced_acc': bal_acc,
            'weighted_f1': weighted_f1
        }, step=epoch)
    
        # Call early stopping
        if CONFIG["early_stopping"]:
            early_stopping(valid_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            
    writer.close()
    return train_loss, valid_loss, valid_acc, save_model_path
    # Load the last checkpoint with the best model
    #model.load_state_dict(torch.load('best_model_checkpoint.pth'))



In [19]:
def test_on_holdout(model, CONFIG, df_test, TRAIN_DIR=None, val_size=1.0):
    if not CONFIG["is_submission"]:
        test_dataset = CancerTilesDataset(df_test, TRAIN_DIR, transforms=data_transforms["valid"], mode="test", split=1.0)
        test_loader = DataLoader(test_dataset, batch_size=CONFIG['valid_batch_size'], 
                                  num_workers=2, shuffle=False, pin_memory=True)

        preds = []
        labels_list = []
        test_acc = 0.0

        with torch.no_grad():
            bar = tqdm(enumerate(test_loader), total=len(test_loader))
            for step, data in bar: 
                # print(step)
                images = data['image'].to(CONFIG["device"], dtype=torch.float)
                labels = data['label'].to(CONFIG["device"], dtype=torch.long)

                batch_size = images.size(0)
                outputs = model(images)
                _, predicted = torch.max(model.softmax(outputs), 1)
                preds.append(predicted.detach().cpu().numpy() )
                labels_list.append(labels.detach().cpu().numpy() )
                acc = torch.sum(predicted == labels )
                test_acc  += acc.item()
        test_acc /= len(test_loader.dataset)
        preds = np.concatenate(preds).flatten()
        labels_list = np.concatenate(labels_list).flatten()
        pred_labels = encoder.inverse_transform( preds )
        
        # Calculate Balanced Accuracy
        bal_acc = balanced_accuracy_score(labels_list, preds)
        # Calculate Confusion Matrix
        conf_matrix = confusion_matrix(labels_list, preds)
        macro_f1 = f1_score(labels_list, preds, average='macro')

    
        print(f"Test Accuracy: {test_acc}")
        print(f"Balanced Accuracy: {bal_acc}")
        print(f"Confusion Matrix: {conf_matrix}")
        
        # add to validation dataframe
        df_test["pred"] = preds
        df_test["pred_labels"] = pred_labels
        
        mlflow.log_metrics({
            'test_acc': test_acc,
            'test_balanced_acc': bal_acc,
            'test_f1_score': macro_f1,
        })
        return df_test
    else:
        print("Skip validation on training set due to submission!")
        return None

In [20]:
def convert_dict_to_tensor(dict_):
    """Converts the values of a dict into a PyTorch tensor."""

    # Create a new PyTorch tensor
    tensor = torch.empty(len(dict_))

    # Iterate over the dict and for each key-value pair, convert the value to a PyTorch tensor and add it to the new tensor
    for i, (key, value) in enumerate(dict_.items()):
        tensor[i] = value

    # Return the new tensor
    return tensor

def get_class_weights(df_train):
    label_counts = df_train.target_label.value_counts().sort_index().to_dict()
    ratios_dict = {}
    for key,val in label_counts.items():
        ratios_dict[key] = val / df_train.shape[0]
    ratios_dict
    weights = {}
    sum_weights = 0
    for key, val in ratios_dict.items():
        weights[key] = 1 / val
        sum_weights +=  1 / val
    for key, val in weights.items():
        weights[key] = val / sum_weights
    weight_tensor = convert_dict_to_tensor(weights)
    return weight_tensor



### Training N-Fold Models

In [21]:
if CONFIG["weighted_loss"]:
    class_weights = get_class_weights(df_train).to(CONFIG['device'], dtype=torch.float)
    print(f"Class weights: {class_weights}")
else:
    class_weights=None
criterion = nn.CrossEntropyLoss(weight=class_weights)


Class weights: tensor([0.1538, 0.1228, 0.0686, 0.3239, 0.3310], device='cuda:0')


In [22]:
def get_dataloaders(df, fold=CONFIG["fold"]):
    df_train = df[df["kfold"]!=fold].reset_index(drop=True)
    train_dataset = CancerTilesDataset(df_train, TRAIN_DIR, transforms=data_transforms["train"], mode="train")
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    valid_dataset = CancerTilesDataset(df_train, TRAIN_DIR, transforms=data_transforms["valid"], mode="valid")
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    print(f"Len Train Dataset: {len(train_dataset)}, Len Validation Dataset: {len(valid_dataset)}" )
    return train_loader, valid_loader, df_train


In [23]:
def print_logged_info(r):
    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print(f"run_id: {r.info.run_id}")
    print(f"artifacts: {artifacts}")
    print(f"params: {r.data.params}")
    print(f"metrics: {r.data.metrics}")
    print(f"tags: {tags}")


In [24]:
for fold in range(1,2):
    #mlflow.set_tracking_uri('https://dagshub.com/Niggl0n/UBC_Cancer_Classification.mlflow')
    mlflow.set_experiment(experiment_name="UBC_cancer_classification")

    with mlflow.start_run(experiment_id=mlflow_experiment_id) as run:
        print(f"Train on Fold: {str(fold)}")
        mlflow.log_params(CONFIG)
        mlflow.log_param("fold", fold)
        train_loader, valid_loader, df_train_fold = get_dataloaders(df_train.copy(), fold=fold)

        checkpoint_path='//kaggle/input/tf-efficientnet-b0-aa-827b6e33-pth/tf_efficientnet_b0_aa-827b6e33.pth'
        model = EfficientNetB5(CONFIG['model_name'], CONFIG['num_classes'], pretrained=False , checkpoint_path=checkpoint_path)
        model.to(CONFIG['device']);

        optimizer = get_optimizer(CONFIG["optimizer"], model)
        scheduler = fetch_scheduler(optimizer)

        _, _, _, save_model_path = train_model(model, train_loader, valid_loader, optimizer, criterion, CONFIG["device"], CONFIG["num_epochs"], scheduler)
        model.load_state_dict(torch.load(save_model_path))
        
        print("Validate on Holdout Set:")
        df_test = df_train[df_train["kfold"]==CONFIG["test_fold"]].reset_index(drop=True)
        df_test = test_on_holdout(model, CONFIG, df_test, TRAIN_DIR, val_size=1)
        
        mlflow.pytorch.log_model(model, "model")
        mlflow.log_artifact(save_model_path)
        df_test_file_path = "df_test_results.csv"
        df_test.to_csv(df_test_file_path, index=False)
        
    print_logged_info(mlflow.get_run(run_id=run.info.run_id))

Train on Fold: 1
Len Train Dataset: 387, Len Validation Dataset: 43


  model = create_fn(


Training model: epochs30_bs16_optAdam_schedCosineAnnealingLR_lr0.0001_wd0.0001
Path for saving model: best_model_checkpoint2023-11-14_08-31-40.pth
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1/30 - Train loss: 1.6529, Validation loss: 1.5421, Validation acc: 0.3256, Balanced acc: 0.3566, Weighted F1-Score: 0.2926
Validation loss decreased (inf --> 1.542073). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2/30 - Train loss: 1.5396, Validation loss: 1.5883, Validation acc: 0.2791, Balanced acc: 0.2484, Weighted F1-Score: 0.2696
EarlyStopping counter: 1 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3/30 - Train loss: 1.4895, Validation loss: 1.5660, Validation acc: 0.2791, Balanced acc: 0.2294, Weighted F1-Score: 0.2813
EarlyStopping counter: 2 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4/30 - Train loss: 1.4161, Validation loss: 1.3932, Validation acc: 0.4884, Balanced acc: 0.4101, Weighted F1-Score: 0.5305
Validation loss decreased (1.542073 --> 1.393169). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 5/30 - Train loss: 1.3545, Validation loss: 1.3250, Validation acc: 0.4419, Balanced acc: 0.3887, Weighted F1-Score: 0.4580
Validation loss decreased (1.393169 --> 1.325038). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 6/30 - Train loss: 1.2921, Validation loss: 1.4202, Validation acc: 0.5116, Balanced acc: 0.4871, Weighted F1-Score: 0.4738
EarlyStopping counter: 1 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7/30 - Train loss: 1.2686, Validation loss: 1.1163, Validation acc: 0.6279, Balanced acc: 0.5777, Weighted F1-Score: 0.6047
Validation loss decreased (1.325038 --> 1.116282). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8/30 - Train loss: 1.1763, Validation loss: 1.3190, Validation acc: 0.4884, Balanced acc: 0.3646, Weighted F1-Score: 0.4798
EarlyStopping counter: 1 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 9/30 - Train loss: 1.0946, Validation loss: 1.1456, Validation acc: 0.6047, Balanced acc: 0.5229, Weighted F1-Score: 0.5885
EarlyStopping counter: 2 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 10/30 - Train loss: 1.0983, Validation loss: 0.9848, Validation acc: 0.6977, Balanced acc: 0.7442, Weighted F1-Score: 0.6922
Validation loss decreased (1.116282 --> 0.984759). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 11/30 - Train loss: 1.0749, Validation loss: 1.1005, Validation acc: 0.6047, Balanced acc: 0.6003, Weighted F1-Score: 0.6016
EarlyStopping counter: 1 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 12/30 - Train loss: 1.0723, Validation loss: 1.2385, Validation acc: 0.5116, Balanced acc: 0.4158, Weighted F1-Score: 0.5159
EarlyStopping counter: 2 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 13/30 - Train loss: 1.0282, Validation loss: 1.0634, Validation acc: 0.6047, Balanced acc: 0.4942, Weighted F1-Score: 0.6310
EarlyStopping counter: 3 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 14/30 - Train loss: 0.9692, Validation loss: 0.9161, Validation acc: 0.6512, Balanced acc: 0.5513, Weighted F1-Score: 0.6484
Validation loss decreased (0.984759 --> 0.916120). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 15/30 - Train loss: 0.9724, Validation loss: 1.1096, Validation acc: 0.5814, Balanced acc: 0.4990, Weighted F1-Score: 0.5739
EarlyStopping counter: 1 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 16/30 - Train loss: 0.9029, Validation loss: 1.0504, Validation acc: 0.6047, Balanced acc: 0.6086, Weighted F1-Score: 0.5942
EarlyStopping counter: 2 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 17/30 - Train loss: 0.8973, Validation loss: 0.9832, Validation acc: 0.5581, Balanced acc: 0.4740, Weighted F1-Score: 0.5415
EarlyStopping counter: 3 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 18/30 - Train loss: 0.9328, Validation loss: 0.8719, Validation acc: 0.7442, Balanced acc: 0.6406, Weighted F1-Score: 0.7201
Validation loss decreased (0.916120 --> 0.871907). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 19/30 - Train loss: 0.9001, Validation loss: 0.9875, Validation acc: 0.5814, Balanced acc: 0.5908, Weighted F1-Score: 0.5760
EarlyStopping counter: 1 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 20/30 - Train loss: 0.8845, Validation loss: 1.1167, Validation acc: 0.6744, Balanced acc: 0.6335, Weighted F1-Score: 0.6798
EarlyStopping counter: 2 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 21/30 - Train loss: 0.8257, Validation loss: 0.9394, Validation acc: 0.6047, Balanced acc: 0.5386, Weighted F1-Score: 0.5956
EarlyStopping counter: 3 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 22/30 - Train loss: 0.8251, Validation loss: 1.0740, Validation acc: 0.6512, Balanced acc: 0.6085, Weighted F1-Score: 0.6638
EarlyStopping counter: 4 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 23/30 - Train loss: 0.7854, Validation loss: 0.8587, Validation acc: 0.7442, Balanced acc: 0.6943, Weighted F1-Score: 0.7498
Validation loss decreased (0.871907 --> 0.858652). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 24/30 - Train loss: 0.8313, Validation loss: 0.8457, Validation acc: 0.7209, Balanced acc: 0.5943, Weighted F1-Score: 0.7149
Validation loss decreased (0.858652 --> 0.845698). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 25/30 - Train loss: 0.7602, Validation loss: 0.7700, Validation acc: 0.6047, Balanced acc: 0.6729, Weighted F1-Score: 0.5959
Validation loss decreased (0.845698 --> 0.769951). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 26/30 - Train loss: 0.8037, Validation loss: 0.7397, Validation acc: 0.7209, Balanced acc: 0.6884, Weighted F1-Score: 0.7198
Validation loss decreased (0.769951 --> 0.739732). Saving model ...
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 27/30 - Train loss: 0.7790, Validation loss: 0.7814, Validation acc: 0.6512, Balanced acc: 0.6538, Weighted F1-Score: 0.6440
EarlyStopping counter: 1 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 28/30 - Train loss: 0.7788, Validation loss: 0.8300, Validation acc: 0.7442, Balanced acc: 0.7134, Weighted F1-Score: 0.7445
EarlyStopping counter: 2 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 29/30 - Train loss: 0.7236, Validation loss: 0.7869, Validation acc: 0.6279, Balanced acc: 0.5764, Weighted F1-Score: 0.6188
EarlyStopping counter: 3 out of 7
[INFO] Using GPU: Tesla P100-PCIE-16GB



  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 30/30 - Train loss: 0.7738, Validation loss: 0.7564, Validation acc: 0.7209, Balanced acc: 0.6335, Weighted F1-Score: 0.7082
EarlyStopping counter: 4 out of 7
Validate on Holdout Set:


  0%|          | 0/7 [00:00<?, ?it/s]

Test Accuracy: 0.6203703703703703
Balanced Accuracy: 0.6264444444444445
Confusion Matrix: [[15  1  2  2  0]
 [ 3 14  3  3  2]
 [ 6  7 27  4  1]
 [ 2  0  1  6  0]
 [ 0  2  1  1  5]]




run_id: 2584ec3f124741a1b72b311cb0060e29
artifacts: ['model/MLmodel', 'model/conda.yaml', 'model/data', 'model/python_env.yaml', 'model/requirements.txt']
params: {'is_submission': 'False', 'crop_vertical': 'True', 'weighted_loss': 'True', 'datetime_now': '2023-11-14_08-31-24', 'n_fold': '5', 'test_fold': '0', 'fold': '1', 'seed': '42', 'img_size': '512', 'center_crop_size': '1024', 'model_name': 'tf_efficientnet_b0_ns', 'num_classes': '5', 'train_batch_size': '16', 'valid_batch_size': '16', 'device': 'cuda:0', 'num_epochs': '30', 'early_stopping': 'True', 'patience': '7', 'optimizer': 'Adam', 'learning_rate': '0.0001', 'scheduler': 'CosineAnnealingLR', 'min_lr': '1e-06', 'T_max': '30', 'momentum': '0.9', 'weight_decay': '0.0001'}
metrics: {'epoch': 29.0, 'train_loss': 0.773763258789861, 'valid_loss': 0.756415780893592, 'valid_acc': 0.720930232558139, 'balanced_acc': 0.633516483516484, 'weighted_f1': 0.708199827734711, 'test_acc': 0.62037037037037, 'test_balanced_acc': 0.62644444444444

In [25]:
# model.load_state_dict(torch.load('/kaggle/working/best_model_checkpoint' + CONFIG["datetime_now"] + '.pth'))
df_test = df_train[df_train["kfold"]==CONFIG["test_fold"]].reset_index(drop=True)
df_test = test_on_holdout(model, CONFIG, df_test, TRAIN_DIR, val_size=1)
df_test

  0%|          | 0/7 [00:00<?, ?it/s]

Test Accuracy: 0.6666666666666666
Balanced Accuracy: 0.7248888888888888
Confusion Matrix: [[14  1  1  4  0]
 [ 0 17  4  3  1]
 [ 6  3 26  8  2]
 [ 0  0  0  9  0]
 [ 1  1  1  0  6]]


Unnamed: 0,image_id,label,image_width,image_height,is_tma,file_path,target_label,kfold,pred,pred_labels
0,431,HGSC,39991,40943,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,2,0.0,0,CC
1,1101,HGSC,26306,18403,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,2,0.0,0,CC
2,1943,CC,73730,34949,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,0,0.0,2,HGSC
3,2666,EC,53270,44031,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,1,0.0,2,HGSC
4,2706,HGSC,71289,22569,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,2,0.0,3,LGSC
...,...,...,...,...,...,...,...,...,...,...
103,63367,EC,62905,24783,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,1,0.0,3,LGSC
104,63429,EC,67783,29066,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,1,0.0,3,LGSC
105,63836,EC,17416,21934,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,1,0.0,0,CC
106,63941,HGSC,47123,36600,False,/kaggle/input/tiles-of-cancer-2048px-scale-0-2...,2,0.0,2,HGSC


# Append: Inference Methods

In [26]:
"""from joblib.externals.loky.backend.context import get_context




import os
import pyvips
import numpy as np
import random
from PIL import Image

def extract_image_tiles(
    p_img, size: int = 2048, scale: float = 0.5,
    drop_thr: float = 0.6, white_thr: int = 245, max_samples: int = 50
) -> list:
    im = pyvips.Image.new_from_file(p_img)
    w = h = size
    # https://stackoverflow.com/a/47581978/4521646
    idxs = [(y, y + h, x, x + w) for y in range(0, im.height, h) for x in range(0, im.width, w)]
    # random subsample
    max_samples = max_samples if isinstance(max_samples, int) else int(len(idxs) * max_samples)
    random.shuffle(idxs)
    images = []
    for y, y_, x, x_ in idxs:
        # https://libvips.github.io/pyvips/vimage.html#pyvips.Image.crop
        tile = im.crop(x, y, min(w, im.width - x), min(h, im.height - y)).numpy()[..., :3]
        if tile.shape[:2] != (h, w):
            tile_ = tile
            tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
            tile = np.zeros(tile_size, dtype=tile.dtype)
            tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
        black_bg = np.sum(tile, axis=2) == 0
        tile[black_bg, :] = 255
        mask_bg = np.mean(tile, axis=2) > white_thr
        if np.sum(mask_bg) >= (np.prod(mask_bg.shape) * drop_thr):
            #print(f"skip almost empty tile: {k:06}_{int(x_ / w)}-{int(y_ / h)}")
            continue
        # print(tile.shape, tile.dtype, tile.min(), tile.max())
        new_size = int(size * scale), int(size * scale)
        images.append(np.array(
            Image.fromarray(tile).resize(new_size, Image.LANCZOS)
        ))
        # need to set counter check as some empty tiles could be skipped earlier
        if len(images) >= max_samples:
            break
    return images


class TilesInferenceDataset(Dataset):

    def __init__(
        self,
        img_path: str,
        size: int = 2048,
        scale: float = 0.25,
        drop_thr: float = 0.6,
        max_samples: int = 30,
        transforms = None
    ):
        assert os.path.isfile(img_path)
        self.transforms = transforms
        self.imgs = extract_image_tiles(
            img_path, size=size, scale=scale,
            drop_thr=drop_thr, max_samples=max_samples)

    def __getitem__(self, idx: int) -> tuple:
        img = self.imgs[idx]
        # filter background
        mask = np.sum(img, axis=2) == 0
        img[mask, :] = 255
        if np.max(img) < 1.5:
            img = np.clip(img * 255, 0, 255).astype(np.uint8)
        # augmentation
        if self.transforms:
            img = self.transforms(Image.fromarray(img))
        #print(f"img dim: {img.shape}")
        return img

    def __len__(self) -> int:
        return len(self.imgs)
    
    
def infer_single_image(idx_row, model, device="cuda") -> dict:
    row = dict(idx_row[1])
    # prepare data - cut and load tiles
    dataset = TilesImageDataset(
        os.path.join(DATASET_FOLDER, "test_images", f"{str(row['image_id'])}.png"),
        size=2048, scale=0.25, transforms=VALID_TRANSFORM)
    if not len(dataset):
        print (f"seem no tiles were cut for `{row['image_id']}`")
        return row
    preds = []
    model = model.to(device)
    dataloader = DataLoader(
        dataset, batch_size=4, num_workers=2, shuffle=False,
        # see: https://github.com/pytorch/pytorch/issues/44687#issuecomment-790842173
        multiprocessing_context=get_context('loky')
    )
    # iterate over images and collect predictions | 
    for imgs in dataloader:
        #print(f"{imgs.shape}")
        with torch.no_grad():
            pred = model(imgs.to(device))
        preds += pred.cpu().numpy().tolist()
    print(f"Sum contrinution from all tiles: {np.sum(preds, axis=0)}")
    print(f"Max contribution over all tiles: {np.max(preds, axis=0)}")
    # decide label
    lb = np.argmax(np.sum(preds, axis=0))
    row['label'] = labels[lb]
    print(row)
    return row"""

'from joblib.externals.loky.backend.context import get_context\n\n\n\n\nimport os\nimport pyvips\nimport numpy as np\nimport random\nfrom PIL import Image\n\ndef extract_image_tiles(\n    p_img, size: int = 2048, scale: float = 0.5,\n    drop_thr: float = 0.6, white_thr: int = 245, max_samples: int = 50\n) -> list:\n    im = pyvips.Image.new_from_file(p_img)\n    w = h = size\n    # https://stackoverflow.com/a/47581978/4521646\n    idxs = [(y, y + h, x, x + w) for y in range(0, im.height, h) for x in range(0, im.width, w)]\n    # random subsample\n    max_samples = max_samples if isinstance(max_samples, int) else int(len(idxs) * max_samples)\n    random.shuffle(idxs)\n    images = []\n    for y, y_, x, x_ in idxs:\n        # https://libvips.github.io/pyvips/vimage.html#pyvips.Image.crop\n        tile = im.crop(x, y, min(w, im.width - x), min(h, im.height - y)).numpy()[..., :3]\n        if tile.shape[:2] != (h, w):\n            tile_ = tile\n            tile_size = (h, w) if tile.ndim =