In [85]:
import torch
import torchvision
from torch.utils.data import DataLoader
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from eval_patch_features.logistic import test_saved_logistic_model
from eval_patch_features.ann import test_saved_ann_model
from eval_patch_features.knn import test_saved_knn_model
from eval_patch_features.protonet import test_saved_protonet_model
from eval_patch_features.metrics import get_eval_metrics, print_metrics
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Configurations

In [86]:
# configs
from pathlib import Path
BASE_DIR = Path(r"E:\Aamir Gulzar\WSI_Classification_Using_FM_Features\baseline Classifiers\baseline_2ClusterClassifiers")

VECTOR_DIM = 512  # size of input feature vector
HIDDEN_DIM = 768   # size of ANN hidden layer
BATCH_SIZE = 8
CLUSTERING_METHOD = 'kmeans'
NUM_CLUSTERS = 2
NUM_PATCHES_PER_CLUSTER = 0
FOLDS = 4  # this fold values is only used to load the respective fold model
K_FOLDS_PATH = r"E:\Aamir Gulzar\dataset\paip_data\labels\paip_47slides.csv"
DATA_PATH = r"E:\Aamir Gulzar\dataset\paip_data\baseline_FiveCrop_Features"

### Data Loaders

In [87]:
from sklearn.model_selection import KFold
from sklearn.utils import shuffle
from typing import List, Tuple
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans, DBSCAN
from sklearn.decomposition import PCA

