In [1]:
!pip install --quiet mlflow dagshub

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ydata-profiling 4.3.1 requires dacite>=1.8, but you have dacite 1.6.0 which is incompatible.
ydata-profiling 4.3.1 requires scipy<1.11,>=1.4.1, but you have scipy 1.11.2 which is incompatible.[0m[31m
[0m

In [2]:


import datetime
import os
import gc
import cv2
import math
import copy
import time
import random
import glob
from matplotlib import pyplot as plt

import mlflow
from mlflow import MlflowClient
import mlflow.pytorch 
import dagshub
# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import torchvision

from itertools import chain
import heapq

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, f1_score, accuracy_score

import albumentations as A
from albumentations.pytorch import ToTensorV2
from skimage import io

# For Image Models
import timm

from joblib.externals.loky.backend.context import get_context




import os
import numpy as np
import random
from PIL import Image

from tqdm.auto import tqdm
from joblib import Parallel, delayed

# Albumentations for augmentations
# import albumentations as A
# from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings


In [3]:


CONFIG = {
    "is_submission": False,
    "datetime_now": datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 
    "n_fold": 5,
    'fold': 1,
    'test_fold': 0,
    "seed": 42,
    "img_size": 512,
    "crop_vertical":True,
    "model_name": "tf_efficientnetv2_s_in21ft1k",   # "tf_efficientnet_b0_ns", # "tf_efficientnetv2_s_in21ft1k"
    "checkpoint_path": "/kaggle/input/tf-efficientnetv2-s-in21ft1k/tf_efficientnetv2_s_in21ft1k.pth",
    "num_classes": 5,
    "valid_batch_size": 16,
    "train_batch_size": 16,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    # "model_path": '/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-10-26_09-10-29.pth',
    "encoder_path": "/kaggle/input/effnet-version-28/label_encoder_2023-11-21_15-45-54.pkl"
}



In [4]:
def get_or_create_experiment_id(name):
    exp = mlflow.get_experiment_by_name(name)
    if exp is None:
        exp_id = mlflow.create_experiment(name)
        return exp_id
    return exp.experiment_id

## 1. Datasets & Preprocessing

In [5]:
def get_train_file_path(df_train_row, TRAIN_DIR):
    if df_train_row.is_tma == False:
        return f"{TRAIN_DIR}/{df_train_row.image_id}_thumbnail.png"
    else:
        return f"{TRAIN_DIR}/{df_train_row.image_id}.png"


def get_test_file_path(image_id,TEST_DIR):
    if os.path.exists(f"{TEST_DIR}/{image_id}.png"):
        return f"{TEST_DIR}/{image_id}.png"
    else:
        return f"{ALT_TEST_DIR}/{image_id}.png"

In [6]:
class UBCDataset(Dataset):
    def __init__(self, df, transforms=None, apply_vertical_crop=True):
        self.df = df
        self.filenames = df.file_path.values
        self.labels =  df.target_label.values
        self.transforms = transforms
        self.apply_vertical_crop = apply_vertical_crop

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.apply_vertical_crop:
            img = crop_vertical(img)
                
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            "image": img,
            "label": torch.tensor(self.labels[idx], dtype=torch.long)
               }

def crop_vertical(image):
    """
    Function crops images if multiple slices contained and separated by black vertical background.
    """
    vertical_sum = np.sum(image, axis=(0, 2))

    # Identify the positions where the sum is zero
    zero_positions = np.where(vertical_sum == 0)[0]

    if len(zero_positions)==0:
        cropped_images = [image]
    else:
        # If the image does not start with a black area, add index 0
        if zero_positions[0] != 0:
            zero_positions = np.insert(zero_positions, 0, 0)

        # If the image does not end with a black area, add the image width
        if zero_positions[-1] != image.shape[1] - 1:
            zero_positions = np.append(zero_positions, image.shape[1] - 1)

        start_idx = zero_positions[0]
        cropped_images = []

        for idx in range(1, len(zero_positions)):
            end_idx = zero_positions[idx]
            if end_idx - start_idx > 1:  # If the width of the cropped section is greater than 1
                cropped = image[:, start_idx:end_idx]
                # only include samples which are of min size
                if cropped.shape[1]>200:  
                    cropped_images.append(cropped)
                    # cv2.imwrite(f"{save_prefix}_{idx}.jpg", cropped)
            start_idx = end_idx

    final_crops = []
    # remove black bars above/below the crops 
    for cropped in cropped_images:
        horizontal_sum = np.sum(cropped, axis=(1, 2))
        zero_positions = np.where(horizontal_sum == 0)[0]
        img_ = np.delete(cropped, zero_positions, axis=0)
        final_crops.append(img_)
    if len(final_crops)==0:
        return image
    return final_crops[0]


def custom_center_crop_or_resize(image, crop_size):
    # If both dimensions of the image are greater than or equal to the desired size, apply CenterCrop
    if image.shape[0] >= crop_size[0] and image.shape[1] >= crop_size[1]:
        return A.CenterCrop(crop_size[0], crop_size[1])(image=image)["image"]
    # Else, just resize the image to the desired size
    else:
        return A.Resize(crop_size[0], crop_size[1])(image=image)["image"]

In [7]:
def _color_means(img_path):
    img = np.array(Image.open(img_path))
    mask = np.sum(img[..., :3], axis=2) == 0
    img[mask, :] = 255
    if np.max(img) > 1.5:
        img = img / 255.0
    clr_mean = {i: np.mean(img[..., i]) for i in range(3)}
    clr_std = {i: np.std(img[..., i]) for i in range(3)}
    return clr_mean, clr_std

In [8]:
img_color_mean=[0.8661704276539922, 0.7663107094675368, 0.8574260897185548]
img_color_std=[0.08670629753900036, 0.11646580094195522, 0.07164169171856792]

data_transforms = {
    "train": A.Compose([
        A.Resize(512, 512),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
        A.GaussNoise(var_limit=[10, 50]),
        A.GaussianBlur(),
        A.MotionBlur(),
        ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(512* 0.3), max_height=int(512* 0.3),
        mask_fill_value=0, p=0.5),
        A.Normalize(img_color_mean, img_color_std), 
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(img_color_mean, img_color_std), 
        ToTensorV2()], p=1.)
}

## 2.2. Tiles Dataset

### 2.2.1. Train Dataset

In [9]:
class CancerTilesDataset(Dataset):
    def __init__(
        self,
        df_data,
        path_img_dir: str =  '',
        transforms = None,
        mode: str = 'train',
        labels_lut = None,
        white_thr: int = 225,
        thr_max_bg: float = 0.2,
        split: float = 0.90,
        n_tiles: int = 1
    ):
        assert os.path.isdir(path_img_dir)
        self.path_img_dir = path_img_dir
        self.transforms = transforms
        self.mode = mode
        self.white_thr = white_thr
        self.thr_max_bg = thr_max_bg
        self.train_val_split = train_val_split
        self.n_tiles = n_tiles

        self.data = df_data
        self.labels_unique = sorted(self.data["label"].unique())
        self.labels_lut = labels_lut or {lb: i for i, lb in enumerate(self.labels_unique)}
        # shuffle data
        self.data = self.data.sample(frac=1, random_state=42).reset_index(drop=True)

        # split dataset
        assert 0.0 <= self.train_val_split <= 1.0
        frac = int(self.train_val_split * len(self.data))
        self.data = self.data[:frac] if mode in ["train", "test"] else self.data[frac:]
        self.img_dirs = [glob.glob(os.path.join(path_img_dir, str(idx), "*.png")) for idx in self.data["image_id"]] 
        self.img_dirs = self.img_dirs * self.n_tiles
        self.img_paths = []
        #print(f"missing: {sum([not os.path.isfile(os.path.join(self.path_img_dir, im))
        #                       for im in self.img_names])}")
        # self.labels = list(self.data['label'])
        self.labels =  np.array(self.data.target_label.values.tolist() * self.n_tiles)

    def __getitem__(self, idx: int) -> tuple:
        nth_iteration = idx//len(self.data)
        if self.mode=="train":
            random.seed()
        else:
            random.seed(CONFIG["seed"]+nth_iteration)
        random.shuffle(self.img_dirs[idx])
        for img_path in self.img_dirs[idx]:
            assert os.path.isfile(img_path), f"missing: {img_path}"
            tile = cv2.imread(img_path)
            tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB)
        
            # tile = np.array(Image.open(img_path))[..., :3]
            black_bg = np.sum(tile, axis=2) == 0
            tile[black_bg, :] = 255
            mask_bg = np.mean(tile, axis=2) > self.white_thr
            if np.sum(mask_bg) < (np.prod(mask_bg.shape) * self.thr_max_bg):
                #self.img_paths.append(img_path)
                #print(f"Idx: {idx}, Path: {img_path}, len img_pths: {len(self.img_paths)}, nunique img_paths: {len(set(self.img_paths))}")
                break

        # augmentation
        if self.transforms:
            tile = self.transforms(image=tile)["image"]
        #print(f"img dim: {img.shape}")
        return {
            "image": tile,
            "label": torch.tensor(self.labels[idx], dtype=torch.long)
               }
    def __len__(self) -> int:
        return len(self.img_dirs)

