This is the finalised IDARS pipeline using the FiveCrop transformations

In [1]:
import os
import json
import numpy as np
import argparse
import random
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.autograd import Variable
from sklearn.metrics import auc, roc_curve,roc_auc_score, f1_score, precision_recall_curve, average_precision_score, classification_report 
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, balanced_accuracy_score, accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import cv2
from scipy.stats.mstats import gmean
from custom_tranformations import RandomRotation
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import label_binarize
from utils import encode_onehot, custom_label_binarize
if torch.cuda.is_available():
    print(f'CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}')
else:
    print('CUDA is not available. Using CPU.')
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()

CUDA is available. Using GPU: NVIDIA GeForce RTX 4090


In [2]:
################################################################################
#                             Global Configurations                             #
################################################################################

KFOLD_PATH = r"E:\Aamir Gulzar\dataset\splits\kfolds_IDARS.csv"
DATA_PATH = r"E:\Aamir Gulzar\dataset\patches"

# GPU memory limit (commented out in your code)
# torch.cuda.set_per_process_memory_fraction(0.9)

# Some global metrics placeholders
lmbda = 0.1
best_auc_v = 0
best_auc = 0
n_slides = 0
best_loss = 100000.
best_f1_v = 0.
best_Acc = 0.
best_ap_v = 0.

num_classes = 2
label_names = ['nonMSI', 'MSI']

In [3]:
parser = argparse.ArgumentParser(description='MSI And MSS Classification')
parser.add_argument('--data_lib', type=str, default=KFOLD_PATH, help='path to train ')
parser.add_argument('--val_lib', type=str, default=KFOLD_PATH, help='path to validation ')
parser.add_argument('--test_lib', type=str, default=KFOLD_PATH, help='path to validation ')
parser.add_argument('--problem', type=str, default='MSI_vs_MSS', help='classification problem.')
parser.add_argument('--pos_label', type=int, default=1, help='positive label.')
parser.add_argument('--neg_label', type=int, default=0, help='negative label. If present.')
parser.add_argument('--output', type=str, default='Baseline_Fivecrop_4Folds', help='now i am using t=1 and r=10')
parser.add_argument('--folds', type=int, default=4, help='number of fold to execute')
parser.add_argument('--batch_size', type=int, default=512, help='mini-batch size (default: 128)')
parser.add_argument('--nepochs', type=int, default=4, help='number of epochs')
parser.add_argument('--workers', default=0, type=int, help='number of data loading workers (default: 4)')
parser.add_argument('--test_every', default=1, type=int, help='test on val every (default: 10)')
parser.add_argument('--weights', default=0.5, type=float, help='unbalanced positive class weight (default: 0.5, balanced classes)')

parser.add_argument('--lr', type=float, default=0.001) # best 0.01
parser.add_argument('--l2_reg', type=float, default=5e-3)
parser.add_argument('--grad_bound', type=float, default=5.0)

parser.add_argument('--r', default=10, type=int, help='how many rand tiles to consider (default: 10)')
parser.add_argument('--k', default=1, type=int, help='how many top k tiles to consider (default: 10)')

parser.add_argument('--budget', type=float, default=0.8, metavar='N',
                    help='the budget for how often the network can get hints')

_StoreAction(option_strings=['--budget'], dest='budget', nargs=None, const=None, default=0.8, type=<class 'float'>, choices=None, required=False, help='the budget for how often the network can get hints', metavar='N')

In [4]:
################################################################################
#                     Aggregation Class (for tile-level merges)                #
################################################################################