class WSIDataset(Dataset):
    def __init__(self, save_dir: str, fold_ids: List[str]):
        self.data = []
        self.save_dir = save_dir
        self.fold_ids = fold_ids
        self._load_data()

    def _load_data(self):
        for wsi_folder in os.listdir(self.save_dir):
            wsi_folder_path = os.path.join(self.save_dir, wsi_folder)
            if not os.path.isdir(wsi_folder_path) and len(os.listdir(wsi_folder_path)) <= 15:
                # print(f"Skipping {wsi_folder} due to less than 18 patches")
                continue
            for wsi_file in os.listdir(wsi_folder_path):
                if wsi_file.endswith('.pt'):
                    wsi_id = wsi_folder
                    if wsi_id not in self.fold_ids:
                        continue
                    try:
                        wsi_features = torch.load(os.path.join(wsi_folder_path, wsi_file))
                        # check if loaded features is not one feature vector then average them to make one feature vector
                        if isinstance(wsi_features, torch.Tensor) and wsi_features.dim() > 1:
                            # print(f"WSI ID: {wsi_id} | Features Shape: {wsi_features.shape}")
                            wsi_features = torch.mean(wsi_features, dim=0)
                        label = 0 if '_nonMSI' in wsi_file else 1
                        self.data.append((wsi_features, label, wsi_id))
                    except Exception as e:
                        print(f"Error loading {os.path.join(wsi_folder_path, wsi_file)}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features, label, wsi_id = self.data[idx]
        return features, label, wsi_id
    
    def apply_clustering(self, clustering_algorithm: str, num_clusters: int = 3, num_selected_patches: int = 0):
        """
        Apply clustering on the patches of each WSI and create a consistent WSI representation.
        Args:
        - clustering_algorithm: The clustering algorithm to use ('kmeans', 'dbscan', 'pca').
        - num_clusters: Number of clusters to create (only for k-means or similar algorithms).
        - num_selected_patches: Number of top patches to use for averaging within each cluster (optional).
        """
        clustered_data = []
        wsi_ids = set([wsi_id for _, _, wsi_id in self.data])  # Unique WSI IDs

        for wsi_id in wsi_ids:
            # Extract all patches for the WSI
            wsi_patches = [features for features, _, id in self.data if id == wsi_id]
            wsi_patches = torch.stack(wsi_patches)
            patch_array = wsi_patches.cpu().numpy()
            # Step 0: Skip WSI if the number of patches is less than the specified number of clusters
            if len(patch_array) < num_clusters:
                # print(f"Skipping WSI {wsi_id} because it has fewer patches ({len(patch_array)}) than the specified number of clusters ({num_clusters}).")
                continue

            # Step 1: Perform clustering
            if clustering_algorithm == 'kmeans':
                clustering_model = KMeans(n_clusters=num_clusters, random_state=42)
                clustering_model.fit(patch_array)
                cluster_labels = clustering_model.labels_
                cluster_centroids = clustering_model.cluster_centers_
            elif clustering_algorithm == 'dbscan':
                clustering_model = DBSCAN(eps=0.1, min_samples=2)
                cluster_labels = clustering_model.fit_predict(patch_array)
                unique_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
                num_clusters = unique_clusters
                print(f"Unique clusters: {unique_clusters}")
                cluster_centroids = np.array([np.mean(patch_array[cluster_labels == i], axis=0) for i in range(num_clusters)])
                
            elif clustering_algorithm == 'pca':
                pca = PCA(n_components=num_clusters)
                transformed_features = pca.fit_transform(patch_array)
                cluster_labels = np.argmax(transformed_features, axis=1)
                cluster_centroids = pca.components_
            else:
                raise ValueError(f"Unsupported clustering algorithm: {clustering_algorithm}")

            # Step 2: Aggregate features within each cluster
            selected_features = []
            cluster_sums = []
            for cluster_idx in range(num_clusters):
                cluster_patches = patch_array[cluster_labels == cluster_idx]
                if len(cluster_patches) == 0:
                    continue
                # Optionally select top-ranked patches
                if num_selected_patches > 0:
                    distances = cdist(cluster_patches, [cluster_centroids[cluster_idx]], metric='euclidean').flatten()
                    sorted_indices = np.argsort(distances)
                    cluster_patches = cluster_patches[sorted_indices[:num_selected_patches]]

                # Average the cluster features and calculate cluster sum for sorting
                cluster_average = np.mean(cluster_patches, axis=0)
                cluster_sum = np.average(cluster_average)
                selected_features.append(cluster_average)
                cluster_sums.append(cluster_sum)

            # Step 3: Sort clusters by their sum values
            sorted_indices = np.argsort(cluster_sums)
            sorted_features = np.array(selected_features)[sorted_indices]

            # Step 4: Create a consistent WSI-level representation
            concatenated_features = np.concatenate(sorted_features)
            label = [label for _, label, id in self.data if id == wsi_id][0]  # Assume all patches have the same label
            clustered_data.append((torch.tensor(concatenated_features), label, wsi_id))

            # Print debug information (optional)
            # print(f"WSI ID: {wsi_id} | Total Patches: {len(patch_array)} | Clusters: {num_clusters} | Sorted Features Shape: {concatenated_features.shape}")

        # Update the dataset with clustered data
        self.data = clustered_data

### Evaluation Metrics

In [88]:
def calculate_metric_averages_by_index(all_fold_results, metric_indices):
    """
    Calculate the average of specified metrics over multiple folds.
    Args:
        all_fold_results (list of dicts): Results for each fold.
        metric_indices (dict): Mapping of metric names to their indices.
    Returns:
        dict: Averages of the specified metrics across folds.
    """
    # Initialize averages dictionary
    averages = {metric: 0 for metric in metric_indices.keys()}
    counts = {metric: 0 for metric in metric_indices.keys()}  # Keep track of valid metrics
    num_folds = len(all_fold_results)

    for result in all_fold_results:
        # Iterate through metrics by their index
        for metric, index in metric_indices.items():
            try:
                metric_name = list(result.keys())[index]  # Extract the metric name by index
                if metric_name in result and isinstance(result[metric_name], (int, float)):  # Check if metric exists and is numeric
                    averages[metric] += result[metric_name]
                    counts[metric] += 1
            except IndexError:
                # Metric not present in this result due to model differences
                continue
            except Exception as e:
                print(f"Error processing metric '{metric}': {e}")
    # Compute average only for metrics with valid values
    for metric in averages:
        if counts[metric] > 0:
            averages[metric] /= counts[metric]
    return averages


## Trainer Function

In [89]:
def evaluate(fold, test_loader, model_type='linear'):
    all_test_feats, all_test_labels = [], []
    
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)

    # Convert lists to tensors
    global test_feats, test_labels
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    # Select the model based on the input argument
    if model_type == 'linear':
        eval_metrics = test_saved_logistic_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_logistic_regression.pkl"
        )
    elif model_type == 'ann':
        eval_metrics = test_saved_ann_model(
            input_dim=VECTOR_DIM * NUM_CLUSTERS,
            hidden_dim=HIDDEN_DIM,
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_trained_ann_model_{VECTOR_DIM * NUM_CLUSTERS}.pth"
        )
    elif model_type == 'knn':
        eval_metrics = test_saved_knn_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_knn_model.pkl"
        )
    elif model_type == 'protonet':
        eval_metrics = test_saved_protonet_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=BASE_DIR / f"fold{fold}_protonet_model.pkl"
        )
    
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    return eval_metrics

### K-Folds

