In [1]:
import os
import json
import numpy as np
import argparse
import random
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.autograd import Variable
from sklearn.metrics import auc, roc_curve, roc_auc_score, f1_score, precision_recall_curve, average_precision_score, classification_report 
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, balanced_accuracy_score, accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import cv2
from scipy.stats.mstats import gmean
from custom_tranformations import RandomRotation
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import label_binarize
from utils import encode_onehot, custom_label_binarize

if torch.cuda.is_available():
    print(f'CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}')
else:
    print('CUDA is not available. Using CPU.')

CUDA is available. Using GPU: NVIDIA GeForce RTX 4090


In [2]:
################################################################################
#                             Global Configurations                             #
################################################################################

KFOLD_PATH = r"E:\Aamir Gulzar\dataset\splits\kfolds_IDARS.csv"
DATA_PATH = r"E:\Aamir Gulzar\dataset\patches"

# GPU memory limit (commented out in your code)
# torch.cuda.set_per_process_memory_fraction(0.9)

# Some global metrics placeholders
lmbda = 0.1
best_auc_v = 0
best_auc = 0
n_slides = 0
best_loss = 100000.
best_f1_v = 0.
best_Acc = 0.
best_ap_v = 0.

num_classes = 2
label_names = ['nonMSI', 'MSI']

In [3]:
parser = argparse.ArgumentParser(description='MSI And MSS Classification')
parser.add_argument('--data_lib', type=str, default=KFOLD_PATH, help='path to train ')
parser.add_argument('--val_lib', type=str, default=KFOLD_PATH, help='path to validation ')
parser.add_argument('--test_lib', type=str, default=KFOLD_PATH, help='path to validation ')
parser.add_argument('--problem', type=str, default='MSI_vs_MSS_T50R50', help='classification problem.')
parser.add_argument('--pos_label', type=int, default=1, help='positive label.')
parser.add_argument('--neg_label', type=int, default=0, help='negative label. If present.')
parser.add_argument('--output', type=str, default='CAIMAN_Fivecrop_4Folds', help='new settings suggested by Sir')
parser.add_argument('--folds', type=int, default=4, help='number of fold to execute')
parser.add_argument('--batch_size', type=int, default=512, help='mini-batch size (default: 128)') # larger batch size is better try 512
parser.add_argument('--nepochs', type=int, default=4, help='number of epochs')
parser.add_argument('--workers', default=0, type=int, help='number of data loading workers (default: 4)')
parser.add_argument('--test_every', default=1, type=int, help='test on val every (default: 10)')
parser.add_argument('--weights', default=0.5, type=float, help='unbalanced positive class weight (default: 0.5, balanced classes)')

parser.add_argument('--lr', type=float, default=0.001) # best 0.01
parser.add_argument('--l2_reg', type=float, default=5e-3)
parser.add_argument('--grad_bound', type=float, default=5.0)

parser.add_argument('--r', default=50, type=int, help='how many rand tiles to consider (default: 10)') # try 50 percent of tiles
parser.add_argument('--k', default=50, type=int, help='how many top k tiles to consider (default: 10)') # try 50 percent of tiles

parser.add_argument('--budget', type=float, default=0.8, metavar='N',
                    help='the budget for how often the network can get hints')

_StoreAction(option_strings=['--budget'], dest='budget', nargs=None, const=None, default=0.8, type=<class 'float'>, choices=None, required=False, help='the budget for how often the network can get hints', metavar='N')

In [4]:
################################################################################
#                     Aggregation Class (for tile-level merges)                #
################################################################################

