In [1]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
# print(torch.version)
# print(torch.version.cuda)
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from eval_patch_features.logistic import eval_linear
from eval_patch_features.ann import eval_ANN
from eval_patch_features.knn import eval_knn
from eval_patch_features.protonet import eval_protonet
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import warnings
warnings.filterwarnings("ignore")

### Configurations

In [2]:
# configs
VECTOR_DIM = 1024
CLUSTERING_METHOD = 'kmeans'
NUM_CLUSTERS = 3
NUM_PATCHES_PER_CLUSTER = 0
BATCH_SIZE = 32
K_FOLDS_PATH = r"E:\KSA Project\dataset\splits\kfolds.csv"
DATA_PATH = r"E:\KSA Project\dataset\uni_fivecrop_features"

### DataLoader

In [3]:
from sklearn.model_selection import KFold
from sklearn.utils import shuffle
from typing import List, Tuple
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans, DBSCAN
from sklearn.decomposition import PCA

class WSIDataset(Dataset):
    def __init__(self, save_dir: str, fold_ids: List[str]):
        self.data = []
        self.save_dir = save_dir
        self.fold_ids = fold_ids
        self._load_data()

    def _load_data(self):
        for wsi_folder in os.listdir(self.save_dir):
            wsi_folder_path = os.path.join(self.save_dir, wsi_folder)
            if not os.path.isdir(wsi_folder_path) and len(os.listdir(wsi_folder_path)) <= 15:
                # print(f"Skipping {wsi_folder} due to less than 18 patches")
                continue
            for wsi_file in os.listdir(wsi_folder_path):
                if wsi_file.endswith('.pt'):
                    wsi_id = wsi_file[:12]
                    if wsi_id not in self.fold_ids:
                        continue
                    try:
                        wsi_features = torch.load(os.path.join(wsi_folder_path, wsi_file))
                        # check if loaded features is not one feature vector then average them to make one feature vector
                        if isinstance(wsi_features, torch.Tensor) and wsi_features.dim() > 1:
                            # print(f"WSI ID: {wsi_id} | Features Shape: {wsi_features.shape}")
                            wsi_features = torch.mean(wsi_features, dim=0)
                        label = 0 if '_nonMSI' in wsi_file else 1
                        self.data.append((wsi_features, label, wsi_id))
                    except Exception as e:
                        print(f"Error loading {os.path.join(wsi_folder_path, wsi_file)}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features, label, wsi_id = self.data[idx]
        return features, label, wsi_id
    
    def apply_clustering(self, clustering_algorithm: str, num_clusters: int = 3, num_selected_patches: int = 0):
        """
        Apply clustering on the patches of each WSI and create a consistent WSI representation.
        Args:
        - clustering_algorithm: The clustering algorithm to use ('kmeans', 'dbscan', 'pca').
        - num_clusters: Number of clusters to create (only for k-means or similar algorithms).
        - num_selected_patches: Number of top patches to use for averaging within each cluster (optional).
        """
        clustered_data = []
        wsi_ids = set([wsi_id for _, _, wsi_id in self.data])  # Unique WSI IDs

        for wsi_id in wsi_ids:
            # Extract all patches for the WSI
            wsi_patches = [features for features, _, id in self.data if id == wsi_id]
            wsi_patches = torch.stack(wsi_patches)
            patch_array = wsi_patches.numpy()  # Convert to numpy for clustering

            # Step 1: Perform clustering
            if clustering_algorithm == 'kmeans':
                clustering_model = KMeans(n_clusters=num_clusters, random_state=42)
                clustering_model.fit(patch_array)
                cluster_labels = clustering_model.labels_
                cluster_centroids = clustering_model.cluster_centers_
            elif clustering_algorithm == 'dbscan':
                clustering_model = DBSCAN(eps=0.1, min_samples=2)
                cluster_labels = clustering_model.fit_predict(patch_array)
                unique_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
                num_clusters = unique_clusters
                # print(f"Unique clusters: {unique_clusters}")
                cluster_centroids = np.array([np.mean(patch_array[cluster_labels == i], axis=0) for i in range(num_clusters)])
                
            elif clustering_algorithm == 'pca':
                pca = PCA(n_components=num_clusters)
                transformed_features = pca.fit_transform(patch_array)
                cluster_labels = np.argmax(transformed_features, axis=1)
                cluster_centroids = pca.components_
            else:
                raise ValueError(f"Unsupported clustering algorithm: {clustering_algorithm}")

            # Step 2: Aggregate features within each cluster
            selected_features = []
            cluster_sums = []
            for cluster_idx in range(num_clusters):
                cluster_patches = patch_array[cluster_labels == cluster_idx]
                if len(cluster_patches) == 0:
                    continue
                # Optionally select top-ranked patches
                if num_selected_patches > 0:
                    distances = cdist(cluster_patches, [cluster_centroids[cluster_idx]], metric='euclidean').flatten()
                    sorted_indices = np.argsort(distances)
                    cluster_patches = cluster_patches[sorted_indices[:num_selected_patches]]

                # Average the cluster features and calculate cluster sum for sorting
                cluster_average = np.mean(cluster_patches, axis=0)
                cluster_sum = np.average(cluster_average)
                selected_features.append(cluster_average)
                cluster_sums.append(cluster_sum)

            # Step 3: Sort clusters by their sum values
            sorted_indices = np.argsort(cluster_sums)
            sorted_features = np.array(selected_features)[sorted_indices]

            # Step 4: Create a consistent WSI-level representation
            concatenated_features = np.concatenate(sorted_features)
            label = [label for _, label, id in self.data if id == wsi_id][0]  # Assume all patches have the same label
            clustered_data.append((torch.tensor(concatenated_features), label, wsi_id))

            # Print debug information (optional)
            # print(f"WSI ID: {wsi_id} | Total Patches: {len(patch_array)} | Clusters: {num_clusters} | Sorted Features Shape: {concatenated_features.shape}")

        # Update the dataset with clustered data
        self.data = clustered_data

### Evaluation Metric

In [4]:
def calculate_metric_averages_by_index(all_fold_results, metric_indices):
    """
    Calculate the average of specified metrics over multiple folds.

    Args:
        all_fold_results (list of dicts): Results for each fold.
        metric_indices (dict): Mapping of metric names to their indices.

    Returns:
        dict: Averages of the specified metrics across folds.
    """
    # Initialize averages dictionary
    averages = {metric: 0 for metric in metric_indices.keys()}
    counts = {metric: 0 for metric in metric_indices.keys()}  # Keep track of valid metrics
    num_folds = len(all_fold_results)

    for result in all_fold_results:
        # Iterate through metrics by their index
        for metric, index in metric_indices.items():
            try:
                metric_name = list(result.keys())[index]  # Extract the metric name by index
                if metric_name in result and isinstance(result[metric_name], (int, float)):  # Check if metric exists and is numeric
                    averages[metric] += result[metric_name]
                    counts[metric] += 1
            except IndexError:
                # Metric not present in this result due to model differences
                continue
            except Exception as e:
                print(f"Error processing metric '{metric}': {e}")
    # Compute average only for metrics with valid values
    for metric in averages:
        if counts[metric] > 0:
            averages[metric] /= counts[metric]
    return averages

 

def print_metrics(metrics):
    for key, value in metrics.items():
        if isinstance(value, (int, float)):  # Check if the value is a number
            print(f"{key}: {value:.4f}")  # Format numbers with 4 decimal places
        # else:
        #     print(f"{key}: {value}")  # For non-numeric values, just print them directly

### Trainer Function

In [5]:
def train_and_evaluate(train_loader,val_loader, test_loader, model_type='linear'):
    all_train_feats, all_train_labels,all_val_feats,all_val_labels, all_test_feats, all_test_labels = [], [], [], [], [], []
    all_test_ids = []
    
    # Prepare training and testing data
    for features, label, _ in train_loader:
        all_train_feats.append(features)
        all_train_labels.append(label)
    for features, label, _ in val_loader:
        all_val_feats.append(features)
        all_val_labels.append(label)
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)
        all_test_ids.append(wsi_id)

    # Convert lists to tensors
    global train_feats, train_labels, val_feats, val_labels, test_feats, test_labels
    train_feats = torch.cat(all_train_feats)
    train_labels = torch.cat([labels.clone().detach() for labels in all_train_labels])
    val_feats = torch.cat(all_val_feats)
    val_labels = torch.cat([labels.clone().detach() for labels in all_val_labels])
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    
    # Select the model based on the input argument
    if model_type == 'linear':
        eval_metrics, eval_dump = eval_linear(
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=val_feats,  # Optionally, use a separate validation set
            valid_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            max_iter=250,
            verbose=False,
        )
    elif model_type == 'ann':
        eval_metrics, eval_dump = eval_ANN(
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=val_feats,
            valid_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            combine_trainval=False,
            input_dim=VECTOR_DIM * NUM_CLUSTERS,
            max_iter=250,
            verbose=False,
        )
    elif model_type == 'knn':
        eval_metrics, eval_dump = eval_knn(
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=val_feats,
            val_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            n_neighbors=5,
            normalize_feats=True,
            verbose=False
        )
    elif model_type == 'protonet':
        eval_metrics, eval_dump = eval_protonet(
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=val_feats,
            val_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            normalize_feats=True,
            verbose=False
        )
        
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

    return eval_metrics