### 2.2.2 Inference Dataset

In [10]:
def delete_tiles(directory_path):
    if os.path.isdir(directory_path):
        for filename in os.listdir(directory_path):
            if os.path.isfile(os.path.join(directory_path, filename)):
                os.remove(os.path.join(directory_path, filename))

def extract_image_tiles(
    p_img, img_id, tmp_dir, size: int = 2048, scale: float = 0.5,
    drop_thr: float = 0.8, white_thr: int = 245, max_samples: int = 50
) -> list:
    delete_tiles(tmp_dir)  # empty directory from previous images
    im = pyvips.Image.new_from_file(p_img)
    w = h = size
    # https://stackoverflow.com/a/47581978/4521646
    idxs = [(y, y + h, x, x + w) for y in range(0, im.height, h) for x in range(0, im.width, w)]
    # random subsample
    max_samples = max_samples if isinstance(max_samples, int) else int(len(idxs) * max_samples)
    random.seed(42)
    random.shuffle(idxs)
    images = []
    i = 0
    for y, y_, x, x_ in (idxs):
        i += 1
        img_path = f"{tmp_dir}/{str(i)}.png"
        # https://libvips.github.io/pyvips/vimage.html#pyvips.Image.crop
        tile = im.crop(x, y, min(w, im.width - x), min(h, im.height - y)).numpy()[..., :3]
        if tile.shape[:2] != (h, w):
            tile_ = tile
            tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
            tile = np.zeros(tile_size, dtype=tile.dtype)
            tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
        black_bg = np.sum(tile, axis=2) == 0
        tile[black_bg, :] = 255
        mask_bg = np.mean(tile, axis=2) > white_thr
        if np.sum(mask_bg) >= (np.prod(mask_bg.shape) * drop_thr):
            #print(f"skip almost empty tile: {k:06}_{int(x_ / w)}-{int(y_ / h)}")
            continue
        # print(tile.shape, tile.dtype, tile.min(), tile.max())
        new_size = int(size * scale), int(size * scale)
        tile = Image.fromarray(tile).resize(new_size, Image.LANCZOS)
        tile.save(img_path)
        images.append(img_path)
        # need to set counter check as some empty tiles could be skipped earlier
        if len(images) >= max_samples:
            break
    return images