In [90]:
# Cross-validation function
def run_k_folds(save_dir: str, slides: List[List[str]],folds: int, model: str = 'linear'):
    results_per_fold = []

    for i in range(folds):

        # Create datasets and loaders
        test_dataset = WSIDataset(save_dir, slides)
        test_dataset.apply_clustering(clustering_algorithm=CLUSTERING_METHOD, num_clusters=NUM_CLUSTERS, num_selected_patches=NUM_PATCHES_PER_CLUSTER)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        # Train and evaluate
        print(f"Running Fold {i + 1} with model {model}...")
        eval_metrics = evaluate(i,test_loader, model_type=model)
        print_metrics(eval_metrics)
        results_per_fold.append(eval_metrics)

    return results_per_fold

def calculate_metric_averages_by_index(all_fold_results, metric_indices):
    """
    Calculate the average of specified metrics over multiple folds.

    Args:
        all_fold_results (list of dicts): Results for each fold.
        metric_indices (dict): Mapping of metric names to their indices.

    Returns:
        dict: Averages of the specified metrics across folds.
    """
    # Initialize averages dictionary
    averages = {metric: 0 for metric in metric_indices.keys()}
    counts = {metric: 0 for metric in metric_indices.keys()}  # Keep track of valid metrics
    num_folds = len(all_fold_results)

    for result in all_fold_results:
        # Iterate through metrics by their index
        for metric, index in metric_indices.items():
            try:
                metric_name = list(result.keys())[index]  # Extract the metric name by index
                if metric_name in result and isinstance(result[metric_name], (int, float)):  # Check if metric exists and is numeric
                    averages[metric] += result[metric_name]
                    counts[metric] += 1
            except IndexError:
                # Metric not present in this result due to model differences
                continue
            except Exception as e:
                print(f"Error processing metric '{metric}': {e}")
    # Compute average only for metrics with valid values
    for metric in averages:
        if counts[metric] > 0:
            averages[metric] /= counts[metric]
    return averages



### Main Runner Function

In [91]:
# Example usage:
slides = pd.read_csv(K_FOLDS_PATH)
slides = slides['Fold1'].dropna().values.tolist()
# Define your folds
# Run k-fold cross-validation with different models
model_types = ['linear','ann','knn','protonet']
metric_indices = {
    'acc': 0,          # 'lin_acc' corresponds to index 0
    'bacc': 1,         # 'lin_bacc' corresponds to index 1
    'macro_f1': 2,        # 'lin_kappa' corresponds to index 2
    'weighted_f1': 3,  # 'lin_weighted_f1' corresponds to index 3
    'auroc': 4         # 'lin_auroc' corresponds to index 4
}

for model in model_types:
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    k_folds_results = run_k_folds(DATA_PATH,slides=slides, folds=FOLDS,model=model)
    average_results = calculate_metric_averages_by_index(k_folds_results, metric_indices)
    print("\n\n Average results for all folds:")
    for metric, value in average_results.items():
        print(f"{metric}: {value:.4f}")




 ********* Training with model: linear********* 


Running Fold 1 with model linear...
lin_acc: 0.7660
lin_bacc: 0.6238
lin_macro_f1: 0.6372
lin_weighted_f1: 0.7430
lin_auroc: 0.7167
lin_conf_matrix: [[32  3]
 [ 8  4]]
Running Fold 2 with model linear...
lin_acc: 0.7447
lin_bacc: 0.5000
lin_macro_f1: 0.4268
lin_weighted_f1: 0.6357
lin_auroc: 0.5119
lin_conf_matrix: [[35  0]
 [12  0]]
Running Fold 3 with model linear...
lin_acc: 0.7447
lin_bacc: 0.6095
lin_macro_f1: 0.6189
lin_weighted_f1: 0.7260
lin_auroc: 0.7571
lin_conf_matrix: [[31  4]
 [ 8  4]]
Running Fold 4 with model linear...
lin_acc: 0.7872
lin_bacc: 0.5833
lin_macro_f1: 0.5804
lin_weighted_f1: 0.7245
lin_auroc: 0.7452
lin_conf_matrix: [[35  0]
 [10  2]]


 Average results for all folds:
acc: 0.7606
bacc: 0.5792
macro_f1: 0.5658
weighted_f1: 0.7073
auroc: 0.6827


 ********* Training with model: ann********* 


Running Fold 1 with model ann...
ann_acc: 0.7660
ann_bacc: 0.8155
ann_macro_f1: 0.7432
ann_weighted_f1: 0.7806
ann_