In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from eval_patch_features.logistic import eval_linear
from eval_patch_features.ann import eval_ANN
from eval_patch_features.knn import eval_knn
from eval_patch_features.protonet import eval_protonet
from eval_patch_features.metrics import get_eval_metrics, print_metrics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Configurations

In [2]:
# configs
VECTOR_DIM = 1024
BATCH_SIZE = 32
K_FOLDS_PATH = r"E:\KSA Project\dataset\splits\kfolds.csv"
DATA_PATH = r"E:\KSA Project\dataset\uni_features\all_data"
# torch.tensor([1.2, 3.4]).device

#### If you have only Five-Crop level extracted features saved. Then you can make average features and save them at some different location

In [None]:
data = "E:\KSA Project\dataset\IDARS_Fivecrop_Features"
data_save = "E:\KSA Project\dataset\IDARS_Features"

# load the data
for folder in os.listdir(data):
    print(f'folder {folder}')
    for wsi in os.listdir(j_(data, folder)):
        print(f'wsi {wsi}')
        wsi_data = []
        for patch in os.listdir(j_(data, folder, wsi)):
            patch_data = torch.load(j_(data, folder, wsi, patch))
            # check if the loaded feature vector is five crop then average it first then append to the wsi_data
            if patch_data.shape[0] > 1:
                patch_data = patch_data.mean(dim=0)
            wsi_data.append(patch_data)
        wsi_data = torch.stack(wsi_data).mean(dim=0)
        save_path = j_(data_save, folder, wsi + ".pt")
        print(f'save path {save_path}')
        # torch.save(wsi_data, j_(data_save, folder, wsi + ".pt"))
        print(wsi_data.shape)
        print("saved")
    print("done")

### Data Loaders

In [3]:
import os
import torch
from torch.utils.data import Dataset
from typing import List

class WSIDataset(Dataset):
    def __init__(self, save_dir: str, fold_ids: List[str]):
        self.data = []
        self.save_dir = save_dir
        self.fold_ids = fold_ids
        self._load_data()

    def _load_data(self):
        for wsi_file in os.listdir(self.save_dir):
            wsi_path = os.path.join(self.save_dir, wsi_file)
            # Extract the WSI ID and truncate it to the first 12 characters
            wsi_id = wsi_file[:12]  # Extract first 12 characters
            if wsi_id not in self.fold_ids:
                continue  # Skip if the WSI is not in the current fold

            if wsi_path.endswith('.pt'):
                try:
                    # Load WSI features
                    wsi_features = torch.load(wsi_path)
                    # Average all feature vectors in the .pt file
                    if isinstance(wsi_features, torch.Tensor) and wsi_features.dim() > 1:
                        averaged_features = torch.mean(wsi_features, dim=0)
                    else:
                        averaged_features = wsi_features  # In case it is already a single vector
                    # Determine label based on WSI file name
                    label = 0 if '_nonMSI' in wsi_file else 1

                    # Append the averaged features, label, and WSI ID
                    self.data.append((averaged_features, label, wsi_id))
                except Exception as e:
                    print(f"Error loading {wsi_path}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features, label, wsi_id = self.data[idx]
        return features, label, wsi_id


### Evaluation Metrics

In [4]:
def calculate_metric_averages_by_index(all_fold_results, metric_indices):
    """
    Calculate the average of specified metrics over multiple folds.

    Args:
        all_fold_results (list of dicts): Results for each fold.
        metric_indices (dict): Mapping of metric names to their indices.

    Returns:
        dict: Averages of the specified metrics across folds.
    """
    # Initialize averages dictionary
    averages = {metric: 0 for metric in metric_indices.keys()}
    counts = {metric: 0 for metric in metric_indices.keys()}  # Keep track of valid metrics
    num_folds = len(all_fold_results)

    for result in all_fold_results:
        # Iterate through metrics by their index
        for metric, index in metric_indices.items():
            try:
                metric_name = list(result.keys())[index]  # Extract the metric name by index
                if metric_name in result and isinstance(result[metric_name], (int, float)):  # Check if metric exists and is numeric
                    averages[metric] += result[metric_name]
                    counts[metric] += 1
            except IndexError:
                # Metric not present in this result due to model differences
                continue
            except Exception as e:
                print(f"Error processing metric '{metric}': {e}")
    # Compute average only for metrics with valid values
    for metric in averages:
        if counts[metric] > 0:
            averages[metric] /= counts[metric]
    return averages

 

def print_metrics(metrics):
    for key, value in metrics.items():
        if isinstance(value, (int, float)):  # Check if the value is a number
            print(f"{key}: {value:.4f}")  # Format numbers with 4 decimal places
        # else:
        #     print(f"{key}: {value}")  # For non-numeric values, just print them directly

## Trainer Function

In [5]:
def train_and_evaluate(train_loader,val_loader, test_loader, model_type='linear'):
    all_train_feats, all_train_labels,all_val_feats,all_val_labels, all_test_feats, all_test_labels = [], [], [], [], [], []
    all_test_ids = []
    
    # Prepare training and testing data
    for features, label, _ in train_loader:
        all_train_feats.append(features)
        all_train_labels.append(label)
    for features, label, _ in val_loader:
        all_val_feats.append(features)
        all_val_labels.append(label)
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)
        all_test_ids.append(wsi_id)

    # Convert lists to tensors
    global train_feats, train_labels, val_feats, val_labels, test_feats, test_labels
    train_feats = torch.cat(all_train_feats)
    train_labels = torch.cat([labels.clone().detach() for labels in all_train_labels])
    val_feats = torch.cat(all_val_feats)
    val_labels = torch.cat([labels.clone().detach() for labels in all_val_labels])
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    
    # Select the model based on the input argument
    if model_type == 'linear':
        eval_metrics, eval_dump = eval_linear(
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=val_feats,  # Optionally, use a separate validation set
            valid_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            max_iter=300,
            verbose=False,
        )
    elif model_type == 'ann':
        eval_metrics, eval_dump = eval_ANN(
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=val_feats,
            valid_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            combine_trainval=False,
            input_dim=VECTOR_DIM,
            max_iter=250,
            verbose=False,
        )
    elif model_type == 'knn':
        eval_metrics, eval_dump = eval_knn(
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=val_feats,
            val_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            n_neighbors=5,
            normalize_feats=True,
            verbose=False
        )
    elif model_type == 'protonet':
        eval_metrics, eval_dump = eval_protonet(
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=val_feats,
            val_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            normalize_feats=True,
            verbose=False
        )
        
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

    return eval_metrics