class Aggregation:
    """
    This class holds all tile-level aggregation methods. 
    It can be configured via 'aggregation_config' to enable or disable certain
    aggregations or to conditionally save results to CSV, etc.
    """

    def __init__(self, config=None):
        if config is None:
            config = {}
        self.config = config


    @staticmethod
    def group_max(groups, data, nmax):
        """
        Return the maximum value in each group. 
        nmax is the maximum group ID + 1, used to fill results in an array.
        """
        out = np.empty(nmax)
        out[:] = np.nan
        order = np.lexsort((data, groups))
        groups = groups[order]
        data = data[order]
        index = np.empty(len(groups), 'bool')
        index[-1] = True
        index[:-1] = groups[1:] != groups[:-1]
        out[groups[index]] = data[index]
        return out

    @staticmethod
    def group_avg(groups, data):
        """
        Compute mean of data for each unique group.
        """
        order = np.lexsort((data, groups))
        groups = groups[order]
        data = data[order]
        unames, idx, counts = np.unique(groups, return_inverse=True, return_counts=True)
        sum_pred = np.bincount(idx, weights=data)
        mean_pred = sum_pred / counts
        return mean_pred, sum_pred


    @staticmethod
    def compute_aggregated_probabilities(group, data):
        """
        **Reduced** aggregator that returns only avg & max
        for each group. 
        """
        wsi_dict = {}
        for idx, g in enumerate(group):
            if g not in wsi_dict:
                wsi_dict[g] = [data[idx]]
            else:
                wsi_dict[g].append(data[idx])

        avg_p = []
        max_p = []
        # (Removed sum, median, gmean, top-half-mean)

        for each_wsi in wsi_dict.keys():
            wsi_predictions = np.array(wsi_dict[each_wsi], dtype='float64')
            avg_p.append(np.mean(wsi_predictions))  # keep only average
            max_p.append(np.max(wsi_predictions))   # keep only max

        avg_p = np.array(avg_p, dtype='float64')
        max_p = np.array(max_p, dtype='float64')
        return avg_p, max_p

    @staticmethod
    def compute_aggregated_predictions(group, data):
        """
        Compute majority-vote predictions for each WSI. 
        """
        wsi_dict = {}
        for idx, g in enumerate(group):
            if g not in wsi_dict:
                wsi_dict[g] = [data[idx]]
            else:
                wsi_dict[g].append(data[idx])

        mv_pred = []
        n_pred = []
        for each_wsi in wsi_dict.keys():
            wsi_predictions = wsi_dict[each_wsi]
            wsi_pred_class = []
            for cl in range(num_classes):
                wsi_pred_class.append(wsi_predictions.count(cl))

            mj_vt = np.argmax(wsi_pred_class)
            mv_pred.append(mj_vt)
            n_pred.append(wsi_pred_class)

        n_pred = np.array(n_pred, dtype='float64')
        mv_pred = np.array(mv_pred, dtype='float64')
        return mv_pred, n_pred

    @staticmethod
    def group_avg_df(groups, data):
        """
        Group top-10 aggregator using pandas. 
        Returns the mean of nlargest(10) for each group.
        """
        df = pd.DataFrame({'Slide': groups, 'value': data})
        group_average_df = df.groupby('Slide')['value'].apply(lambda grp: grp.nlargest(10).mean())
        group_average = group_average_df.tolist()
        return group_average


In [5]:
################################################################################
#                     Evaluation Class (metrics, plotting, etc.)               #
################################################################################