### K Fold

In [6]:
from collections import Counter

def count_classes(dataset):
    """
    Helper function to count class occurrences in a dataset.
    """
    labels = []
    wsi_count = Counter([label for _, label, wsi_id in dataset.data])
    for _, label, _ in DataLoader(dataset, batch_size=1, shuffle=False):
        labels.append(label.item() if isinstance(label, torch.Tensor) else label)
    return wsi_count, Counter(labels)

# Cross-validation function
def run_k_fold_cross_validation(save_dir: str, folds: List[List[str]], model_type: str = 'linear'):
    results_per_fold = []

    num_folds = len(folds)

    for i in range(num_folds):
        # Define test and validation folds
        test_ids = folds[i]
        val_ids = folds[(i + 1) % num_folds]  # The next fold in sequence is used as validation

        # Use remaining folds as training
        train_ids = []
        for j in range(num_folds):
            if j != i and j != (i + 1) % num_folds:
                train_ids.extend(folds[j])

        # Create datasets and loaders
        train_dataset = WSIDataset(save_dir, train_ids)
        # train_dataset.apply_clustering(clustering_algorithm=CLUSTERING_METHOD, num_clusters=NUM_CLUSTERS, num_selected_patches=NUM_PATCHES_PER_CLUSTER)
        val_dataset = WSIDataset(save_dir, val_ids)
        # val_dataset.apply_clustering(clustering_algorithm=CLUSTERING_METHOD,num_clusters=NUM_CLUSTERS, num_selected_patches=NUM_PATCHES_PER_CLUSTER)
        test_dataset = WSIDataset(save_dir, test_ids)
        # test_dataset.apply_clustering(clustering_algorithm=CLUSTERING_METHOD,num_clusters=NUM_CLUSTERS, num_selected_patches=NUM_PATCHES_PER_CLUSTER)
        # train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        # val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        # test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        # Train and evaluate
        print(f"Running Fold {i + 1} with model {model_type}...")
        print(f"Train: {count_classes(train_dataset)}")
        print(f"Validation: {count_classes(val_dataset)}")
        print(f"Test: {count_classes(test_dataset)}")
        # eval_metrics = train_and_evaluate(train_loader, val_loader,test_loader, model_type=model_type)
        # print_metrics(eval_metrics)
        # results_per_fold.append(eval_metrics)

    return results_per_fold 



