In [1090]:
import torch
import torchvision
from torch.utils.data import DataLoader
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from eval_patch_features.logistic import eval_linear
from eval_patch_features.ann import eval_ANN
from eval_patch_features.knn import eval_knn
from eval_patch_features.protonet import eval_protonet
from eval_patch_features.metrics import get_eval_metrics, print_metrics
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Configurations

In [1091]:
# configs
VECTOR_DIM = 1280  # size of input feature vector
HIDDEN_DIM = 768   # size of ANN hidden layer
BATCH_SIZE = 8
K_FOLDS_PATH = r"E:\Aamir Gulzar\dataset\paip_data\labels\TrainTest_paip.csv"
DATA_PATH = r"E:\Aamir Gulzar\dataset\paip_data\virchow2_Features"
MODEL_SAVE_PATH = None


In [1092]:
# data = r"E:\Aamir Gulzar\dataset\paip_data\CAIMAN_FiveCrop_Features"
# data_save = r"E:\Aamir Gulzar\dataset\paip_data\CAIMAN_Features"
# if not os.path.exists(data_save):
#     os.makedirs(data_save, exist_ok=True)
# # load the data
# for wsi in os.listdir(data):
#     wsi_data = []
#     for patch in os.listdir(j_(data, wsi)):
#         patch_data = torch.load(j_(data, wsi, patch))
#         # check if the loaded feature vector is five crop then average it first then append to the wsi_data
#         if patch_data.shape[0] > 1:
#             patch_data = patch_data.mean(dim=0)
#         wsi_data.append(patch_data)
#     wsi_data = torch.stack(wsi_data).mean(dim=0)
#     save_path = j_(data_save, wsi + ".pt")
#     torch.save(wsi_data, save_path)
# print("done")

In [1093]:
# import torch

# # Define the path to the .pt file
# file_path = r"E:\Aamir Gulzar\dataset\paip_data\UNI_Features\training_data_01_MSIH.pt"

# # Load the .pt file
# try:
#     data = torch.load(file_path)
#     shape = torch.tensor(data).shape  # Get shape
#     print("Shape of the data:", shape)
#     print("Values:\n", data)  # Print values
# except Exception as e:
#     print("Error loading file:", e)

### Data Loaders

In [1094]:
import os
import torch
from torch.utils.data import Dataset
from typing import List

class WSIDataset(Dataset):
    def __init__(self, save_dir: str, fold_ids: List[str]):
        self.data = []
        self.save_dir = save_dir
        self.fold_ids = fold_ids
        self._load_data()

    def _load_data(self):
        for wsi_file in os.listdir(self.save_dir):
            wsi_path = os.path.join(self.save_dir, wsi_file)
            # Extract the WSI ID and truncate it to the first 12 characters
            wsi_id = os.path.splitext(wsi_file)[0]   # Extract first 12 characters
            if wsi_id not in self.fold_ids:
                continue  # Skip if the WSI is not in the current fold

            if wsi_path.endswith('.pt'):
                try:
                    # Load WSI features
                    wsi_features = torch.load(wsi_path)
                    if wsi_features.is_cuda:
                        wsi_features = wsi_features.cpu()
                    # Average all feature vectors in the .pt file
                    if wsi_features.dim() > 1:
                        averaged_features = torch.mean(wsi_features, dim=0)
                    else:
                        averaged_features = wsi_features  # In case it is already a single vector

                    # Determine label based on WSI file name
                    label = 0 if '_nonMSI' in wsi_file else 1

                    # Append the averaged features, label, and WSI ID
                    self.data.append((averaged_features, label, wsi_id))
                except Exception as e:
                    print(f"Error loading {wsi_path}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features, label, wsi_id = self.data[idx]
        return features, label, wsi_id

### Evaluation Metrics

In [1095]:
def calculate_metric_averages_by_index(all_fold_results, metric_indices):
    """
    Calculate the average of specified metrics over multiple folds.
    Args:
        all_fold_results (list of dicts): Results for each fold.
        metric_indices (dict): Mapping of metric names to their indices.
    Returns:
        dict: Averages of the specified metrics across folds.
    """
    # Initialize averages dictionary
    averages = {metric: 0 for metric in metric_indices.keys()}
    counts = {metric: 0 for metric in metric_indices.keys()}  # Keep track of valid metrics
    num_folds = len(all_fold_results)

    for result in all_fold_results:
        # Iterate through metrics by their index
        for metric, index in metric_indices.items():
            try:
                metric_name = list(result.keys())[index]  # Extract the metric name by index
                if metric_name in result and isinstance(result[metric_name], (int, float)):  # Check if metric exists and is numeric
                    averages[metric] += result[metric_name]
                    counts[metric] += 1
            except IndexError:
                # Metric not present in this result due to model differences
                continue
            except Exception as e:
                print(f"Error processing metric '{metric}': {e}")
    # Compute average only for metrics with valid values
    for metric in averages:
        if counts[metric] > 0:
            averages[metric] /= counts[metric]
    return averages


## Trainer Function

In [1096]:
def train_and_evaluate(fold,train_loader, test_loader, model_type='linear'):
    all_train_feats, all_train_labels, all_test_feats, all_test_labels = [], [], [], []
    
    # Prepare training and testing data
    for features, label, _ in train_loader:
        all_train_feats.append(features)
        all_train_labels.append(label)
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)

    # Convert lists to tensors
    global train_feats, train_labels, val_feats, val_labels, test_feats, test_labels
    train_feats = torch.cat(all_train_feats)
    train_labels = torch.cat([labels.clone().detach() for labels in all_train_labels])
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    
    # Select the model based on the input argument
    if model_type == 'linear':
        eval_metrics, eval_dump = eval_linear(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=None,  # Optionally, use a separate validation set
            valid_labels=None,
            test_feats=test_feats,
            test_labels=test_labels,
            max_iter=350,
            save_path = MODEL_SAVE_PATH,
            verbose=False,
        )
    elif model_type == 'ann':
        eval_metrics, eval_dump = eval_ANN(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=None,
            valid_labels=None,
            test_feats=test_feats,
            test_labels=test_labels,
            input_dim=VECTOR_DIM,
            hidden_dim = HIDDEN_DIM,
            model_save_path = MODEL_SAVE_PATH,
            max_iter=350,
            verbose=False,
        )
    elif model_type == 'knn':
        eval_metrics, eval_dump = eval_knn(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=None,
            val_labels=None,
            test_feats=test_feats,
            test_labels=test_labels,
            n_neighbors=5,
            normalize_feats=True,
            model_save_path = MODEL_SAVE_PATH,
            verbose=False
        )
    elif model_type == 'protonet':
        eval_metrics, eval_dump = eval_protonet(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=None,
            val_labels=None,
            test_feats=test_feats,
            test_labels=test_labels,
            normalize_feats=True,
            model_save_path = MODEL_SAVE_PATH
        )
        
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    return eval_metrics