class TilesInferenceDataset(Dataset):

    def __init__(
        self,
        img_path: str,
        img_id: str = None,
        tmp_dir: str = None,
        size: int = 2048,
        scale: float = 0.25,
        white_thr: int = 225,
        thr_max_bg: float = 0.6,
        max_samples: int = 30,
        transforms = None,
        is_submission: bool = True,
    ):
        self.max_samples = max_samples
        self.white_thr = white_thr
        self.thr_max_bg = thr_max_bg
        self.is_submission = is_submission
        
        self.transforms = transforms
        if self.is_submission:
            # print(img_path)
            assert os.path.isfile(img_path)
            self.imgs = extract_image_tiles(
                img_path, img_id, tmp_dir, size=size, scale=scale,
                drop_thr=self.thr_max_bg, max_samples=max_samples)
        else:  # test
            all_imgs = glob.glob(os.path.join(img_path, img_id, "*.png"))
            # Filter images based on background threshold
            self.imgs = []
            for img_path in all_imgs:
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                black_bg = np.sum(img, axis=2) == 0
                img[black_bg, :] = 255
                mask_bg = np.mean(img, axis=2) > self.white_thr
                if np.sum(mask_bg) <= (np.prod(mask_bg.shape) * self.thr_max_bg):
                    self.imgs.append(img_path)  # Include this image
            self.imgs = self.imgs[:self.max_samples]
            # print(self.imgs)

    def __getitem__(self, idx: int) -> tuple:
        img = cv2.imread(self.imgs[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # filter background
        mask = np.sum(img, axis=2) == 0
        img[mask, :] = 255
        if np.max(img) < 1.5:
            img = np.clip(img * 255, 0, 255).astype(np.uint8)
        if self.transforms:
            img = self.transforms(image=img)["image"]
        return img

    def __len__(self) -> int:
        return len(self.imgs)


# 3. Model Architecture

In [11]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'


class UBCModel(nn.Module):
    '''
    EfficientNet B0 fine-tune.
    '''
    def __init__(self, model_name, num_classes, pretrained=False, checkpoint_path=None):
        '''
        Fine tune for EfficientNetB0
        Args
            n_classes : int - Number of classification categories.
            learnable_modules : tuple - Names of the modules to fine-tune.
        Return
            
        '''
        super(UBCModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path=checkpoint_path)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.linear = nn.Linear(in_features, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, images):
        """
        Forward function for the fine-tuned model
        Args
            x: 
        Return
            result
        """
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        output = self.linear(pooled_features)
        return output

In [12]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# 4. Training

In [13]:


def fetch_scheduler(optimizer, CONFIG):
    if CONFIG['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'CosineAnnealingWarmRestarts':
        CONFIG['T_0'] = 10
        CONFIG['T_mult'] = 2
        CONFIG['min_lr'] = 1e-6
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CONFIG['T_0'], T_mult=CONFIG['T_mult'],
                                                             eta_min=CONFIG['min_lr'], verbose=False)
    elif CONFIG['scheduler'] == 'ReduceLROnPlateau':
        scheduler =  ReduceLROnPlateau(optimizer, mode='min', factor=kwargs.get('factor', 0.1), patience=kwargs.get('patience', 5), verbose=False)
    elif CONFIG['scheduler'] == 'LambdaLR':
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda)
    elif CONFIG['scheduler'] == None:
        return None
    return scheduler

def get_optimizer(optimizer_name, model):
    if optimizer_name.lower() == "adam":
        CONFIG['learning_rate'] = 1e-4
        CONFIG['weight_decay'] = 1e-5
        CONFIG['betas'] = (0.9, 0.999)
        CONFIG['eps'] = 1e-8
        optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], betas=CONFIG['betas'], eps=CONFIG['eps'],  weight_decay=CONFIG['weight_decay'])
    elif optimizer_name.lower() == "sgd":
        CONFIG['learning_rate'] = 1e-3
        CONFIG['weight_decay'] = 1e-3
        CONFIG['momentum'] = 1e-3
        optimizer = optim.SGD(model.parameters(), lr=CONFIG['learning_rate'], momentum=CONFIG['momentum'], weight_decay=CONFIG['weight_decay'])
    elif optimizer_name.lower() == "radam":
        CONFIG['learning_rate'] = 1e-4
        CONFIG['weight_decay'] = 0
        CONFIG['betas'] = (0.9, 0.999)
        CONFIG['eps'] = 1e-8
        optimizer = torch_optimizer.RAdam(
            model.parameters(),
            lr= CONFIG['learning_rate'],
            betas=CONFIG['betas'],
            eps=CONFIG['eps'],
            weight_decay=CONFIG['weight_decay'],
        )
    elif optimizer_name.lower() == "rmsprop":
        CONFIG['learning_rate'] = 0.256
        CONFIG['alpha'] = 0.9
        CONFIG['momentum'] = 0.9
        CONFIG['weight_decay'] = 1e-5
        optimizer = optim.RMSprop(model.parameters(), lr=CONFIG['learning_rate'], alpha=CONFIG['learning_rate'], 
                                  momentum=CONFIG['learning_rate'], weight_decay=CONFIG['learning_rate'])
    else:
        raise ValueError("Invalid Optimizer given!")
    return optimizer
    