### Main Runner Function

In [None]:
# Example usage:
folds_df = pd.read_csv(K_FOLDS_PATH)
# Define your folds
fold1_ids = folds_df['Fold1'].dropna().apply(lambda x: x[:12]).tolist()
fold2_ids = folds_df['Fold2'].dropna().apply(lambda x: x[:12]).tolist()
fold3_ids = folds_df['Fold3'].dropna().apply(lambda x: x[:12]).tolist()
fold4_ids = folds_df['Fold4'].dropna().apply(lambda x: x[:12]).tolist()
folds = [fold1_ids, fold2_ids, fold3_ids, fold4_ids]

# Run k-fold cross-validation with different models
model_types = ['linear','ann','knn','protonet']
# model_types = ['ann','protonet']
metric_indices = {
    'acc': 0,          # 'lin_acc' corresponds to index 0
    'bacc': 1,         # 'lin_bacc' corresponds to index 1
    'macro_f1': 2,     # 'lin_macro_f1' corresponds to index 2
    'weighted_f1': 3,  # 'lin_weighted_f1' corresponds to index 3
    'auroc': 5         # 'lin_auroc' corresponds to index 4
}

for model in model_types:
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    k_folds_results = run_k_fold_cross_validation(DATA_PATH, folds, model_type=model)
    # average_results = calculate_metric_averages_by_index(k_folds_results, metric_indices)
    print("\n\n Average results for all folds:")
    # for metric, value in average_results.items():
    #     print(f"{metric}: {value:.4f}")



 ********* Training with model: linear********* 