### K-Folds

In [1097]:
from collections import Counter

def count_classes(dataset):
    """
    Helper function to count class occurrences in a dataset.
    """
    labels = []
    for _, label, _ in DataLoader(dataset, batch_size=1, shuffle=False):
        labels.append(label.item() if isinstance(label, torch.Tensor) else label)
    return Counter(labels)

def run_k_fold_cross_validation(save_dir: str, folds: List[List[str]], model_type: str = 'linear'):
    results_per_fold = []
    wsi_level_accuracies = []
    num_folds = len(folds)

    for i in range(1):
        # Define test and validation folds
        train_ids = folds[i]
        test_ids = folds[i + 1]  # The next fold in sequence is used as validation
        print(f"Running Fold {i + 1} with model {model_type}...")
        # Create datasets and loaders
        train_dataset = WSIDataset(save_dir, train_ids)
        test_dataset = WSIDataset(save_dir, test_ids)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        eval_metrics = train_and_evaluate(i,train_loader, test_loader, model_type=model_type)
        print_metrics(eval_metrics)
        results_per_fold.append(eval_metrics)

    return results_per_fold

### Main Runner Function

In [1098]:
# Example usage:
folds_df = pd.read_csv(K_FOLDS_PATH)
# Define your folds
fold1_ids = folds_df['Fold1'].dropna().tolist()
fold2_ids = folds_df['Fold2'].dropna().tolist()
folds = [fold1_ids, fold2_ids]
# Run k-fold cross-validation with different models
model_types = ['linear','ann','knn','protonet']
metric_indices = {
    'acc': 0,          # 'lin_acc' corresponds to index 0
    'bacc': 1,         # 'lin_bacc' corresponds to index 1
    'macro_f1': 2,        # 'lin_kappa' corresponds to index 2
    'weighted_f1': 3,  # 'lin_weighted_f1' corresponds to index 3
    'auroc': 4         # 'lin_auroc' corresponds to index 4
}

for model in model_types:
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    k_folds_results = run_k_fold_cross_validation(DATA_PATH, folds, model_type=model)
    # average_results = calculate_metric_averages_by_index(k_folds_results, metric_indices)
    # print("\n\n Average results for all folds:")
    # for metric, value in average_results.items():
    #     print(f"{metric}: {value:.4f}")



 ********* Training with model: linear********* 


Running Fold 1 with model linear...
lin_acc: 0.6129
lin_bacc: 0.5982
lin_macro_f1: 0.5571
lin_weighted_f1: 0.6433
lin_auroc: 0.6310
lin_conf_matrix: [[15  9]
 [ 3  4]]


 ********* Training with model: ann********* 


Running Fold 1 with model ann...
acc: 0.5806
bacc: 0.6280
macro_f1: 0.5507
weighted_f1: 0.6143
auroc: 0.6250
conf_matrix: [[13 11]
 [ 2  5]]


 ********* Training with model: knn********* 


Running Fold 1 with model knn...
knn_acc: 0.7742
knn_bacc: 0.6012
knn_macro_f1: 0.6132
knn_weighted_f1: 0.7500
knn_auroc: 0.5595
knn_conf_matrix: [[22  2]
 [ 5  2]]


 ********* Training with model: protonet********* 


Running Fold 1 with model protonet...
proto_acc: 0.5484
proto_bacc: 0.5060
proto_macro_f1: 0.4833
proto_weighted_f1: 0.5839
proto_auroc: 0.5357
proto_conf_matrix: [[14 10]
 [ 4  3]]