class Aggregation:
    """
    This class holds all tile-level aggregation methods. 
    It can be configured via 'aggregation_config' to enable or disable certain
    aggregations or to conditionally save results to CSV, etc.
    """
    def __init__(self, config=None):
        """
        config: A dictionary to toggle aggregator methods 
                and saving behaviors. For example:
                {
                    'enable_binnedtopk': True,
                    'enable_compute_aggregated_probabilities': True,
                    'enable_compute_aggregated_predictions': True,
                    'enable_group_avg_df': True,
                    'save_csv': True
                }
        """
        if config is None:
            config = {}
        self.config = config

    @staticmethod
    def group_argtopk(groups, data, k=1):
        """
        Return the indices of the top-k elements by group.
        """
        order = np.lexsort((data, groups))
        groups = groups[order]
        data = data[order]
        index = np.empty(len(groups), 'bool')
        index[-k:] = True
        index[:-k] = groups[k:] != groups[:-k]
        return list(order[index])

    @staticmethod
    def group_max(groups, data, nmax):
        """
        Return the maximum value in each group. 
        nmax is the maximum group ID + 1, used to fill results in an array.
        """
        out = np.empty(nmax)
        out[:] = np.nan
        order = np.lexsort((data, groups))
        groups = groups[order]
        data = data[order]
        index = np.empty(len(groups), 'bool')
        index[-1] = True
        index[:-1] = groups[1:] != groups[:-1]
        out[groups[index]] = data[index]
        return out

    @staticmethod
    def group_avg(groups, data):
        """
        Compute mean of data for each unique group.
        """
        order = np.lexsort((data, groups))
        groups = groups[order]
        data = data[order]
        unames, idx, counts = np.unique(groups, return_inverse=True, return_counts=True)
        sum_pred = np.bincount(idx, weights=data)
        mean_pred = sum_pred / counts
        return mean_pred, sum_pred

    @staticmethod
    def get_binnedtopK_aggregation(group, data):
        """
        Aggregate predictions by taking the top-1, mean, top percentages, 
        and a custom top-10 threshold-based score.
        Returns topk_p and top10_sc:
          - topk_p: The aggregated top predictions per WSI (binned)
          - top10_sc: The special scoring mechanism for top 10 predictions 
                      above or below 0.5 threshold.
        """
        wsi_dict = {}
        for idx, g in enumerate(group):
            g_id = wsi_dict.get(g, -1)
            if g_id == -1:
                wsi_dict[g] = [data[idx]]
            else:
                temp_data = wsi_dict[g]
                temp_data.append(data[idx])
                wsi_dict[g] = temp_data

        topk_p = []  
        top10_sc = []  
        for each_wsi in wsi_dict.keys():
            aggregated = []
            wsi_predictions = wsi_dict[each_wsi]
            wsi_predictions = np.array(wsi_predictions, dtype='float64')
            wsi_predictions.sort()
            aggregated.append(wsi_predictions[-1])           # top-1
            aggregated.append(np.mean(wsi_predictions))      # average
            for k_in in [1, 4]:
                temp = wsi_predictions[-k_in:]
                aggregated.append(np.mean(temp))
            topk_p.append(np.mean(aggregated))

            # For the top 10 predictions
            top10_predictions = wsi_predictions[-10:]
            if sum(top10_predictions[top10_predictions > 0.5]) > 0:
                a = np.mean(top10_predictions[top10_predictions > 0.5])
            else:
                a = 0
            if sum(top10_predictions[top10_predictions < 0.5]) > 0:
                b = np.mean(top10_predictions[top10_predictions < 0.5])
            else:
                b = 0
            top10_sc.append((a + a + b) / 3)
        topk_p = np.array(topk_p, dtype='float64')
        top10_sc = np.array(top10_sc, dtype='float64')
        return topk_p, top10_sc

    @staticmethod
    def compute_aggregated_probabilities(group, data, k=10):
        """
        we need to use avg, max and top10
        Compute multiple forms of aggregated probabilities (avg, max, sum, 
        median, gmean, and top-half-mean) for each group.
        """
        wsi_dict = {}
        for idx, g in enumerate(group):
            g_id = wsi_dict.get(g, -1)
            if g_id == -1:
                wsi_dict[g] = [data[idx]]
            else:
                temp_data = wsi_dict[g]
                temp_data.append(data[idx])
                wsi_dict[g] = temp_data

        avg_p = []
        max_p = []
        sum_p = []
        md_p = []
        gm_p = []
        top_p = []

        for each_wsi in wsi_dict.keys():
            wsi_predictions = np.array(wsi_dict[each_wsi], dtype='float64')
            avg_p.append(np.mean(wsi_predictions))
            max_p.append(np.max(wsi_predictions))
            # sum_p.append(np.sum(wsi_predictions))
            # md = np.median(wsi_predictions)
            # md_p.append(md)
            # gm_p.append(gmean(wsi_predictions))
            # top_p.append(np.mean(wsi_predictions[wsi_predictions >= md]))
        avg_p = np.array(avg_p, dtype='float64')
        max_p = np.array(max_p, dtype='float64')
        # sum_p = np.array(sum_p, dtype='float64')
        # md_p = np.array(md_p, dtype='float64')
        # gm_p = np.array(gm_p, dtype='float64')
        # top_p = np.array(top_p, dtype='float64')
        return avg_p, max_p

    @staticmethod
    def compute_aggregated_predictions(group, data):
        """
        Compute majority-vote predictions for each WSI. 
        Returns:
          mv_pred: majority class for each WSI
          n_pred:  array of raw counts for each class
        """
        wsi_dict = {}
        for idx, g in enumerate(group):
            if g not in wsi_dict:
                wsi_dict[g] = [data[idx]]
            else:
                wsi_dict[g].append(data[idx])

        mv_pred = []
        n_pred = []
        for each_wsi in wsi_dict.keys():
            wsi_predictions = wsi_dict[each_wsi]
            wsi_pred_class = []
            for cl in range(num_classes):
                wsi_pred_class.append(wsi_predictions.count(cl))

            mj_vt = np.argmax(wsi_pred_class)
            mv_pred.append(mj_vt)
            n_pred.append(wsi_pred_class)

        n_pred = np.array(n_pred, dtype='float64')
        mv_pred = np.array(mv_pred, dtype='float64')
        return mv_pred, n_pred

    @staticmethod
    def group_avg_df(groups, data):
        """
        Group top-10 aggregator using pandas. 
        Returns the mean of nlargest(10) for each group.
        """
        df = pd.DataFrame({'Slide': groups, 'value': data})
        group_average_df = df.groupby('Slide')['value'].apply(lambda grp: grp.nlargest(10).mean())
        group_average = group_average_df.tolist()
        return group_average

    @staticmethod
    def get_topMedtraining(group, data):
        """
        Returns indices of tiles whose probabilities are >= the median 
        probability within each WSI group.
        """
        wsi_dict = {}
        for idx, g in enumerate(group):
            if g not in wsi_dict:
                wsi_dict[g] = [data[idx]]
            else:
                wsi_dict[g].append(data[idx])

        top_p = []
        for each_wsi in wsi_dict.keys():
            wsi_predictions = np.array(wsi_dict[each_wsi], dtype='float64')
            md = np.median(wsi_predictions)
            start_i = np.squeeze(np.argwhere(group == each_wsi))[0]
            indices = np.squeeze(np.argwhere(wsi_predictions >= md)) + start_i
            top_p.extend(indices)

        top_p = np.array(top_p, dtype='int64')
        return top_p

    @staticmethod
    def get_topKtraining(group, data, k=5):
        """
        Returns the indices of the top-k% (or top-k absolute if you want) 
        probabilities for each WSI group.
        """
        wsi_dict = {}
        for idx, g in enumerate(group):
            if g not in wsi_dict:
                wsi_dict[g] = [data[idx]]
            else:
                wsi_dict[g].append(data[idx])

        topk_p = []
        for each_wsi in wsi_dict.keys():
            wsi_predictions = np.array(wsi_dict[each_wsi], dtype='float64')
            start_i = np.squeeze(np.argwhere(group == each_wsi)[0])
            perc = int((k / 100) * len(wsi_predictions)) + 1
            topk = wsi_predictions.argsort()[-perc:]
            topk_p.extend(topk + start_i)

        topk_p = np.squeeze(np.array(topk_p, dtype='int64'))
        return topk_p

In [5]:
################################################################################
#                     Evaluation Class (metrics, plotting, etc.)               #
################################################################################