class Evaluation:
    """
    This class holds evaluation-related methods such as metrics calculation,
    confusion matrices, classification reports, etc.
    """

    def __init__(self, config=None):
        if config is None:
            config = {}
        self.config = config

    @staticmethod
    def cutoff_youdens_j(fpr, tpr, thresholds):
        j_scores = tpr - fpr
        j_ordered = sorted(zip(j_scores, thresholds))
        return j_ordered[-1][1]

    @staticmethod
    def cal_f1_score(targets, prediction, cutoff):
        prediction = np.array(prediction)
        targets = np.array(targets)
        return f1_score(targets, prediction, average='weighted')

    @staticmethod
    def calc_metrics(target, prediction):
        fpr, tpr, thresholds = roc_curve(target, prediction)
        cutoff = Evaluation.cutoff_youdens_j(fpr, tpr, thresholds)
        roc_auc = auc(fpr, tpr)
        f1score = Evaluation.cal_f1_score(target, prediction, cutoff)
        precision, recall, _ = precision_recall_curve(target, prediction, zero_division=1)
        average_precision = average_precision_score(target, prediction, zero_division=1)
        return f1score, average_precision, roc_auc, cutoff

    @staticmethod
    def calculate_accuracy(output, target):
        preds = output.max(1, keepdim=True)[1]
        correct = preds.eq(target.view_as(preds)).sum()
        acc = correct.float() / preds.shape[0]
        return acc

    @staticmethod
    def compute_auc(labels, predictions):
        # binary_labels = custom_label_binarize(np.array(labels), classes=[i for i in range(num_classes)])
        # auc_list = []
        # for cl in range(num_classes):
        #     fpr, tpr, thresholds = roc_curve(binary_labels[:, cl], predictions[:, cl])
        #     auc_list.append(auc(fpr, tpr))
        aucscore = roc_auc_score(labels, predictions[:, 1])
        return aucscore

    @staticmethod
    def plot_metrics(target, prediction, set):
        import matplotlib.pyplot as plt
        fpr, tpr, thresholds = roc_curve(target, prediction)
        roc_auc = auc(fpr, tpr)
        print('roc_auc is:', roc_auc)

        precision, recall, _ = precision_recall_curve(target, prediction, zero_division=1)
        average_precision = average_precision_score(target, prediction, zero_division=1)
        pr_auc = auc(recall, precision)
        print('Average precision-recall score: {0:0.2f}'.format(average_precision))

        plt.figure(figsize=(12, 4))
        lw = 2

        # Subplot 1: ROC
        plt.subplot(121)
        plt.plot(fpr, tpr, color='darkorange',
                 lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
        plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve, AUC={0:0.2f}'.format(roc_auc))
        plt.legend(loc="lower right")

        # Subplot 2: Precision-Recall
        plt.subplot(122)
        plt.step(recall, precision, alpha=0.4, color='darkorange', where='post')
        plt.fill_between(recall, precision, alpha=0.2, color='navy', step='post')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.ylim([0, 1.05])
        plt.xlim([0, 1])
        plt.title('Precision-Recall curve: AP={0:0.2f}'.format(average_precision))

        plt.savefig(os.path.join(args.output, 'roc_pr' + set + '.png'))
        plt.close(plt.gcf())

In [6]:
################################################################################
#                            Training/Testing Utilities                        #
################################################################################

def inference(run, loader, model, criterion):
    model.eval()
    running_acc = 0.
    running_loss = 0.
    probs = torch.FloatTensor(len(loader.dataset), num_classes)
    preds = torch.FloatTensor(len(loader.dataset))

    index_offset = 0  # --- 5-crop change ---
    with torch.no_grad():
        for i, (inputs, target) in enumerate(loader):
            try:
                # inputs shape: (batch_size, 5, 3, 224, 224) for 5-crop
                b, ncrops, c, h, w = inputs.shape
                inputs = inputs.view(-1, c, h, w).cuda()     # shape: (batch_size*5, 3, 224, 224)
                target = target.cuda()

                # Forward pass
                output = model(inputs)                       # shape: (batch_size*5, num_classes)

                # Average logits over the 5 crops to get a single prediction per sample
                output = output.view(b, ncrops, -1).mean(1)  # shape: (batch_size, num_classes)
                loss = criterion(output, target)
                acc = Evaluation.calculate_accuracy(output, target)
                y = F.softmax(output, dim=-1)
                pred_value, pred = torch.max(output.data, 1)

                # Fill in our large preds/probs arrays
                batch_size_now = b
                preds[index_offset:index_offset + batch_size_now] = pred.detach().clone()
                probs[index_offset:index_offset + batch_size_now] = y.detach().clone()
                index_offset += batch_size_now

                running_acc += acc.item() * batch_size_now
                running_loss += loss.item() * batch_size_now

                if i % 100 == 0:
                    print('Inference\tEpoch: [{:3d}/{:3d}]\tBatch: [{:3d}/{}]\tValidation Loss: {:.4f}, acc: {:0.2f}%'
                        .format(run + 1, args.nepochs, i + 1, len(loader),
                                running_loss / ((i + 1) * batch_size_now),
                                (100 * running_acc) / ((i + 1) * batch_size_now)))
            except RuntimeError as e:
                print(f"RuntimeError in batch {i}: {e}")
                torch.cuda.empty_cache()
                continue  # Skip to the next batch
    return (
        probs.cpu().numpy(),
        running_loss / len(loader.dataset),
        running_acc / len(loader.dataset),
        preds.cpu().numpy()
    )


def train(run, loader, model, criterion, optimizer):
    model.train()
    running_loss = 0.
    running_acc = 0.

    index_offset = 0  # --- 5-crop change ---
    for i, (inputs, target) in enumerate(loader):
        try:
            # inputs shape: (batch_size, 5, 3, 224, 224) for 5-crop
            # print('5. i am running upto here inside train \n\n')
            b, ncrops, c, h, w = inputs.shape
            inputs = inputs.view(-1, c, h, w).cuda()  # shape: (batch_size*5, 3, 224, 224)
            target = target.cuda()
            
            # Forward pass
            optimizer.zero_grad()
            output = model(inputs)                    # shape: (batch_size*5, num_classes)
            # print('6. i am running upto here model given me output \n\n')
            # Average over 5 crops
            output = output.view(b, ncrops, -1).mean(1)  # shape: (batch_size, num_classes)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * b
            acc = Evaluation.calculate_accuracy(output, target)
            running_acc += acc.item() * b

            if i % 100 == 0:
                print("Train Epoch: [{:3d}/{:3d}] Batch: {:3d}, Training Loss: {:.4f}, Acc: {:.2f}%"
                    .format(run + 1, args.nepochs, i + 1,
                            running_loss / ((i + 1) * b),
                            100 * running_acc / ((i + 1) * b)))
        except RuntimeError as e:
                    print(f"RuntimeError in batch {i}: {e}")
                    torch.cuda.empty_cache()
                    continue  # Skip to the next batch
    return (running_loss / len(loader.dataset),
            running_acc / len(loader.dataset))



In [7]:
################################################################################
#                                Dataset Class                                 #
################################################################################
def get_split_indices(lib, test_fold, val_fold):
    """
    lib: DataFrame with a column named 'fold' 
         that has values in [0, 1, 2, 3].
    test_fold, val_fold: integers from 0..3
    returns: three lists => train_idx, val_idx, test_idx
    """
    # Extract all folds
    all_folds = lib['fold'].values.tolist()
    test_idx = [i for i, f in enumerate(all_folds) if f == test_fold]
    val_idx  = [i for i, f in enumerate(all_folds) if f == val_fold]
    train_idx = [i for i, f in enumerate(all_folds)
                 if f not in (test_fold, val_fold)]
    
    return train_idx, val_idx, test_idx

class MILdataset(data.Dataset):
    """
    A custom Dataset to handle multiple WSI tiles from a library file (CSV).
    Each tile is read at 512×512 and then transformed into 5-crop (224×224).
    """

    def __init__(self, libraryfile=KFOLD_PATH, transform=None, mult=2, s=10, 
                 shuffle=False, set='', test_fold=None, val_fold=None):
        path_dir = DATA_PATH
        lib = pd.DataFrame(pd.read_csv(libraryfile, usecols=[
            'Case_ID', 'WSI_Dir', 'label_desc', 'label_id', f'fold'
        ]))

        lib.dropna(inplace=True)

        allcases = lib['Case_ID'].values.tolist()
        allslides = lib['WSI_Dir'].values.tolist()
        tar = lib['label_id'].values.tolist()
        label_desc = lib['label_desc'].values.tolist()
        split = lib[f'fold'].values.tolist()

        train_idx, val_idx, test_idx = get_split_indices(lib, test_fold, val_fold)
        if set == 'train':
            print(f"##--Targets ==> {len(tar)} | {tar.count(0)} | {tar.count(1)} --##")
            indices = train_idx
            print(f"##--Train Split ==> {len(indices)} Slides")
            thresh_tiles=3
        elif set == 'valid':
            indices = val_idx
            print(f"##--Val Split ==> {len(indices)} Slides")
            thresh_tiles=3
        elif set == 'test':
            indices = test_idx
            print(f"##--Test Split ==> {len(indices)} Slides")
            thresh_tiles=3
        else:
            raise ValueError("Invalid set_type. Must be 'train', 'valid', or 'test'.")

        cases = []
        tiles = []
        ntiles = []
        slideIDX = []
        targetIDX = []
        targets = []
        label_desciption = []
        slides = []

        j = 0
        for i in indices:
            path = os.path.join(path_dir, str(allslides[i]))
            slide_label = int(tar[i])
            if os.path.exists(path):
                # max patch per slide comment it if you want to use all patches or make MAX_PATCHES = len(t)
                # MAX_PATCHES = 10
                t = [
                    os.path.join(path, f) 
                    for f in os.listdir(path) 
                    if f.endswith('.png')
                ]
                if len(t) >= thresh_tiles:
                    cases.append(allcases[i])
                    slides.append(allslides[i])
                    tiles.extend(t)
                    ntiles.append(len(t))
                    slideIDX.extend([j]*len(t))
                    targetIDX.extend([slide_label]*len(t))
                    targets.append(slide_label)
                    label_desciption.append(label_desc[i])
                    j += 1

        print('-------------------------')
        print('Number of Slides: {}'.format(len(slides)))
        print('Number of tiles: {}'.format(len(tiles)))
        print('Max tiles: ', max(ntiles) if len(ntiles) else 0)
        print('Min tiles: ', min(ntiles) if len(ntiles) else 0)
        print('Average tiles: ', np.mean(ntiles) if len(ntiles) else 0)
        print('nonMSI: ', targets.count(0))
        print('MSI: ', targets.count(1))

        self.slideIDX = slideIDX
        self.ntiles = ntiles
        self.tiles = tiles
        self.targets = targets
        self.label_desc = label_desciption
        self.slides = slides
        self.cases = cases

        self.transform = transform
        self.mult = mult
        self.s = s
        self.mode = None
        self.shuffle = shuffle
        self.targetIDX = targetIDX

    def setmode(self, mode):
        self.mode = mode

    def maketraindata(self, idxs):
        self.t_data = [(self.slideIDX[x], self.tiles[x], self.targets[self.slideIDX[x]]) 
                       for x in idxs]

    def shuffletraindata(self):
        self.t_data = random.sample(self.t_data, len(self.t_data))

    def __getitem__(self, index):
        if self.mode == 1:
            tile = self.tiles[index]
            img = Image.open(str(tile)).convert('RGB')

            # --- 5-crop change ---
            # Remove manual resize; rely on self.transform pipeline
            if self.transform is not None:
                img = self.transform(img)  # => shape: (5, 3, 224, 224)
            # 'img' is now a 5-crop tensor

            slide_idx = self.slideIDX[index]
            target = self.targets[slide_idx]
            return img, target

        elif self.mode == 2:
            slideIDX, tile, target = self.t_data[index]
            img = Image.open(str(tile)).convert('RGB')

            # --- 5-crop change ---
            if self.transform is not None:
                img = self.transform(img)  # => shape: (5, 3, 224, 224)

            return img, target

    def __len__(self):
        if self.mode == 1:
            return len(self.tiles)
        elif self.mode == 2:
            return len(self.t_data)


In [8]:

################################################################################
#                                Main Pipeline                                 #
################################################################################
class StackAndNormalize:
    def __init__(self, normalize):
        self.normalize = normalize

    def __call__(self, crops):
        return torch.stack([
            self.normalize(transforms.ToTensor()(crop)) 
            for crop in crops
        ])
def main():
    """
    Main training/validation/test pipeline.
    """

    global args, best_auc_v, tr_batch_size, n_slides, best_auc, best_f1_v

    config = {
        'aggregation_methods': ['average', 'top10', 'max'],
        'save_val_tile_csv': True,
        'save_test_tile_csv': True,
        'save_val_slide_csv': True,
        'save_test_slide_csv': True,
        'enable_group_avg_df': True,
        'enable_plot_metrics': True,
        'enable_confusion_matrix': True
    }

    aggregator = Aggregation(config=config)
    evaluator = Evaluation(config=config)

    args = parser.parse_args("")
    if not os.path.exists(args.output):
        os.mkdir(args.output)

    args.output = os.path.join(args.output, args.problem)
    if not os.path.exists(args.output):
        os.mkdir(args.output)

    temp_output = args.output

    AUC_SCORES = []
    F1_SCORES = []
    # Add these variables to control resume functionality
    resume_fold = 4  # Specify the fold to resume from (1-indexed)
    resume_epoch = 2  # Specify the epoch to resume from (0-indexed)
    for fold in range(args.folds):
        if fold + 1 < resume_fold:  # Skip folds before the resume point
            continue
        test_fold = fold + 1
        val_fold  = ((fold + 1) % 4) + 1
        args.output = os.path.join(temp_output, 'fold' + str(fold + 1))
        if not os.path.exists(args.output):
            os.mkdir(args.output)
        path_fold = args.output

        for sets in range(1):
            torch.cuda.empty_cache()

            args.output = os.path.join(path_fold, 'best' + str(sets))
            if not os.path.exists(args.output):
                os.mkdir(args.output)

            global best_auc_v, best_auc, n_slides, best_loss, best_f1_v, best_Acc, best_ap_v
            best_auc_v = 0
            best_auc = 0
            n_slides = 0
            best_loss = 100000.
            best_f1_v = 0.
            best_Acc = 0.
            best_ap_v = 0.

            model = models.resnet34(weights='DEFAULT')
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, num_classes)
            model.cuda()

            optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
            cudnn.benchmark = True
            criterion = nn.CrossEntropyLoss().cuda()

               # --- 5-crop change: define transforms so we return 5-crops (224×224) for each 512×512 image
            normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.1, 0.1, 0.1])
            trans = transforms.Compose([
                RandomRotation([0, 90, 180, 270]),
                transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05),
                transforms.RandomAdjustSharpness(sharpness_factor=2),
                transforms.FiveCrop(224),   # returns tuple of 5 PIL images
                StackAndNormalize(normalize)])

            trans_Valid = transforms.Compose([
                transforms.FiveCrop(224),
                StackAndNormalize(normalize)])
            # --- end 5-crop changes

            # Load Data
            train_dset = MILdataset(args.data_lib, trans, set='train',  test_fold=test_fold, val_fold=val_fold)
            train_loader = torch.utils.data.DataLoader(
                train_dset,
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)

            val_dset = MILdataset(args.data_lib, trans_Valid, set='valid',  test_fold=test_fold, val_fold=val_fold)
            val_loader = torch.utils.data.DataLoader(
                val_dset,
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)

            test_dset = MILdataset(args.data_lib, trans_Valid, set='test',  test_fold=test_fold, val_fold=val_fold)
            test_loader = torch.utils.data.DataLoader(
                test_dset,
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)

            #  if you wanted to see only dataset distribution and data loaders etc then put continue here
            # continue
            fconv = open(os.path.join(args.output, 'train_convergence.csv'), 'w')
            fconv.write('epoch,loss,accuracy\n')
            fconv.close()

            fconv = open(os.path.join(args.output, 'valid_convergence.csv'), 'w')
            # We removed sum, median, gmean columns; only keep relevant columns
            fconv.write('epoch,tile_loss,tile_acc,best_F1,F1_AVG,F1_Max,F1_T10,best_BAcc,Bacc_AVG,Bacc_Max,Bacc_T10,best_AUC,Avg_AUC,Max_AUC,Top_AUC\n')
            fconv.close()

            num_tiles = len(train_dset.slideIDX)
            n_slides = len(train_dset.slides)
            print(n_slides)
            train_dset.maketraindata(np.arange(num_tiles))
            # Resume logic
            # start_epoch = 0
            # if fold + 1 == resume_fold:
            #     checkpoint_path = r"E:\Aamir Gulzar\existing_approaches\Baseline_Fivecrop_4Folds\MSI_vs_MSS\fold3\best0\checkpoint_best_AUC.pth"
            #     if os.path.exists(checkpoint_path):
            #         checkpoint = torch.load(checkpoint_path)
            #         model.load_state_dict(checkpoint['state_dict'])
            #         optimizer.load_state_dict(checkpoint['optimizer'])
            #         start_epoch = resume_epoch
            #         print(f"Resuming training from Fold {resume_fold}, Epoch {resume_epoch + 1}")

            for epoch in range(args.nepochs):
                train_dset.shuffletraindata()
                train_dset.setmode(2)
                loss, acc = train(epoch, train_loader, model, criterion, optimizer)
                print('Training\tEpoch: [{}/{}]\tLoss: {:0.4f}\tAccuracy: {:3d}'
                      .format(epoch+1, args.nepochs, loss, int(acc * 100)))

                fconv = open(os.path.join(args.output, 'train_convergence.csv'), 'a')
                fconv.write('{},{:0.4f},{:3d}\n'.format(epoch, loss, int(acc * 100)))
                fconv.close()

                # Validation
                if (epoch + 1) % args.test_every == 0:
                    val_dset.setmode(1)
                    val_probs, val_loss, val_acc, val_preds = inference(epoch, val_loader, model, criterion)
                    val_slide_mjvt, _ = aggregator.compute_aggregated_predictions(np.array(val_dset.slideIDX), val_preds)

                    # We only keep avg, max, top10
                    val_slide_avg = []
                    val_slide_max = []
                    val_slide_avgt10 = []

                    # We'll store each class aggregator
                    for cl in range(num_classes):
                        t_avg, t_max = aggregator.compute_aggregated_probabilities(
                            np.array(val_dset.slideIDX), val_probs[:, cl]
                        )
                        t_t10 = aggregator.group_avg_df(
                            np.array(val_dset.slideIDX), val_probs[:, cl]
                        )
                        val_slide_avg.append(t_avg)
                        val_slide_max.append(t_max)
                        val_slide_avgt10.append(t_t10)

                    val_slide_avg = np.array(val_slide_avg).transpose()
                    val_slide_max = np.array(val_slide_max).transpose()
                    val_slide_avgt10 = np.array(val_slide_avgt10).transpose()

                    val_slide_avg_m = np.argmax(val_slide_avg, axis=1)
                    val_slide_max_m = np.argmax(val_slide_max, axis=1)
                    val_slide_avgt10_m = np.argmax(val_slide_avgt10, axis=1)

                    from sklearn.metrics import f1_score
                    f1_mv = f1_score(val_dset.targets, val_slide_mjvt, average='macro')
                    f1_avg = f1_score(val_dset.targets, val_slide_avg_m, average='macro')
                    f1_max = f1_score(val_dset.targets, val_slide_max_m, average='macro')
                    f1_a10 = f1_score(val_dset.targets, val_slide_avgt10_m, average='macro')
                    bacc_mv = balanced_accuracy_score(val_dset.targets, val_slide_mjvt)
                    bacc_avg = balanced_accuracy_score(val_dset.targets, val_slide_avg_m)
                    bacc_max = balanced_accuracy_score(val_dset.targets, val_slide_max_m)
                    bacc_avgt10 = balanced_accuracy_score(val_dset.targets, val_slide_avgt10_m)

                    auc_val_avg = evaluator.compute_auc(val_dset.targets, val_slide_avg)
                    auc_val_max = evaluator.compute_auc(val_dset.targets, val_slide_max)
                    auc_val_top10 = evaluator.compute_auc(val_dset.targets, val_slide_avgt10)

                    fconv = open(os.path.join(args.output, 'valid_convergence.csv'), 'a')
                    fconv.write('{},{:0.4f},{:3d},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f},{:0.4f}\n'.format(
                        epoch, val_loss, int(val_acc * 100),
                        max(f1_mv,f1_avg, f1_max, f1_a10),
                        f1_avg, f1_max, f1_a10,
                        # save max and individual balanced accuracies
                        max(bacc_mv, bacc_avg, bacc_max, bacc_avgt10),
                        bacc_avg, bacc_max, bacc_avgt10,
                        # also save best auc values
                        max(auc_val_avg, auc_val_max, auc_val_top10),
                        auc_val_avg, auc_val_max, auc_val_top10
                    ))
                    fconv.close()

                    print("\n----------------------Validation Results -------------------------------------")
                    print(f"F1-scores (Val): MV={f1_mv:.3f}  AVG={f1_avg:.3f}  MAX={f1_max:.3f}  top10={f1_a10:.3f}")
                    print(f"Balanced Acc (Val): MV={bacc_mv:.3f}  AVG={bacc_avg:.3f}  MAX={bacc_max:.3f}  top10={bacc_avgt10:.3f}")
                    print(f"AUC-scores (Val): AVG={auc_val_avg:.3f} MAX={auc_val_max:.3f} top10={auc_val_top10:.3f}")
                    print("------------------------------------------------------------------\n")

                    best_f1_candidate = max(f1_mv, f1_avg, f1_max, f1_a10)
                    if best_f1_candidate > best_f1_v:
                        best_f1_v = best_f1_candidate
                        obj = {
                            'epoch': epoch + 1,
                            'state_dict': model.state_dict(),
                            'best_ap_v': best_f1_v,
                            'best_auc_v': max(auc_val_max, auc_val_top10, auc_val_avg),
                            'optimizer': optimizer.state_dict()
                        }
                        torch.save(obj, os.path.join(args.output, 'checkpoint_best_F1.pth'))
                        print("Saved checkpoint_best_F1.pth")
                        # save the tile level predictions for validation slides columns slideidx, prob non msi, prob msi and pred tile
                        df2_f1 = pd.DataFrame({
                            'slideidx': val_dset.slideIDX,
                            'nonMSI_prob': val_probs[:, 0],
                            'MSI_prob': val_probs[:, 1],
                            'pred_tile': val_preds
                        })
                        df2_f1.to_csv(os.path.join(args.output, 'val_tile_pred_F1.csv'), index=False)

                    best_auc_candidate = max(auc_val_avg, auc_val_max, auc_val_top10)
                    if best_auc_candidate > best_auc_v:
                        best_auc_v = best_auc_candidate
                        obj = {
                            'epoch': epoch + 1,
                            'state_dict': model.state_dict(),
                            'best_ap_v': best_f1_v,
                            'best_auc_v': best_auc_v,
                            'optimizer': optimizer.state_dict()
                        }
                        torch.save(obj, os.path.join(args.output, 'checkpoint_best_AUC.pth'))
                        print("Saved checkpoint_best_AUC.pth")

                        # save the tile level predictions for validation slides columns slideidx, prob non msi, prob msi and pred tile
                        df2_auc = pd.DataFrame({
                            'slideidx': val_dset.slideIDX,
                            'nonMSI_prob': val_probs[:, 0],
                            'MSI_prob': val_probs[:, 1],
                            'pred_tile': val_preds
                        })
                        df2_auc.to_csv(os.path.join(args.output, 'val_tile_pred_AUC.csv'), index=False)

            # Save ground truth for validation slides
            df1 = pd.DataFrame({
                'Case_ID': val_dset.cases,
                'WSI_Id': val_dset.slides,
                'n_tiles': val_dset.ntiles,
                'label_desc': val_dset.label_desc,
                'label_id': val_dset.targets
            })
            df1.to_csv(os.path.join(args.output, 'val_GT.csv'), index=False)
            df2_f1.to_csv(os.path.join(args.output, 'val_tile_pred_F1.csv'), index=False)
            df2_auc.to_csv(os.path.join(args.output, 'val_tile_pred_AUC.csv'), index=False)
        
            ############## Test ##############
            ch = torch.load(os.path.join(args.output, 'checkpoint_best_AUC.pth'))
            model.load_state_dict(ch['state_dict'])

            test_dset.setmode(1)
            test_probs, test_loss, test_acc, test_preds = inference(epoch, test_loader, model, criterion)

            print(f'Predicted Tiles: {len(test_probs[:, 1])}, classes: {len(test_probs[0, :])}, total targets: {len(test_dset.targets)}')

            test_slide_mjvt, _ = aggregator.compute_aggregated_predictions(
                np.array(test_dset.slideIDX), test_preds
            )
            test_slide_avg = []
            test_slide_max = []
            test_slide_avgt10 = []
            for cl in range(num_classes):
                t_avg, t_max = aggregator.compute_aggregated_probabilities(
                    np.array(test_dset.slideIDX), test_probs[:, cl]
                )
                t_t10 = aggregator.group_avg_df(
                    np.array(test_dset.slideIDX), test_probs[:, cl]
                )
                test_slide_avg.append(t_avg)
                test_slide_max.append(t_max)
                test_slide_avgt10.append(t_t10)

            test_slide_avg = np.array(test_slide_avg).transpose()
            test_slide_max = np.array(test_slide_max).transpose()
            test_slide_avgt10 = np.array(test_slide_avgt10).transpose()

            test_slide_avg_m = np.argmax(test_slide_avg, axis=1)
            test_slide_max_m = np.argmax(test_slide_max, axis=1)
            test_slide_avgt10_m = np.argmax(test_slide_avgt10, axis=1)

            f1_mv = f1_score(test_dset.targets, test_slide_mjvt, average='macro')
            f1_avg = f1_score(test_dset.targets, test_slide_avg_m, average='macro')
            f1_max = f1_score(test_dset.targets, test_slide_max_m, average='macro')
            f1_t10 = f1_score(test_dset.targets, test_slide_avgt10_m, average='macro')
            # balanced accuracy
            bacc_mv = balanced_accuracy_score(test_dset.targets, test_slide_mjvt)
            bacc_avg = balanced_accuracy_score(test_dset.targets, test_slide_avg_m)
            bacc_max = balanced_accuracy_score(test_dset.targets, test_slide_max_m)
            bacc_avgt10 = balanced_accuracy_score(test_dset.targets, test_slide_avgt10_m)
            # AUC

            auc_test_avg = evaluator.compute_auc(test_dset.targets, test_slide_avg)
            auc_test_max = evaluator.compute_auc(test_dset.targets, test_slide_max)
            auc_test_top10 = evaluator.compute_auc(test_dset.targets, test_slide_avgt10)
      
            print("\n----------------------Test Set Results-------------------------------")
            print(f"Test F1-scores: MV={f1_mv:.3f}, AVG={f1_avg:.3f}, MAX={f1_max:.3f}, top10={f1_t10:.3f}")
            print(f"Test Balanced Acc: MV={bacc_mv:.3f}, AVG={bacc_avg:.3f}, MAX={bacc_max:.3f}, top10={bacc_avgt10:.3f}")
            print(f"Test AUC-scores: AVG={auc_test_avg:.3f}, MAX={auc_test_max:.3f}, top10={auc_test_top10:.3f}")
            # confusion matrix of average predictions only
            print("Confusion Matrix (Average Predictions):")
            print(confusion_matrix(test_dset.targets, test_slide_avg_m))
            # classification report of average predictions only
            print("Classification Report (Average Predictions):")
            print(classification_report(test_dset.targets, test_slide_avg_m))
            print("------------------------------------------------------\n")
            # save the ground truth for test slides
            df1 = pd.DataFrame({
                'Case_ID': test_dset.cases,
                'WSI_Id': test_dset.slides,
                'n_tiles': test_dset.ntiles,
                'label_desc': test_dset.label_desc,
                'label_id': test_dset.targets
            })
            df1.to_csv(os.path.join(args.output, 'test_GT.csv'), index=False)
            # save the average predictions of the test slides
            df2 = pd.DataFrame({
                'wsi_id': test_dset.slides,
                'nonMSI_prob': test_slide_avg[:, 0],
                'MSI_prob': test_slide_avg[:, 1]
            })
            df2.to_csv(os.path.join(args.output, 'test_pred_avg_AUC.csv'), index=False)
            # save the max predictions of the test slides
            df3 = pd.DataFrame({
                'wsi_id': test_dset.slides,
                'nonMSI_prob': test_slide_max[:, 0],
                'MSI_prob': test_slide_max[:, 1]
            })
            df3.to_csv(os.path.join(args.output, 'test_pred_max_AUC.csv'), index=False)
            # save the top10 predictions of the test slides
            df4 = pd.DataFrame({
                'wsi_id': test_dset.slides,
                'nonMSI_prob': test_slide_avgt10[:, 0],
                'MSI_prob': test_slide_avgt10[:, 1]
            })
            df4.to_csv(os.path.join(args.output, 'test_pred_t10_AUC.csv'), index=False)
            # save original tile level predictions for test slides
            df5 = pd.DataFrame({
                'slideidx': test_dset.slideIDX,
                'nonMSI_prob': test_probs[:, 0],
                'MSI_prob': test_probs[:, 1],
                'pred_tile': test_preds
            })
            df5.to_csv(os.path.join(args.output, 'test_tile_pred_AUC.csv'), index=False)
            print('..............Test set done ...............')