### K-Folds

In [6]:
from collections import Counter

def count_classes(dataset):
    """
    Helper function to count class occurrences in a dataset.
    """
    labels = []
    for _, label, _ in DataLoader(dataset, batch_size=1, shuffle=False):
        labels.append(label.item() if isinstance(label, torch.Tensor) else label)
    return Counter(labels)


def run_k_fold_cross_validation(save_dir: str, folds: List[List[str]], model_type: str = 'linear'):
    results_per_fold = []
    wsi_level_accuracies = []
    num_folds = len(folds)

    for i in range(num_folds):
        # Define test and validation folds
        test_ids = folds[i]
        val_ids = folds[(i + 1) % num_folds]  # The next fold in sequence is used as validation

        # Use remaining folds as training
        train_ids = []
        for j in range(num_folds):
            if j != i and j != (i + 1) % num_folds:
                train_ids.extend(folds[j])

        # Create datasets and loaders
        train_dataset = WSIDataset(save_dir, train_ids)
        val_dataset = WSIDataset(save_dir, val_ids)
        test_dataset = WSIDataset(save_dir, test_ids)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Running Fold {i + 1} with model {model_type}... \n")
        eval_metrics = train_and_evaluate(train_loader,val_loader, test_loader, model_type=model_type)
        print_metrics(eval_metrics)
        results_per_fold.append(eval_metrics)

    return results_per_fold

### Main Runner Function

In [7]:
# Example usage:
folds_df = pd.read_csv(K_FOLDS_PATH)
# Define your folds
fold1_ids = folds_df['Fold1'].dropna().apply(lambda x: x[:12]).tolist()
fold2_ids = folds_df['Fold2'].dropna().apply(lambda x: x[:12]).tolist()
fold3_ids = folds_df['Fold3'].dropna().apply(lambda x: x[:12]).tolist()
fold4_ids = folds_df['Fold4'].dropna().apply(lambda x: x[:12]).tolist()
folds = [fold1_ids, fold2_ids, fold3_ids, fold4_ids]

# Run k-fold cross-validation with different models
model_types = ['linear','ann','knn','protonet']
# model_types = ['ann','protonet']
metric_indices = {
    'acc': 0,          # 'lin_acc' corresponds to index 0
    'bacc': 1,         # 'lin_bacc' corresponds to index 1
    'macro_f1': 2,     # 'lin_macro_f1' corresponds to index 2
    'weighted_f1': 3,  # 'lin_weighted_f1' corresponds to index 3
    'auroc': 5         # 'lin_auroc' corresponds to index 4
}

for model in model_types:
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    k_folds_results = run_k_fold_cross_validation(DATA_PATH, folds, model_type=model)
    average_results = calculate_metric_averages_by_index(k_folds_results, metric_indices)
    print("\n\n Average results for all folds:")
    for metric, value in average_results.items():
        print(f"{metric}: {value:.4f}")



 ********* Training with model: linear********* 


Running Fold 1 with model linear... 

Train: Counter({0: 178, 1: 33})
Validation: Counter({0: 81, 1: 13})
Test: Counter({0: 85, 1: 15})

Confusion Matrix:
[[83  2]
 [ 7  8]]
lin_acc: 0.9100
lin_bacc: 0.7549
lin_macro_f1: 0.7943
lin_weighted_f1: 0.9023
lin_auroc: 0.8800
Running Fold 2 with model linear... 

Train: Counter({0: 172, 1: 28})
Validation: Counter({0: 91, 1: 20})
Test: Counter({0: 81, 1: 13})

Confusion Matrix:
[[78  3]
 [ 6  7]]
lin_acc: 0.9043
lin_bacc: 0.7507
lin_macro_f1: 0.7771
lin_weighted_f1: 0.8989
lin_auroc: 0.9069
Running Fold 3 with model linear... 

Train: Counter({0: 166, 1: 28})
Validation: Counter({0: 87, 1: 13})
Test: Counter({0: 91, 1: 20})

Confusion Matrix:
[[88  3]
 [11  9]]
lin_acc: 0.8739
lin_bacc: 0.7085
lin_macro_f1: 0.7444
lin_weighted_f1: 0.8608
lin_auroc: 0.8791
Running Fold 4 with model linear... 

Train: Counter({0: 172, 1: 33})
Validation: Counter({0: 85, 1: 15})
Test: Counter({0: 87, 1: 13})