class Evaluation:
    """
    This class holds evaluation-related methods such as metrics calculation, 
    confusion matrices, classification reports, etc. 
    """
    def __init__(self, config=None):
        """
        config: A dictionary to toggle certain evaluation steps or CSV saving.
                For example:
                {
                    'enable_plot_metrics': True,
                    'enable_confusion_matrix': True,
                    'save_csv': True
                }
        """
        if config is None:
            config = {}
        self.config = config
    ############################################################################
    #                          Original Evaluation Methods                     #
    ############################################################################

    @staticmethod
    def cutoff_youdens_j(fpr, tpr, thresholds):
        """
        Return the cutoff threshold that maximizes Youden's J statistic (tpr - fpr).
        """
        j_scores = tpr - fpr
        j_ordered = sorted(zip(j_scores, thresholds))
        return j_ordered[-1][1]

    @staticmethod
    def cal_f1_score(targets, prediction, cutoff):
        """
        Calculate F1 score (weighted) for binary predictions 
        given a cutoff threshold.
        """
        prediction = np.array(prediction)
        targets = np.array(targets)
        f1score = f1_score(targets, prediction, average='weighted')
        return f1score

    @staticmethod
    def calc_metrics(target, prediction):
        """
        Calculate multiple metrics: 
        - AUC (from ROC), 
        - F1 (with Youden J threshold), 
        - Average Precision (PR AUC).
        """
        fpr, tpr, thresholds = roc_curve(target, prediction)
        cutoff = Evaluation.cutoff_youdens_j(fpr, tpr, thresholds)
        roc_auc = auc(fpr, tpr)
        f1score = Evaluation.cal_f1_score(target, prediction, cutoff)
        precision, recall, _ = precision_recall_curve(target, prediction, zero_division=1)
        average_precision = average_precision_score(target, prediction, zero_division=1)
        return f1score, average_precision, roc_auc, cutoff

    @staticmethod
    def calculate_accuracy(output, target):
        """
        Compute classification accuracy given output logits and ground truth target.
        """
        preds = output.max(1, keepdim=True)[1]
        correct = preds.eq(target.view_as(preds)).sum()
        acc = correct.float() / preds.shape[0]
        return acc
    
    @staticmethod
    def compute_auc(labels, predictions):
        # binary_labels = custom_label_binarize(np.array(labels), classes=[i for i in range(num_classes)])
        # auc_list = []
        # for cl in range(num_classes):
        #     fpr, tpr, thresholds = roc_curve(binary_labels[:, cl], predictions[:, cl])
        #     auc_list.append(auc(fpr, tpr))
        aucscore = roc_auc_score(labels, predictions[:, 1])
        return aucscore
    
    
    @staticmethod
    def plot_metrics(target, prediction, set):
        """
        Plot and save ROC and Precision-Recall curves.
        """
        import matplotlib.pyplot as plt  # local import to match old structure

        fpr, tpr, thresholds = roc_curve(target, prediction)
        roc_auc = auc(fpr, tpr)
        print('roc_auc is:', roc_auc)

        precision, recall, _ = precision_recall_curve(target, prediction, zero_division=1)
        average_precision = average_precision_score(target, prediction, zero_division=1)
        pr_auc = auc(recall, precision)
        print('Average precision-recall score: {0:0.2f}'.format(average_precision))

        plt.figure(figsize=(12, 4))
        lw = 2

        # Subplot 1: ROC
        plt.subplot(121)
        plt.plot(fpr, tpr, color='darkorange',
                 lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
        plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve, AUC={0:0.2f}'.format(roc_auc))
        plt.legend(loc="lower right")

        # Subplot 2: Precision-Recall
        plt.subplot(122)
        plt.step(recall, precision, alpha=0.4, color='darkorange', where='post')
        plt.fill_between(recall, precision, alpha=0.2, color='navy', step='post')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.ylim([0, 1.05])
        plt.xlim([0, 1])
        plt.title('Precision-Recall curve: AP={0:0.2f}'.format(average_precision))

        plt.savefig(os.path.join(args.output, 'roc_pr' + set + '.png'))
        plt.close(plt.gcf())

In [6]:
###############################################################################
#                            Model Definition                                 #
###############################################################################
import copy
class CNN(nn.Module):
    """
    A standard ResNet34-based feature extractor + classifier + confidence head,
    with no internal logic for multi-crop.
    """
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.model_resnet = models.resnet34(weights='DEFAULT')
        num_ftrs = self.model_resnet.fc.in_features
        self.model_resnet.fc = nn.Identity()  # remove the original final FC
        self.classifier = nn.Linear(num_ftrs, num_classes)
        self.conf = nn.Linear(num_ftrs, 1)

    def forward(self, x):
        """
        x => shape [N, 3, H, W]
        N can be any batch size (including B*5 if flattened outside).
        """
        features = self.model_resnet(x)        # => [N, num_ftrs]
        logit = self.classifier(features)      # => [N, num_classes]
        conf = self.conf(features)             # => [N, 1]
        return logit, conf

def targets_for_wsi_loss(tile_labels, wsi_ids, wsi, cl):
    tile_labels[tile_labels == cl] = -1
    tile_labels[wsi_ids == wsi] = cl
    return tile_labels

def encode_onehot_new(labels, n_classes):
    device = labels.device
    eye = torch.eye(n_classes, device=device)
    return eye[labels]

# from old CAIMAN implementation
def encode_onehot(labels, n_classes):
    onehot = torch.FloatTensor(labels.size()[0], n_classes)
    labels = labels.data
    if labels.is_cuda:
        onehot = onehot.cuda()
    onehot.zero_()
    onehot.scatter_(1, labels.view(-1, 1), 1)
    return onehot

def train(run, loader, model, criterion, optimizer):
    model.train()
    xentropy_loss_avg = 0.
    confidence_loss_avg = 0.
    slide_loss_avg = 0.
    running_loss = 0.
    running_acc = 0.

    for i, (inputs, templabels, wsi_ids) in enumerate(loader):
        # inputs => shape [batch_size, 5, 3, 224, 224]
        # templabels => shape [batch_size]
        # 1) Flatten the 5-crop dimension into the batch dimension
        tile_labels = copy.deepcopy(templabels)
        b, ncrops, c, h, w = inputs.shape      # e.g. (B, 5, 3, 224, 224)
        inputs = inputs.view(b * ncrops, c, h, w).cuda()  # => (B*5, 3, 224, 224)

        # Move labels to GPU
        templabels = templabels.cuda()

        optimizer.zero_grad()

        # 2) Forward pass (model sees a batch of size B*5)
        output, conf = model(inputs)           # => (B*5, num_classes), (B*5, 1)

        # 3) Reshape back into (B, 5, ...)
        output = output.view(b, ncrops, -1)    # => (B, 5, num_classes)
        conf = conf.view(b, ncrops, -1)        # => (B, 5, 1)

        # 4) Average over the 5 crops
        output = output.mean(dim=1)           # => (B, num_classes)
        conf = conf.mean(dim=1)               # => (B, 1)

        # 5) Proceed with the existing logic
        pred_original = F.softmax(output, dim=-1)
        confidence = torch.sigmoid(conf)

        eps = 1e-12
        pred_original = torch.clamp(pred_original, eps, 1. - eps)
        confidence = torch.clamp(confidence, eps, 1. - eps)

        # Convert to one-hot for your custom weighting logic
        labels_onehot = encode_onehot(templabels, num_classes)  # => (B, num_classes)

        # Weighted combination of predicted distribution and the one-hot label
        pred_new = pred_original * confidence + labels_onehot * (1 - confidence)
        pred_new = torch.log(pred_new)

        xentropy_loss = criterion(pred_new, templabels)
        confidence_loss = torch.mean(-torch.log(confidence))

        # Slide-level consistency loss (if needed, placeholder)
        slide_loss = 0.
        for cl in range(1, num_classes):
            wsi_ids_temp = wsi_ids[tile_labels==cl]            
            wsi_ids_unique = wsi_ids_temp.unique()
            if len(wsi_ids_unique) >= 1:
                for wsi in wsi_ids_unique:
                    target_group_tiles = targets_for_wsi_loss(tile_labels, wsi_ids, wsi, cl)
                    target_group_tiles = target_group_tiles.cuda()
                    sloss = criterion(pred_new, target_group_tiles)
                    slide_loss += sloss
                slide_loss = slide_loss / len(wsi_ids_unique)
                slide_loss_avg = slide_loss.item() * b

        total_loss = xentropy_loss + confidence_loss + slide_loss
        total_loss.backward()
        optimizer.step()

        # 6) Track stats
        running_loss += total_loss.item() * b
        acc = Evaluation.calculate_accuracy(output, templabels)
        running_acc += acc.item() * b

        xentropy_loss_avg += xentropy_loss.item() * b
        confidence_loss_avg += confidence_loss.item() * b

        if i % 100 == 0:
            print(
                "Train Epoch: [{:3d}/{:3d}] Batch: {:3d}/{:3d}, Loss: {:.4f}, acc: {:.2f}%, xent: {:.4f}, conf: {:.4f}, slide: {:.4f}".format(
                    run + 1, 
                    args.nepochs, 
                    i + 1, 
                    len(loader),
                    running_loss / ((i + 1) * b),
                    (100 * running_acc) / ((i + 1) * b),
                    xentropy_loss_avg / ((i + 1) * b),
                    confidence_loss_avg / ((i + 1) * b),
                    slide_loss_avg / ((i + 1) * b)
                )
            )

    # If you need to return predictions for some reason, adapt accordingly
    return running_loss / len(loader.dataset), running_acc / len(loader.dataset)


In [7]:
def inference(run, loader, model, criterion):
    model.eval()
    running_acc = 0.
    running_loss = 0.

    # We'll store per-sample probabilities, predictions, and confidence.
    # 'len(loader.dataset)' should be the total number of *images* (where each image has 5 crops).
    probs = torch.FloatTensor(len(loader.dataset), num_classes)
    preds = torch.FloatTensor(len(loader.dataset))
    confidence = torch.FloatTensor(len(loader.dataset))

    index_offset = 0  # --- 5-crop change ---
    with torch.no_grad():
        for i, (inputs, target, wsi_id) in enumerate(loader):
            """
            inputs => shape [batch_size, 5, 3, 224, 224]
            target => shape [batch_size]
            We want to flatten the 5 crops per sample into a single batch dimension,
            pass them to the model, then reshape back and average.
            """
            b, ncrops, c, h, w = inputs.shape  # e.g. (B, 5, 3, 224, 224)
            inputs = inputs.view(-1, c, h, w).cuda()  # (B*5, 3, 224, 224)
            target = target.cuda()

            # Forward pass: the model now treats this as a batch of B*5 single images
            output, conf = model(inputs)      # => (B*5, num_classes), (B*5, 1)

            # Reshape back to (B, 5, ...) and average
            output = output.view(b, ncrops, -1).mean(dim=1)  # => (B, num_classes)
            conf = conf.view(b, ncrops, -1).mean(dim=1)      # => (B, 1)

            # Compute loss and accuracy
            loss = criterion(output, target)
            acc = Evaluation.calculate_accuracy(output, target)

            # Softmax over classes
            y = F.softmax(output, dim=-1)         # => (B, num_classes)
            # Sigmoid confidence (already (B,1)), reshape to 1D
            conf = torch.sigmoid(conf).view(-1)   # => (B,)

            # Predictions
            pred_value, pred = torch.max(output.data, 1)
            #store the probabilities, predictions and confidence
            batch_size_now = b
            probs[index_offset:index_offset + batch_size_now] = y.detach().clone()
            preds[index_offset:index_offset + batch_size_now] = pred.detach().clone()
            confidence[index_offset:index_offset + batch_size_now] = conf.detach().clone()
            index_offset += batch_size_now

            running_loss += loss.item() * batch_size_now
            running_acc += acc.item() * batch_size_now

            if i % 100 == 0:
                print('Inference\tEpoch: [{:3d}/{:3d}]\tBatch: [{:3d}/{}]\t'
                      'Validation: Loss: {:.4f}, acc: {:0.2f}%'.format(
                    run + 1, 
                    args.nepochs, 
                    i + 1, 
                    len(loader),
                    running_loss / ((i + 1) * batch_size_now),
                    (100. * running_acc) / ((i + 1) * batch_size_now) 
                ))

        # Print some confidence stats
        confidence_np = confidence.cpu().numpy()
        print('Confidence\tMin: {:0.4f}\tAverage: {:.4f}\tMax: {:0.4f}'.format(
            np.min(confidence_np), np.mean(confidence_np), np.max(confidence_np)
        ))

    return (
        probs.cpu().numpy(),
        running_loss / len(loader.dataset),
        running_acc / len(loader.dataset),
        preds.cpu().numpy(),
        confidence_np
    )


In [8]:
################################################################################
#                                Dataset Class                                 #
################################################################################
def get_split_indices(lib, test_fold, val_fold):
    """
    lib: DataFrame with a column named 'fold' 
         that has values in [0, 1, 2, 3].
    test_fold, val_fold: integers from 0..3
    returns: three lists => train_idx, val_idx, test_idx
    """
    # Extract all folds
    all_folds = lib['fold'].values.tolist()

    test_idx = [i for i, f in enumerate(all_folds) if f == test_fold]
    val_idx  = [i for i, f in enumerate(all_folds) if f == val_fold]
    train_idx = [i for i, f in enumerate(all_folds)
                 if f not in (test_fold, val_fold)]
    
    return train_idx, val_idx, test_idx


class MILdataset(data.Dataset):
    """
    A custom Dataset to handle multiple WSI tiles from a library CSV.
    Each tile is read at 512×512, then we apply a transform that does
    FiveCrop(224). This results in a 5-crop for each tile.
    """
    def __init__(self, libraryfile=KFOLD_PATH, transform=None,mult=2, s=10, shuffle=False, set='', test_fold=None, val_fold=None):
        path_dir = DATA_PATH
        lib = pd.DataFrame(pd.read_csv(libraryfile, usecols=[
            'Case_ID', 'WSI_Dir', 'label_desc', 'label_id', f'fold'
        ]))
        lib.dropna(inplace=True)

        allcases = lib['Case_ID'].values.tolist()
        allslides = lib['WSI_Dir'].values.tolist()
        tar = lib['label_id'].values.tolist()
        label_desc = lib['label_desc'].values.tolist()
        split = lib[f'fold'].values.tolist()
        
        train_idx, val_idx, test_idx = get_split_indices(lib, test_fold, val_fold)
        if set == 'train':
            indices = train_idx
            print(f"##--Targets ==> {len(tar)} | {tar.count(0)} | {tar.count(1)} --##")
            thresh_tiles = 4
        elif set == 'valid':
            indices = val_idx
            print(f"##--Val Split ==> {len(indices)} Slides")
            thresh_tiles = 4
        elif set == 'test':
            indices = test_idx
            print(f"##--Test Split ==> {len(indices)} Slides")
            thresh_tiles = 4
        else:
            raise ValueError("Invalid set_type. Must be 'train', 'valid', or 'test'.")

        cases = []
        tiles = []
        ntiles = []
        slideIDX = []
        targetIDX = []
        targets = []
        label_desciption = []
        slides = []

        j = 0
        for i in indices:
            path = os.path.join(path_dir, str(allslides[i]))
            slide_label = int(tar[i])
            if os.path.exists(path):
                Max_Patches = 10
                t = [
                    os.path.join(path, f)
                    for f in os.listdir(path)
                    if f.endswith('.png')
                ]
                if len(t) >= thresh_tiles:
                    cases.append(allcases[i])
                    slides.append(allslides[i])
                    tiles.extend(t)
                    ntiles.append(len(t))
                    slideIDX.extend([j]*len(t))
                    targetIDX.extend([slide_label]*len(t))
                    targets.append(slide_label)
                    label_desciption.append(label_desc[i])
                    j += 1

        self.slideIDX = slideIDX
        self.ntiles = ntiles
        self.tiles = tiles
        self.targets = targets
        self.label_desc = label_desciption
        self.slides = slides
        self.cases = cases
        self.transform = transform
        self.mult = mult
        self.s = s
        self.mode = None
        self.shuffle = shuffle
        self.targetIDX = targetIDX

        print('-------------------------')
        print('Number of Slides: {}'.format(len(slides)))
        print('Number of tiles: {}'.format(len(tiles)))
        if len(ntiles) > 0:
            print('Max tiles: ', max(ntiles))
            print('Min tiles: ', min(ntiles))
            print('Average tiles: ', np.mean(ntiles))
        print('nonMSI: ', targets.count(0))
        print('MSI: ', targets.count(1))

    def setmode(self, mode):
        """Set self.mode to 1 or 2 depending on training iteration usage."""
        self.mode = mode

    def maketraindata(self, idxs):
        self.t_data = [(self.slideIDX[x], self.tiles[x], self.targets[self.slideIDX[x]]) 
                       for x in idxs]

    def shuffletraindata(self):
        self.t_data = random.sample(self.t_data, len(self.t_data))

    def __getitem__(self, index):
        """
        If mode=1, we sample from the full tile list; if mode=2, from t_data.
        Each tile is loaded, optionally transformed, which includes FiveCrop.
        So the transform outputs shape [5, 3, 224, 224].
        """
        if self.mode == 1:
            tile_path = self.tiles[index]
            slide_idx = self.slideIDX[index]
            img = Image.open(tile_path).convert('RGB')
            if self.transform is not None:
                img = self.transform(img)  # => shape (5, 3, 224, 224)
            target = self.targets[slide_idx]
            return img, target, slide_idx

        elif self.mode == 2:
            slideIDX, tile_path, target = self.t_data[index]
            img = Image.open(tile_path).convert('RGB')
            if self.transform is not None:
                img = self.transform(img)  # => shape (5, 3, 224, 224)
            return img, target, slideIDX

    def __len__(self):
        if self.mode == 1:
            return len(self.tiles)
        elif self.mode == 2:
            return len(self.t_data)

In [9]:
###############################################################################
#                               Main Training                                 #
###############################################################################
def run_train_step(epoch,data_indices, dataset,loader,model, criterion, optimizer, scheduler, lmbda,args):
    """
    1) Sets the dataset mode
    2) Loads the selected data indices (tiles) into the dataset e.g random or topk
    3) Trains for one epoch
    4) Logs the result
    """
    # 1) Prepare data
    dataset.setmode(1)
    dataset.maketraindata(data_indices)
    dataset.shuffletraindata()
    dataset.setmode(2)
    # 2) Train
    loss, acc = train(epoch, loader, model, criterion, optimizer)
    # 3) Print & log
    print('Training Epoch: [{}/{}]\tLoss: {:0.4f}\tAccuracy: {:3d}%'.format(
            epoch + 1, args.nepochs, loss, int(acc * 100)))
    log_file = os.path.join(args.output, 'train_convergence.csv')
    with open(log_file, 'a') as fconv:
        fconv.write('{},{:0.4f},{:3d}\n'.format(epoch, loss, int(acc * 100)))
    return loss, acc
def update_topk(epoch,data_indices,dataset,loader,model,criterion,topk_list=None,invert=False):
    """
    1) Performs inference on the specified data_indices
    2) Computes top_prob = 1 - prob[:, 0] (or some variant)
    3) Collects top-50% or top-k% tiles
    4) Optionally merges with an existing 'topk_list'
    Returns:
        updated_topk (list): new or merged topk
    """
    # 1) Inference
    dataset.maketraindata(data_indices)
    dataset.setmode(1)
    trn_probs, _, _, _, trn_conf = inference(epoch, loader, model, criterion)
    # shape = [n_tiles, n_classes], and 'trn_conf' is [n_tiles]
    # 2) Probability adjustment
    trn_probs_conf = np.transpose(trn_probs.transpose() * trn_conf)  # same line as original
    top_prob = 1. - trn_probs_conf[:, 0] if not invert else 1. - (1. - trn_probs_conf[:, 0])
    # 3) Gather topk
    slide_idx = [dataset.slideIDX[item] for item in data_indices]
    local_topk = Aggregation.get_topKtraining(np.array(slide_idx), top_prob, args.k)
    local_topk = np.array(local_topk, dtype='int64')
    # Convert from local index to the overall dataset index
    local_topk = [data_indices[item] for item in local_topk]
    print(f'Number of topk found: {len(local_topk)}')
    # 4) Merge with any existing topk
    if topk_list is None:
        updated_topk = local_topk
    else:
        updated_topk = list(set(topk_list + local_topk))
    return updated_topk

class StackAndNormalize:
    def __init__(self, normalize):
        self.normalize = normalize

    def __call__(self, crops):
        return torch.stack([
            self.normalize(transforms.ToTensor()(crop)) 
            for crop in crops
        ])

def main():
    global args, best_auc_v, tr_batch_size, n_slides, best_auc, best_f1_v

    config = {
        'aggregation_methods': ['average', 'top10', 'max'],
        'save_val_tile_csv': True,
        'save_test_tile_csv': True,
        'save_val_slide_csv': True,
        'save_test_slide_csv': True,
        'enable_group_avg_df': True,
        'enable_plot_metrics': True,
        'enable_confusion_matrix': True
    }
    aggregator = Aggregation(config)
    evaluator = Evaluation(config)
    
    args = parser.parse_args(args=[])
    if not os.path.exists(args.output):
        os.mkdir(args.output)
    args.output = os.path.join(args.output, args.problem)
    if not os.path.exists(args.output):
        os.mkdir(args.output)
    temp_output = args.output
# --- 5-crop change: define transforms so we return 5-crops (224×224) for each 512×512 image
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                        std=[0.1, 0.1, 0.1])
    trans = transforms.Compose([
        RandomRotation([0, 90, 180, 270]),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05),
        transforms.RandomAdjustSharpness(sharpness_factor=2),
        transforms.FiveCrop(224),
        StackAndNormalize(normalize)
    ])

    trans_Valid = transforms.Compose([
        transforms.FiveCrop(224),
        StackAndNormalize(normalize)])

    for fold in range(args.folds):
        test_fold = fold + 1
        val_fold  = ((fold + 1) % 4) + 1
        args.output = os.path.join(temp_output, 'fold' + str(fold + 1))
        if not os.path.exists(args.output):
            os.mkdir(args.output)
        path_fold = args.output

        for sets in range(1):
            torch.cuda.empty_cache()
            args.output = os.path.join(path_fold, 'best' + str(sets))
            if not os.path.exists(args.output):
                os.mkdir(args.output)
            global best_auc_v, best_auc, n_slides, best_loss, best_f1_v, best_Acc, best_ap_v
            # reset
            best_auc_v = 0
            best_auc = 0
            n_slides = 0
            best_loss = 100000.
            best_f1_v = 0.
            best_Acc = 0.
            best_ap_v = 0.

            model = CNN(num_classes=num_classes)
            model.cuda()

            optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
            scheduler = []
            cudnn.benchmark = True
            criterion = nn.NLLLoss(ignore_index=-1).cuda()
            # --- end 5-crop changes
            # Load Data
            train_dset = MILdataset(args.data_lib, trans, set='train',  test_fold=test_fold, val_fold=val_fold)
            train_loader = torch.utils.data.DataLoader(
                train_dset,
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)

            val_dset = MILdataset(args.data_lib, trans_Valid, set='valid',  test_fold=test_fold, val_fold=val_fold)
            val_loader = torch.utils.data.DataLoader(
                val_dset,
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)

            test_dset = MILdataset(args.data_lib, trans_Valid, set='test',  test_fold=test_fold, val_fold=val_fold)
            test_loader = torch.utils.data.DataLoader(
                test_dset,
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)

            #  if you wanted to see only dataset distribution and data loaders etc then put continue here
            # continue
            fconv = open(os.path.join(args.output, 'train_convergence.csv'), 'w')
            fconv.write('epoch,loss,accuracy\n')
            fconv.close()

            fconv = open(os.path.join(args.output, 'valid_convergence.csv'), 'w')
            # We removed sum, median, gmean columns; only keep relevant columns
            fconv.write('epoch,tile_loss,tile_acc,best_F1,F1_AVG,F1_Max,F1_T10,best_BAcc,Bacc_AVG,Bacc_Max,Bacc_T10,best_AUC,Avg_AUC,Max_AUC,Top_AUC\n')
            fconv.close()

            num_tiles = len(train_dset.slideIDX)
            n_slides = len(train_dset.slides)
            print(f'Number of slides: {n_slides} | Number of tiles: {num_tiles}')
            # 1) Generate initial random selection
            r0 = np.random.rand(num_tiles)
            randk = aggregator.get_topKtraining(np.array(train_dset.slideIDX), r0, args.r)
            print('Number of randk1:', len(randk))
            # ---- Epoch 0 Training
            loss0, acc0 = run_train_step(epoch=0,data_indices=randk,dataset=train_dset,loader=train_loader,model=model,criterion=criterion,optimizer=optimizer,scheduler=scheduler,lmbda=lmbda,args=args)
            # 2) Inference + top prob => local topk
            topk = update_topk(epoch=0,data_indices=randk,dataset=train_dset,loader=train_loader,model=model,criterion=criterion,topk_list=None,invert=False)
            print('Number of topk1:', len(topk))

            # 3) Another random selection
            randk2 = aggregator.get_topKtraining(np.array(train_dset.slideIDX), 1 - r0, args.r)
            print('Number of randk2:', len(randk2))
            # ---- Epoch 1 Training
            loss1, acc1 = run_train_step(epoch=1,data_indices=randk2,dataset=train_dset,loader=train_loader,model=model,criterion=criterion,optimizer=optimizer,scheduler=scheduler,lmbda=lmbda,args=args)
            # 4) Inference + top prob => local topk, merge with existing topk
            temp = update_topk(epoch=1,data_indices=randk2,dataset=train_dset,loader=train_loader,model=model,criterion=criterion,topk_list=None)
            print('Number of topk from second update:', len(temp))
            print(f' these two should be the same {len(topk)} == {len(temp)}')
            topk = list(set(topk + temp))
            print('Overall topk after 2 iterations:', len(topk))

            # 5) Continue training for remaining epochs
            for epoch in range(2, args.nepochs):
                randk = list(aggregator.get_topKtraining(np.array(train_dset.slideIDX), 
                                                    np.random.rand(num_tiles), args.r))
                print('Number of randk k', len(randk))
                randk = list(set(randk + topk))
                print('Number of randk and topk combined ', len(randk))
                train_dset.setmode(1)
                train_dset.maketraindata(randk)
                train_dset.shuffletraindata()
                train_dset.setmode(2)
                loss, acc = train(epoch, train_loader, model, criterion, optimizer)
                print('Training\tEpoch: [{}/{}]\tLoss: {:0.4f}\tAccuracy: {:3d}'
                      .format(epoch+1, args.nepochs, loss, int(acc * 100)))

                fconv = open(os.path.join(args.output, 'train_convergence.csv'), 'a')
                fconv.write('{},{:0.4f},{:3d}\n'.format(epoch, loss, int(acc * 100)))
                fconv.close()

                # if there are no of epochs left to train then update topk
                if (epoch + 1) < args.nepochs:
                    trn_probs, _, _, _, trn_conf = inference(epoch, train_loader, model, criterion)
                    trn_probs_conf = np.transpose(trn_probs.transpose() * trn_conf)                
                    top_prob = 1.-trn_probs_conf[:,0]
                    slide_idx = [train_dset.slideIDX[item] for item in randk]
                    local_topk = aggregator.get_topKtraining(np.array(slide_idx), top_prob, args.k)
                    local_topk = [randk[item] for item in local_topk]
                    topk = list(set(topk + local_topk))
                    print(f'Number of topk {len(topk)} in epoch {epoch}')

                # Validation
                if (epoch + 1) % args.test_every == 0:
                    val_dset.setmode(1)
                    val_probs, val_loss, val_acc, val_preds,val_conf = inference(epoch, val_loader, model, criterion)
                    val_probs = np.transpose(val_probs.transpose() * val_conf)
                    val_slide_mjvt, _ = aggregator.compute_aggregated_predictions(
                        np.array(val_dset.slideIDX), val_preds
                    )

                    # We only keep avg, max, top10
                    val_slide_avg = []
                    val_slide_max = []
                    val_slide_avgt10 = []

                    # We'll store each class aggregator
                    for cl in range(num_classes):
                        t_avg, t_max = aggregator.compute_aggregated_probabilities(
                            np.array(val_dset.slideIDX), val_probs[:, cl]
                        )
                        t_t10 = aggregator.group_avg_df(
                            np.array(val_dset.slideIDX), val_probs[:, cl]
                        )
                        val_slide_avg.append(t_avg)
                        val_slide_max.append(t_max)
                        val_slide_avgt10.append(t_t10)

                    val_slide_avg = np.array(val_slide_avg).transpose()
                    val_slide_max = np.array(val_slide_max).transpose()
                    val_slide_avgt10 = np.array(val_slide_avgt10).transpose()

                    val_slide_avg_m = np.argmax(val_slide_avg, axis=1)
                    val_slide_max_m = np.argmax(val_slide_max, axis=1)
                    val_slide_avgt10_m = np.argmax(val_slide_avgt10, axis=1)

                    from sklearn.metrics import f1_score
                    f1_mv = f1_score(val_dset.targets, val_slide_mjvt, average='macro')
                    f1_avg = f1_score(val_dset.targets, val_slide_avg_m, average='macro')
                    f1_max = f1_score(val_dset.targets, val_slide_max_m, average='macro')
                    f1_a10 = f1_score(val_dset.targets, val_slide_avgt10_m, average='macro')
                    bacc_mv = balanced_accuracy_score(val_dset.targets, val_slide_mjvt)
                    bacc_avg = balanced_accuracy_score(val_dset.targets, val_slide_avg_m)
                    bacc_max = balanced_accuracy_score(val_dset.targets, val_slide_max_m)
                    bacc_avgt10 = balanced_accuracy_score(val_dset.targets, val_slide_avgt10_m)
                    
                    auc_val_avg = evaluator.compute_auc(val_dset.targets, val_slide_avg)
                    auc_val_max = evaluator.compute_auc(val_dset.targets, val_slide_max)
                    auc_val_top10 = evaluator.compute_auc(val_dset.targets, val_slide_avgt10)

                    fconv = open(os.path.join(args.output, 'valid_convergence.csv'), 'a')
                    fconv.write('{},{:0.4f},{:3d},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f}\n'.format(
                        epoch, val_loss, int(val_acc * 100),
                        max(f1_mv,f1_avg, f1_max, f1_a10),
                        f1_avg, f1_max, f1_a10,
                        # save max and individual balanced accuracies
                        max(bacc_mv, bacc_avg, bacc_max, bacc_avgt10),
                        bacc_avg, bacc_max, bacc_avgt10,
                        # also save best auc values
                        max(auc_val_avg, auc_val_max, auc_val_top10),
                        auc_val_avg, auc_val_max, auc_val_top10
                    ))
                    fconv.close()

                    print("\n----------------------Validation Results -------------------------------------")
                    print(f"F1-scores (Val): MV={f1_mv:.3f}  AVG={f1_avg:.3f}  MAX={f1_max:.3f}  top10={f1_a10:.3f}")
                    print(f"Balanced Acc (Val): MV={bacc_mv:.3f}  AVG={bacc_avg:.3f}  MAX={bacc_max:.3f}  top10={bacc_avgt10:.3f}")
                    print(f"AUC-scores (Val): AVG={auc_val_avg:.3f}  MAX={auc_val_max:.3f}  top10={auc_val_top10:.3f}")
                    print("------------------------------------------------------------------\n")
                    
                    best_f1_candidate = max(f1_mv, f1_avg, f1_max, f1_a10)
                    if best_f1_candidate > best_f1_v:
                        best_f1_v = best_f1_candidate
                        obj = {
                            'epoch': epoch + 1,
                            'state_dict': model.state_dict(),
                            'best_ap_v': best_f1_v,
                            'best_auc_v': max(auc_val_max, auc_val_top10, auc_val_avg),
                            'optimizer': optimizer.state_dict()
                        }
                        torch.save(obj, os.path.join(args.output, 'checkpoint_best_F1.pth'))
                        print("Saved checkpoint_best_F1.pth")
                        # save the tile level predictions for validation slides columns slideidx, prob non msi, prob msi and pred tile
                        df2_f1 = pd.DataFrame({
                            'slideidx': val_dset.slideIDX,
                            'nonMSI_prob': val_probs[:, 0],
                            'MSI_prob': val_probs[:, 1],
                            'pred_tile': val_preds
                        })
                        df2_f1.to_csv(os.path.join(args.output, 'val_tile_pred_F1.csv'), index=False)

                    best_auc_candidate = max(auc_val_avg, auc_val_max, auc_val_top10)
                    if best_auc_candidate > best_auc_v:
                        best_auc_v = best_auc_candidate
                        obj = {
                            'epoch': epoch + 1,
                            'state_dict': model.state_dict(),
                            'best_ap_v': best_f1_v,
                            'best_auc_v': best_auc_v,
                            'optimizer': optimizer.state_dict()
                        }
                        torch.save(obj, os.path.join(args.output, 'checkpoint_best_AUC.pth'))
                        print("Saved checkpoint_best_AUC.pth")

                        # save the tile level predictions for validation slides columns slideidx, prob non msi, prob msi and pred tile
                        df2_auc = pd.DataFrame({
                            'slideidx': val_dset.slideIDX,
                            'nonMSI_prob': val_probs[:, 0],
                            'MSI_prob': val_probs[:, 1],
                            'pred_tile': val_preds
                        })
                        df2_auc.to_csv(os.path.join(args.output, 'val_tile_pred_AUC.csv'), index=False)

            # Save ground truth for validation slides
            df1 = pd.DataFrame({
                'Case_ID': val_dset.cases,
                'WSI_Id': val_dset.slides,
                'n_tiles': val_dset.ntiles,
                'label_desc': val_dset.label_desc,
                'label_id': val_dset.targets
            })
            df1.to_csv(os.path.join(args.output, 'val_GT.csv'), index=False)
            df2_f1.to_csv(os.path.join(args.output, 'val_tile_pred_F1.csv'), index=False)
            df2_auc.to_csv(os.path.join(args.output, 'val_tile_pred_AUC.csv'), index=False)
            ############## Test ##############
            ch = torch.load(os.path.join(args.output, 'checkpoint_best_AUC.pth'))
            model.load_state_dict(ch['state_dict'])

            test_dset.setmode(1)
            test_probs, test_loss, test_acc, test_preds,test_conf = inference(epoch, test_loader, model, criterion)
            test_probs = np.transpose(test_probs.transpose() * test_conf)
            print(f'Predicted Tiles: {len(test_probs[:, 1])}, classes: {len(test_probs[0, :])}, total targets: {len(test_dset.targets)}')

            test_slide_mjvt, _ = aggregator.compute_aggregated_predictions(
                np.array(test_dset.slideIDX), test_preds
            )
            test_slide_avg = []
            test_slide_max = []
            test_slide_avgt10 = []
            for cl in range(num_classes):
                t_avg, t_max = aggregator.compute_aggregated_probabilities(
                    np.array(test_dset.slideIDX), test_probs[:, cl]
                )
                t_t10 = aggregator.group_avg_df(
                    np.array(test_dset.slideIDX), test_probs[:, cl]
                )
                test_slide_avg.append(t_avg)
                test_slide_max.append(t_max)
                test_slide_avgt10.append(t_t10)

            test_slide_avg = np.array(test_slide_avg).transpose()
            test_slide_max = np.array(test_slide_max).transpose()
            test_slide_avgt10 = np.array(test_slide_avgt10).transpose()

            test_slide_avg_m = np.argmax(test_slide_avg, axis=1)
            test_slide_max_m = np.argmax(test_slide_max, axis=1)
            test_slide_avgt10_m = np.argmax(test_slide_avgt10, axis=1)

            f1_mv = f1_score(test_dset.targets, test_slide_mjvt, average='macro')
            f1_avg = f1_score(test_dset.targets, test_slide_avg_m, average='macro')
            f1_max = f1_score(test_dset.targets, test_slide_max_m, average='macro')
            f1_t10 = f1_score(test_dset.targets, test_slide_avgt10_m, average='macro')
            # balanced accuracy
            bacc_mv = balanced_accuracy_score(test_dset.targets, test_slide_mjvt)
            bacc_avg = balanced_accuracy_score(test_dset.targets, test_slide_avg_m)
            bacc_max = balanced_accuracy_score(test_dset.targets, test_slide_max_m)
            bacc_avgt10 = balanced_accuracy_score(test_dset.targets, test_slide_avgt10_m)
            # AUC
            auc_test_avg = evaluator.compute_auc(test_dset.targets, test_slide_avg)
            auc_test_max = evaluator.compute_auc(test_dset.targets, test_slide_max)
            auc_test_top10 = evaluator.compute_auc(test_dset.targets, test_slide_avgt10)
            print("\n----------------------Test Set Results-------------------------------")
            print(f"Test F1-scores: MV={f1_mv:.3f}, AVG={f1_avg:.3f}, MAX={f1_max:.3f}, top10={f1_t10:.3f}")
            print(f"Test Balanced Acc: MV={bacc_mv:.3f}, AVG={bacc_avg:.3f}, MAX={bacc_max:.3f}, top10={bacc_avgt10:.3f}")
            print(f"Test AUC-scores: AVG={auc_test_avg:.3f}, MAX={auc_test_max:.3f}, top10={auc_test_top10:.3f}")
            # confusion matrix of average predictions only
            print("Confusion Matrix (Average Predictions):")
            print(confusion_matrix(test_dset.targets, test_slide_avg_m))
            # classification report of average predictions only
            print("Classification Report (Average Predictions):")
            print(classification_report(test_dset.targets, test_slide_avg_m))
            print("------------------------------------------------------\n")
            # save the ground truth for test slides
            df1 = pd.DataFrame({
                'Case_ID': test_dset.cases,
                'WSI_Id': test_dset.slides,
                'n_tiles': test_dset.ntiles,
                'label_desc': test_dset.label_desc,
                'label_id': test_dset.targets
            })
            df1.to_csv(os.path.join(args.output, 'test_GT.csv'), index=False)
            # save the average predictions of the test slides
            df2 = pd.DataFrame({
                'wsi_id': test_dset.slides,
                'nonMSI_prob': test_slide_avg[:, 0],
                'MSI_prob': test_slide_avg[:, 1]
            })
            df2.to_csv(os.path.join(args.output, 'test_pred_avg_AUC.csv'), index=False)
            # save the max predictions of the test slides
            df3 = pd.DataFrame({
                'wsi_id': test_dset.slides,
                'nonMSI_prob': test_slide_max[:, 0],
                'MSI_prob': test_slide_max[:, 1]
            })
            df3.to_csv(os.path.join(args.output, 'test_pred_max_AUC.csv'), index=False)
            # save the top10 predictions of the test slides
            df4 = pd.DataFrame({
                'wsi_id': test_dset.slides,
                'nonMSI_prob': test_slide_avgt10[:, 0],
                'MSI_prob': test_slide_avgt10[:, 1]
            })
            df4.to_csv(os.path.join(args.output, 'test_pred_t10_AUC.csv'), index=False)
            # save original tile level predictions for test slides
            df5 = pd.DataFrame({
                'slideidx': test_dset.slideIDX,
                'nonMSI_prob': test_probs[:, 0],
                'MSI_prob': test_probs[:, 1],
                'pred_tile': test_preds
            })
            df5.to_csv(os.path.join(args.output, 'test_tile_pred_AUC.csv'), index=False)
            
            print('..............Test set done ...............')

            print("Done setting up 5-crop pipeline for fold ", fold + 1)