In [14]:


def convert_dict_to_tensor(dict_):
    """Converts the values of a dict into a PyTorch tensor."""

    # Create a new PyTorch tensor
    tensor = torch.empty(len(dict_))

    # Iterate over the dict and for each key-value pair, convert the value to a PyTorch tensor and add it to the new tensor
    for i, (key, value) in enumerate(dict_.items()):
        tensor[i] = value

    # Return the new tensor
    return tensor

def get_class_weights(df_train):
    label_counts = df_train.target_label.value_counts().sort_index().to_dict()
    ratios_dict = {}
    for key,val in label_counts.items():
        ratios_dict[key] = val / df_train.shape[0]
    ratios_dict
    weights = {}
    sum_weights = 0
    for key, val in ratios_dict.items():
        weights[key] = 1 / val
        sum_weights +=  1 / val
    for key, val in weights.items():
        weights[key] = val / sum_weights
    weight_tensor = convert_dict_to_tensor(weights)
    return weight_tensor

def get_dataloaders(df, TRAIN_DIR, CONFIG, data_transforms, n_tiles=1, train_val_split=0.9):
    # df_train = df[df["kfold"]!=fold].reset_index(drop=True)
    train_dataset = CancerTilesDataset(df, TRAIN_DIR, transforms=data_transforms["train"], mode="train", n_tiles=n_tiles, train_val_split=train_val_split)
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    valid_dataset = CancerTilesDataset(df, TRAIN_DIR, transforms=data_transforms["valid"], mode="valid", n_tiles=n_tiles, train_val_split=train_val_split)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    print(f"Len Train Dataset: {len(train_dataset)}, Len Validation Dataset: {len(valid_dataset)}" )
    return train_loader, valid_loader, df