Running Fold 1 with model linear...
Confusion Matrix:
[[14  0]
 [ 4  0]]
lin_acc: 0.7778
lin_bacc: 0.5000
lin_kappa: 0.0000
lin_weighted_f1: 0.6806
lin_auroc: 0.5000
Running Fold 2 with model linear...
Confusion Matrix:
[[14  0]
 [ 5  2]]
lin_acc: 0.7619
lin_bacc: 0.6429
lin_kappa: 0.3478
lin_weighted_f1: 0.7138
lin_auroc: 0.8265
Running Fold 3 with model linear...
Confusion Matrix:
[[15  1]
 [ 3  0]]
lin_acc: 0.7895
lin_bacc: 0.4688
lin_kappa: -0.0857
lin_weighted_f1: 0.7430
lin_auroc: 0.6250
Running Fold 4 with model linear...
Confusion Matrix:
[[15  0]
 [ 0  1]]
lin_acc: 1.0000
lin_bacc: 1.0000
lin_kappa: 1.0000
lin_weighted_f1: 1.0000
lin_auroc: 1.0000


 Average results for all folds:
acc: 0.8323
bacc: 0.6529
kappa: 0.3155
weighted_f1: 0.7843
auroc: 0.7379


 ********* Training with model: ann********* 


Running Fold 1 with model ann...
Confusion Matrix:
[[11  3]
 [ 1  3]]
ann_acc: 0.7778
ann_bacc: 0.7679
ann_kappa: 0.4545
ann_weighted_f1: 0.7915
ann_auroc: 0.9107
Running Fold 2 with model ann...
Confusion Matrix:
[[12  2]
 [ 2  5]]
ann_acc: 0.8095
ann_bacc: 0.7857
ann_kappa: 0.5714
ann_weighted_f1: 0.8095
ann_auroc: 0.8980
Running Fold 3 with model ann...
Confusion Matrix:
[[9 7]
 [2 1]]
ann_acc: 0.5263
ann_bacc: 0.4479
ann_kappa: -0.0621
ann_weighted_f1: 0.5901
ann_auroc: 0.6250
Running Fold 4 with model ann...
Confusion Matrix:
[[9 6]
 [0 1]]
ann_acc: 0.6250
ann_bacc: 0.8000
ann_kappa: 0.1579
ann_weighted_f1: 0.7188
ann_auroc: 1.0000


 Average results for all folds:
acc: 0.6847
bacc: 0.7004
kappa: 0.2804
weighted_f1: 0.7275
auroc: 0.8584


 ********* Training with model: knn********* 


Running Fold 1 with model knn...
knn5_acc: 0.7778
knn5_bacc: 0.5000
knn5_kappa: 0.0000
knn5_weighted_f1: 0.6806
Running Fold 2 with model knn...
knn5_acc: 0.6667
knn5_bacc: 0.5000
knn5_kappa: 0.0000
knn5_weighted_f1: 0.5333
Running Fold 3 with model knn...
knn5_acc: 0.8421
knn5_bacc: 0.5000
knn5_kappa: 0.0000
knn5_weighted_f1: 0.7699
Running Fold 4 with model knn...
knn5_acc: 0.8750
knn5_bacc: 0.4667
knn5_kappa: -0.0667
knn5_weighted_f1: 0.8750


 Average results for all folds:
acc: 0.7904
bacc: 0.4917
kappa: -0.0167
weighted_f1: 0.7147
auroc: 0.0000


 ********* Training with model: protonet********* 


Running Fold 1 with model protonet...
proto_acc: 0.6667
proto_bacc: 0.4286
proto_kappa: -0.1739
proto_weighted_f1: 0.6222
Running Fold 2 with model protonet...
proto_acc: 0.6190
proto_bacc: 0.5714
proto_kappa: 0.1429
proto_weighted_f1: 0.6190
Running Fold 3 with model protonet...
proto_acc: 0.8421
proto_bacc: 0.5000
proto_kappa: 0.0000
proto_weighted_f1: 0.7699
Running Fold 4 with model protonet...
proto_acc: 0.8750
proto_bacc: 0.9333
proto_kappa: 0.4483
proto_weighted_f1: 0.9018


 Average results for all folds:
acc: 0.7507
bacc: 0.6083
kappa: 0.1043
weighted_f1: 0.7282
auroc: 0.0000
