## Some Instruction before you run this notebook.
1. Change the line 29 of dataloader_clustering module from wsi_id = wsi_file[:12] to wsi_id = wsi_folder
2. Edit the second cell configurations according to your paths
3. Use patch level features for this. (Fivecrops or Patch Level Averaged)  

In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from dataloader_clustering import WSIDataset
from eval_patch_features.logistic import test_saved_logistic_model
from eval_patch_features.ann import test_saved_ann_model
from eval_patch_features.knn import test_saved_knn_model
from eval_patch_features.protonet import test_saved_protonet_model
from eval_patch_features.metrics import get_eval_metrics, print_metrics
from utility import calculate_metric_averages, average_confusion_matrices, write_data_in_excel, build_probs_df
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Configurations

In [None]:
# configs
from pathlib import Path
BASE_DIR = Path(r"E:\Aamir Gulzar\WSI_Classification_Using_FM_Features\test\baseline_2Cluster_Classifiers")

VECTOR_DIM = 512  # size of input feature vector
HIDDEN_DIM = 768   # size of ANN hidden layer
BATCH_SIZE = 8
CLUSTERING_METHOD = 'kmeans'
NUM_CLUSTERS = 2
NUM_PATCHES_PER_CLUSTER = 0
FOLDS = 4  # this fold values is only used to load the respective fold model
K_FOLDS_PATH = r"E:\Aamir Gulzar\dataset\paip_data\labels\paip_78slides.csv"
DATA_PATH = r"E:\Aamir Gulzar\dataset\paip_data\baseline_FiveCrop_Features"
OUTPUT_SAVE_PATH = r"E:\Aamir Gulzar\WSI_Classification_Using_FM_Features\test\output"
os.makedirs(OUTPUT_SAVE_PATH, exist_ok=True)
# create a excel sheet in the output folder to save the results
EVAL_METRICS_EXCEL = os.path.join(OUTPUT_SAVE_PATH, "PAIP-EV78_2cluster_eval_metrics.xlsx")
PROBS_ALL_EXCEL = os.path.join(OUTPUT_SAVE_PATH, "PAIP-EV78_2cluster_probs_all.xlsx")
sheet_name = "baseline"

## Trainer Function

In [None]:
def evaluate(fold, test_loader, model_type='linear'):
    all_test_feats, all_test_labels,all_test_ids = [], [], []
    
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)
        # Store as single WSI IDs from the batch 
        if isinstance(wsi_id, (list, tuple)):
            all_test_ids.extend(wsi_id)
        else:
            all_test_ids.append(wsi_id)

    # Convert lists to tensors
    global test_feats, test_labels
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    # Select the model based on the input argument
    if model_type == 'lin':
        eval_metrics, eval_preds = test_saved_logistic_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_logistic_regression.pkl"
        )
    elif model_type == 'ann':
        eval_metrics, eval_preds = test_saved_ann_model(
            input_dim=VECTOR_DIM * NUM_CLUSTERS,
            hidden_dim=HIDDEN_DIM,
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_trained_ann_model_{VECTOR_DIM * NUM_CLUSTERS}.pth"
        )
    elif model_type == 'knn':
        eval_metrics, eval_preds = test_saved_knn_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_knn_model.pkl"
        )
    elif model_type == 'proto':
        eval_metrics, eval_preds = test_saved_protonet_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_protonet_model.pkl"
        )
    
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    return eval_metrics, eval_preds, all_test_ids

### K-Folds

In [None]:
from typing import List, Dict

def run_k_folds(save_dir: str, slides: List[List[str]],folds: int, model: str = 'linear'):
    results_per_fold = []

    for i in range(folds):

        # Create datasets and loaders
        test_dataset = WSIDataset(save_dir, slides)
        test_dataset.apply_clustering(clustering_algorithm=CLUSTERING_METHOD, num_clusters=NUM_CLUSTERS, num_selected_patches=NUM_PATCHES_PER_CLUSTER)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        # Train and evaluate
        print(f"Running Fold {i + 1} with model {model}...")
        eval_metrics,eval_preds, all_test_ids = evaluate(i,test_loader, model_type=model)
        print_metrics(eval_metrics)
        result = {
            **eval_metrics,
            **eval_preds,
            "wsi_ids":all_test_ids ,  # You already have this in train_and_evaluate
            "fold": i + 1
        }
        results_per_fold.append(result)
    return results_per_fold


### Main Runner Function