def print_logged_info(r):
    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print(f"run_id: {r.info.run_id}")
    print(f"artifacts: {artifacts}")
    print(f"params: {r.data.params}")
    print(f"metrics: {r.data.metrics}")
    print(f"tags: {tags}")



# 5. Inference & Evaluation

In [15]:
def eval_predictions(df):
    # Total Accuracy
    total_accuracy = accuracy_score(df['target_label'], df['label'])

    # Balanced Accuracy
    balanced_accuracy = balanced_accuracy_score(df['target_label'], df['label'])

    # F1 Score
    f1 = f1_score(df['target_label'], df['label'], average='weighted')

    # Accuracy Per Class
    cm = confusion_matrix(df['target_label'], df['label'])
    class_accuracy = cm.diagonal() / cm.sum(axis=1)

    print(f"Total Accuracy: {total_accuracy}")
    print(f"Balanced Accuracy: {balanced_accuracy}")
    print(f"F1 Score: {f1}")
    print(f"Accuracy Per Class: {class_accuracy}")
    display(cm)
    
def most_frequent(List):
    return max(set(List), key = List.count)

def score_predictions(preds, method="sum", N=10):
    if method=="sum":  # sum up the predictions of all tiles/models
        lb = np.argmax(np.sum(preds, axis=0))
    elif method in ["most_frequent", "most_votes", "majority_vote"]:  # get majority vote over all tiles/models 
        lb = most_frequent(np.argmax(preds, axis=1).tolist())
    elif method == "n_highest_sum":  # sum up predictions of N-most decicive tiles/models
        max_vals = np.max(preds, axis=1).tolist()
        max_idxs = [max_vals.index(i) for i in heapq.nlargest(N, max_vals)]
        n_tiles_preds = np.take(preds, max_idxs, axis=0).tolist()
        lb = np.argmax(np.sum(n_tiles_preds, axis=0))
    else:
        print("No method found: Apply Sum Method for Scoring predictions!")
        lb = np.argmax(np.sum(preds, axis=0))
    return lb

