In [120]:
import torch
import torchvision
from torch.utils.data import DataLoader
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from eval_patch_features.logistic import test_saved_logistic_model
from eval_patch_features.ann import test_saved_ann_model
from eval_patch_features.knn import test_saved_knn_model
from eval_patch_features.protonet import test_saved_protonet_model
from eval_patch_features.metrics import get_eval_metrics, print_metrics
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Configurations

In [121]:
# configs
from pathlib import Path
BASE_DIR = Path(r"E:\Aamir Gulzar\WSI_Classification_Using_FM_Features\caiman Classifiers\caiman_Average_Classifiers")

VECTOR_DIM = 512  # size of input feature vector
HIDDEN_DIM = 768   # size of ANN hidden layer
BATCH_SIZE = 8
FOLDS = 4  # this fold values is only used to load the respective fold model
K_FOLDS_PATH = r"E:\Aamir Gulzar\dataset\paip_data\labels\paip_78slides.csv"
DATA_PATH = r"E:\Aamir Gulzar\dataset\paip_data\caiman_Features"

### Data Loaders

In [122]:
import os
import torch
from torch.utils.data import Dataset
from typing import List

class WSIDataset(Dataset):
    def __init__(self, save_dir: str, fold_ids: List[str]):
        self.data = []
        self.save_dir = save_dir
        self.fold_ids = fold_ids
        self._load_data()

    def _load_data(self):
        for wsi_file in os.listdir(self.save_dir):
            wsi_path = os.path.join(self.save_dir, wsi_file)
            # Extract the WSI ID and truncate it to the first 12 characters
            wsi_id = os.path.splitext(wsi_file)[0]   # Extract first 12 characters
            if wsi_id not in self.fold_ids:
                # print(f'skiping {wsi_file} becuase {wsi_id} not in list from path {wsi_path}')
                continue  # Skip if the WSI is not in the current fold

            if wsi_path.endswith('.pt'):
                try:
                    # Load WSI features
                    wsi_features = torch.load(wsi_path)
                    if wsi_features.is_cuda:
                        wsi_features = wsi_features.cpu()
                    # Average all feature vectors in the .pt file
                    if wsi_features.dim() > 1:
                        averaged_features = torch.mean(wsi_features, dim=0)
                    else:
                        averaged_features = wsi_features  # In case it is already a single vector

                    # Determine label based on WSI file name
                    label = 0 if '_nonMSI' in wsi_file else 1

                    # Append the averaged features, label, and WSI ID
                    self.data.append((averaged_features, label, wsi_id))
                except Exception as e:
                    print(f"Error loading {wsi_path}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features, label, wsi_id = self.data[idx]
        return features, label, wsi_id

### Evaluation Metrics

In [123]:
def calculate_metric_averages_by_index(all_fold_results, metric_indices):
    """
    Calculate the average of specified metrics over multiple folds.
    Args:
        all_fold_results (list of dicts): Results for each fold.
        metric_indices (dict): Mapping of metric names to their indices.
    Returns:
        dict: Averages of the specified metrics across folds.
    """
    # Initialize averages dictionary
    averages = {metric: 0 for metric in metric_indices.keys()}
    counts = {metric: 0 for metric in metric_indices.keys()}  # Keep track of valid metrics
    num_folds = len(all_fold_results)

    for result in all_fold_results:
        # Iterate through metrics by their index
        for metric, index in metric_indices.items():
            try:
                metric_name = list(result.keys())[index]  # Extract the metric name by index
                if metric_name in result and isinstance(result[metric_name], (int, float)):  # Check if metric exists and is numeric
                    averages[metric] += result[metric_name]
                    counts[metric] += 1
            except IndexError:
                # Metric not present in this result due to model differences
                continue
            except Exception as e:
                print(f"Error processing metric '{metric}': {e}")
    # Compute average only for metrics with valid values
    for metric in averages:
        if counts[metric] > 0:
            averages[metric] /= counts[metric]
    return averages


## Trainer Function

In [124]:
def evaluate(fold, test_loader, model_type='linear'):
    all_test_feats, all_test_labels = [], []
    
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)

    # Convert lists to tensors
    global test_feats, test_labels
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    # Select the model based on the input argument
    if model_type == 'linear':
        eval_metrics = test_saved_logistic_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_logistic_regression.pkl"
        )
    elif model_type == 'ann':
        eval_metrics = test_saved_ann_model(
            input_dim=VECTOR_DIM,
            hidden_dim=HIDDEN_DIM,
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_trained_ann_model_{VECTOR_DIM}.pth"
        )
    elif model_type == 'knn':
        eval_metrics = test_saved_knn_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_knn_model.pkl"
        )
    elif model_type == 'protonet':
        eval_metrics = test_saved_protonet_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_protonet_model.pkl"
        )
    
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    return eval_metrics