In [None]:
# Example usage:
slides = pd.read_csv(K_FOLDS_PATH)
slides = slides['Fold1'].dropna().values.tolist()
# Define your folds
# Run k-fold cross-validation with different models
model_types = ['lin','ann','knn','proto']
metric_indices = {
    'acc': 0,          # 'lin_acc' corresponds to index 0
    'bacc': 1,         # 'lin_bacc' corresponds to index 1
    'macro_f1': 2,        # 'lin_kappa' corresponds to index 2
    'weighted_f1': 3,  # 'lin_weighted_f1' corresponds to index 3
    'auroc': 4         # 'lin_auroc' corresponds to index 4
}
eval_metrics__for_excel = []
probs_all_for_excel = None
for model in model_types:
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    k_folds_results = run_k_folds(DATA_PATH,slides=slides, folds=FOLDS,model=model)
    model_df = build_probs_df(k_folds_results,model_name=model)
    # === Merge predictions across models ===
    if probs_all_for_excel is None:
        probs_all_for_excel = model_df
    else:
        probs_all_for_excel=pd.merge(probs_all_for_excel,model_df,on=["Fold", "WSI_ID", "Target"],how="outer")

    # === Average metrics (only pass metric parts of result dicts)
    average_results = calculate_metric_averages(
        [{k: v for k, v in result.items() if k in [f"{model}_{m}" for m in metric_indices.keys()]}
        for result in k_folds_results],
        metric_indices,
        model_prefix=model
    )
    # === Confusion matrices
    confusion_matrices = [np.array(result[f"{model}_conf_matrix"]) for result in k_folds_results if f"{model}_conf_matrix" in result]
    
    avg_conf_matrix = average_confusion_matrices(confusion_matrices)
    print("\n\n Average results for all folds:")
    for metric, value in average_results.items():
        print(f"{metric}: {value:.4f}")
     # Append per metric rows for each fold + average
    for metric in metric_indices.keys():
        row = [f"{model}_{metric}"]
        for result in k_folds_results:
            row.append(result.get(f"{model}_{metric}", 'N/A'))
        row.append(average_results.get(f"{model}_{metric}", 'N/A'))
        eval_metrics__for_excel.append(row)

    # Append confusion matrix as string (per fold)
    row = [f"{model}_conf_matrix"]
    for result in k_folds_results:
        row.append(str(result.get(f"{model}_conf_matrix", "N/A")))
    row.append(str(avg_conf_matrix))
    eval_metrics__for_excel.append(row)
    
eval_metrics_df = pd.DataFrame(eval_metrics__for_excel, 
                        columns=["Metric","Fold1","Fold2","Fold3","Fold4","AvgFolds"])
write_data_in_excel(EVAL_METRICS_EXCEL, eval_metrics_df, sheet_name=sheet_name)
write_data_in_excel(PROBS_ALL_EXCEL, probs_all_for_excel, sheet_name=sheet_name)




 ********* Training with model: linear********* 


Running Fold 1 with model linear...
lin_acc: 0.7660
lin_bacc: 0.6238
lin_macro_f1: 0.6372
lin_weighted_f1: 0.7430
lin_auroc: 0.7167
lin_conf_matrix: [[32  3]
 [ 8  4]]
Running Fold 2 with model linear...
lin_acc: 0.7447
lin_bacc: 0.5000
lin_macro_f1: 0.4268
lin_weighted_f1: 0.6357
lin_auroc: 0.5119
lin_conf_matrix: [[35  0]
 [12  0]]
Running Fold 3 with model linear...
lin_acc: 0.7447
lin_bacc: 0.6095
lin_macro_f1: 0.6189
lin_weighted_f1: 0.7260
lin_auroc: 0.7571
lin_conf_matrix: [[31  4]
 [ 8  4]]
Running Fold 4 with model linear...
lin_acc: 0.7872
lin_bacc: 0.5833
lin_macro_f1: 0.5804
lin_weighted_f1: 0.7245
lin_auroc: 0.7452
lin_conf_matrix: [[35  0]
 [10  2]]


 Average results for all folds:
acc: 0.7606
bacc: 0.5792
macro_f1: 0.5658
weighted_f1: 0.7073
auroc: 0.6827


 ********* Training with model: ann********* 


Running Fold 1 with model ann...
ann_acc: 0.7660
ann_bacc: 0.8155
ann_macro_f1: 0.7432
ann_weighted_f1: 0.7806
ann_