if __name__ == '__main__':
    main()
    torch.cuda.empty_cache()

##--Targets ==> 405 | 344 | 61 --##
-------------------------
Number of Slides: 211
Number of tiles: 131716
Max tiles:  1875
Min tiles:  16
Average tiles:  624.2464454976304
nonMSI:  178
MSI:  33
##--Val Split ==> 94 Slides
-------------------------
Number of Slides: 94
Number of tiles: 55618
Max tiles:  1682
Min tiles:  6
Average tiles:  591.6808510638298
nonMSI:  81
MSI:  13
##--Test Split ==> 100 Slides
-------------------------
Number of Slides: 100
Number of tiles: 55849
Max tiles:  1942
Min tiles:  30
Average tiles:  558.49
nonMSI:  85
MSI:  15
Number of slides: 211 | Number of tiles: 131716
Number of randk1: 66011
Train Epoch: [  1/  4] Batch:   1/129, Loss: 1.3619, acc: 73.44%, xent: 0.1702, conf: 1.0372, slide: 0.1544
Train Epoch: [  1/  4] Batch: 101/129, Loss: 0.4907, acc: 83.99%, xent: 0.2563, conf: 0.1903, slide: 0.0002
Training Epoch: [1/4]	Loss: 0.4645	Accuracy:  84%
Inference	Epoch: [  1/  4]	Batch: [  1/258]	Validation: Loss: -2.0527, acc: 94.53%
Inference	Epoch: [  1/

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


##--Targets ==> 405 | 344 | 61 --##
-------------------------
Number of Slides: 194
Number of tiles: 111467
Max tiles:  1942
Min tiles:  6
Average tiles:  574.5721649484536
nonMSI:  166
MSI:  28
##--Val Split ==> 100 Slides
-------------------------
Number of Slides: 100
Number of tiles: 64050
Max tiles:  1875
Min tiles:  18
Average tiles:  640.5
nonMSI:  87
MSI:  13
##--Test Split ==> 111 Slides
-------------------------
Number of Slides: 111
Number of tiles: 67666
Max tiles:  1849
Min tiles:  16
Average tiles:  609.6036036036036
nonMSI:  91
MSI:  20
Number of slides: 194 | Number of tiles: 111467
Number of randk1: 55885
Train Epoch: [  1/  4] Batch:   1/110, Loss: 1.5152, acc: 19.53%, xent: 0.2657, conf: 0.9665, slide: 0.2831
Train Epoch: [  1/  4] Batch: 101/110, Loss: 0.4563, acc: 84.67%, xent: 0.2184, conf: 0.1917, slide: 0.0002
Training Epoch: [1/4]	Loss: 0.4459	Accuracy:  85%
Inference	Epoch: [  1/  4]	Batch: [  1/218]	Validation: Loss: -1.2871, acc: 78.32%
Inference	Epoch: [  1