if __name__ == '__main__':
    main()
    torch.cuda.empty_cache()


##--Targets ==> 405 | 344 | 61 --##
##--Train Split ==> 205 Slides
-------------------------
Number of Slides: 205
Number of tiles: 123284
Max tiles:  1849
Min tiles:  6
Average tiles:  601.3853658536585
nonMSI:  172
MSI:  33
##--Val Split ==> 100 Slides
-------------------------
Number of Slides: 100
Number of tiles: 55849
Max tiles:  1942
Min tiles:  30
Average tiles:  558.49
nonMSI:  85
MSI:  15
##--Test Split ==> 100 Slides
-------------------------
Number of Slides: 100
Number of tiles: 64050
Max tiles:  1875
Min tiles:  18
Average tiles:  640.5
nonMSI:  87
MSI:  13
205
Train Epoch: [  1/  4] Batch:   1, Training Loss: 0.7160, Acc: 48.63%
Train Epoch: [  1/  4] Batch: 101, Training Loss: 0.3127, Acc: 87.36%
Train Epoch: [  1/  4] Batch: 201, Training Loss: 0.2722, Acc: 89.01%
Training	Epoch: [1/4]	Loss: 0.2610	Accuracy:  89
Inference	Epoch: [  1/  4]	Batch: [  1/110]	Validation Loss: 0.6878, acc: 77.15%
Inference	Epoch: [  1/  4]	Batch: [101/110]	Validation Loss: 0.6829, acc: 77.9