### K-Folds

In [125]:
# Cross-validation function
def run_k_folds(save_dir: str, slides: List[List[str]],folds: int, model: str = 'linear'):
    results_per_fold = []

    for i in range(folds):

        # Create datasets and loaders
        test_dataset = WSIDataset(save_dir, slides)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        # Train and evaluate
        print(f"Running Fold {i + 1} with model {model}...")
        eval_metrics = evaluate(i,test_loader, model_type=model)
        print_metrics(eval_metrics)
        results_per_fold.append(eval_metrics)

    return results_per_fold

def calculate_metric_averages_by_index(all_fold_results, metric_indices):
    """
    Calculate the average of specified metrics over multiple folds.

    Args:
        all_fold_results (list of dicts): Results for each fold.
        metric_indices (dict): Mapping of metric names to their indices.

    Returns:
        dict: Averages of the specified metrics across folds.
    """
    # Initialize averages dictionary
    averages = {metric: 0 for metric in metric_indices.keys()}
    counts = {metric: 0 for metric in metric_indices.keys()}  # Keep track of valid metrics
    num_folds = len(all_fold_results)

    for result in all_fold_results:
        # Iterate through metrics by their index
        for metric, index in metric_indices.items():
            try:
                metric_name = list(result.keys())[index]  # Extract the metric name by index
                if metric_name in result and isinstance(result[metric_name], (int, float)):  # Check if metric exists and is numeric
                    averages[metric] += result[metric_name]
                    counts[metric] += 1
            except IndexError:
                # Metric not present in this result due to model differences
                continue
            except Exception as e:
                print(f"Error processing metric '{metric}': {e}")
    # Compute average only for metrics with valid values
    for metric in averages:
        if counts[metric] > 0:
            averages[metric] /= counts[metric]
    return averages



### Main Runner Function

In [126]:
# Example usage:
slides = pd.read_csv(K_FOLDS_PATH)
slides = slides['Fold1'].dropna().values.tolist()
# Define your folds
# Run k-fold cross-validation with different models
model_types = ['linear','ann','knn','protonet']
metric_indices = {
    'acc': 0,          # 'lin_acc' corresponds to index 0
    'bacc': 1,         # 'lin_bacc' corresponds to index 1
    'macro_f1': 2,        # 'lin_kappa' corresponds to index 2
    'weighted_f1': 3,  # 'lin_weighted_f1' corresponds to index 3
    'auroc': 4         # 'lin_auroc' corresponds to index 4
}

for model in model_types:
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    k_folds_results = run_k_folds(DATA_PATH,slides=slides, folds=FOLDS,model=model)
    average_results = calculate_metric_averages_by_index(k_folds_results, metric_indices)
    print("\n\n Average results for all folds:")
    for metric, value in average_results.items():
        print(f"{metric}: {value:.4f}")




 ********* Training with model: linear********* 


Running Fold 1 with model linear...
lin_acc: 0.7857
lin_bacc: 0.5606
lin_macro_f1: 0.5625
lin_weighted_f1: 0.7411
lin_auroc: 0.8333
lin_conf_matrix: [[21  1]
 [ 5  1]]
Running Fold 2 with model linear...
lin_acc: 0.7857
lin_bacc: 0.5000
lin_macro_f1: 0.4400
lin_weighted_f1: 0.6914
lin_auroc: 0.7576
lin_conf_matrix: [[22  0]
 [ 6  0]]
Running Fold 3 with model linear...
lin_acc: 0.8214
lin_bacc: 0.5833
lin_macro_f1: 0.5918
lin_weighted_f1: 0.7668
lin_auroc: 0.8258
lin_conf_matrix: [[22  0]
 [ 5  1]]
Running Fold 4 with model linear...
lin_acc: 0.8214
lin_bacc: 0.5833
lin_macro_f1: 0.5918
lin_weighted_f1: 0.7668
lin_auroc: 0.7576
lin_conf_matrix: [[22  0]
 [ 5  1]]


 Average results for all folds:
acc: 0.8036
bacc: 0.5568
macro_f1: 0.5465
weighted_f1: 0.7415
auroc: 0.7936


 ********* Training with model: ann********* 


Running Fold 1 with model ann...
ann_acc: 0.5000
ann_bacc: 0.6212
ann_macro_f1: 0.4896
ann_weighted_f1: 0.5312
ann_