In [16]:
def infer_with_softmax(model, images):
    with torch.no_grad():
        logits = model(images)
        probabilities = torch.softmax(logits, dim=1)
        return probabilities
    
def apply_thr_outlier_detect(predictions, label, threshold = 0.60):
    # Convert logits to probabilities
    probabilities = softmax(predictions)

    # Apply the threshold
    max_probabilities = np.max(probabilities)
    if max_probabilities < threshold:
        print("Outlier detected:", max_probabilities)
        return "Other"
    else:
        return label
    
def infer_single_image_ensemble(idx_row, models, CONFIG, encoder, score_method="sum", max_samples=30, thr_max_bg=0.1, is_submission=True, device="cuda") -> dict:
    """
    Create tiled-dataset based on test image.
    Iterate throuh all tiles and apply model prediction.
    Select highest of sum of all logits.
    """
    row = dict(idx_row[1])
    img_id = str(row["image_id"])
    result = {"image_id": img_id}
    tmp_dir = "tmp_tiles_"+str(img_id)
    print("Image ID: ", img_id)
    if is_submission:
        result["target_label"] = encoder.inverse_transform(np.array(row["target_label"]).ravel())[0]
        # delete old directory if exists and create new empty directory to temporarily save image tiles   
        if os.path.exists(tmp_dir) and os.path.isdir(tmp_dir):
            shutil.rmtree(tmp_dir)
        os.mkdir(tmp_dir)  
        dataset = TilesInferenceDataset(
            os.path.join("/kaggle/input/UBC-OCEAN/", "train_images", f"{img_id}.png"), tmp_dir=tmp_dir, 
            size=2048, scale=0.25, thr_max_bg=thr_max_bg, transforms=data_transforms["valid"], max_samples=max_samples)
    else:
        dataset = TilesInferenceDataset(
            "/kaggle/input/tiles-of-cancer-2048px-scale-0-25", img_id,
            size=2048, scale=0.25, thr_max_bg=thr_max_bg, transforms=data_transforms["valid"], is_submission=is_submission, max_samples=max_samples)
        result["target_label"] = encoder.inverse_transform(np.array(row["target_label"]).ravel())[0]
        
    dataloader = DataLoader(
        dataset, batch_size=CONFIG["valid_batch_size"], num_workers=2, shuffle=False,
        multiprocessing_context=get_context('loky')         # see: https://github.com/pytorch/pytorch/issues/44687#issuecomment-790842173
    )
    if not len(dataset):  # if no tiles available, set to "Other" class
        if not is_submission: 
            print (f"seem no tiles were cut for `{row['image_id']}`. Set to label Other")
        result["label"] = "Other"
        return result
    
    if not isinstance(models, list):
        models = [models]
    
    model_preds_sum = []
    for i,model in enumerate(models):
        #print(f"Apply Model {i+1} of {len(models)}")
        model = model.to(device)
        model.eval()
        preds = []
        for imgs in dataloader:
            # print(f"{imgs.shape}")
            probabilities = infer_with_softmax(model, imgs.to(device))
            preds += probabilities.cpu().numpy().tolist()
        if not is_submission:
            print(f"Sum contrinution from all tiles: {np.sum(preds, axis=0)}")
            print(f"Max contribution over all tiles: {np.max(preds, axis=0)}")
        model_preds_sum.append(preds)
        
    model_preds_sum = sum(model_preds_sum, [])
    prediction = score_predictions(model_preds_sum, method=score_method)
    
    result["label"] = encoder.inverse_transform(np.array(prediction).reshape(-1,))[0]        
    result["label"] = apply_thr_outlier_detect(prediction, result["label"], threshold=0.6)
    result["predictions"] = np.sum(model_preds_sum, axis=0).tolist()    
    if os.path.exists(tmp_dir) and os.path.isdir(tmp_dir):
        shutil.rmtree(tmp_dir)
    
    print(result)
    return result