In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from eval_patch_features.logistic import test_saved_logistic_model
from eval_patch_features.ann import test_saved_ann_model
from eval_patch_features.knn import test_saved_knn_model
from eval_patch_features.protonet import test_saved_protonet_model
from eval_patch_features.metrics import get_eval_metrics, print_metrics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Configurations

In [2]:
# configs
from pathlib import Path
BASE_DIR = Path(r"E:\Aamir Gulzar\WSI_Classification_Using_FM_Features\UNI_Average_Classifiers")

VECTOR_DIM = 1024  # size of input feature vector
HIDDEN_DIM = 768   # size of ANN hidden layer
BATCH_SIZE = 8
FOLD = 3  # this fold values is only used to load the respective fold model
K_FOLDS_PATH = r"E:\Aamir Gulzar\dataset\paip_data\labels\paip_47slides.csv"
DATA_PATH = r"E:\Aamir Gulzar\dataset\paip_data\UNI_Features"

ANN_MODEL_PATH = BASE_DIR / f"fold{FOLD}_trained_ann_model_{VECTOR_DIM}.pth"
LOGISTIC_MODEL_PATH = BASE_DIR / f"fold{FOLD}_logistic_regression.pkl"
KNN_MODEL_PATH = BASE_DIR / f"fold{FOLD}_knn_model.pkl"
PROTONET_MODEL_PATH = BASE_DIR / f"fold{FOLD}_protonet_model.pkl"

In [3]:
# data = r"E:\Aamir Gulzar\dataset\paip_data\Baseline_FiveCrop_Features"
# data_save = r"E:\Aamir Gulzar\dataset\paip_data\Baseline_Features"
# if not os.path.exists(data_save):
#     os.makedirs(data_save, exist_ok=True)

# # load the data
# for wsi in os.listdir(data):
#     wsi_data = []
#     for patch in os.listdir(j_(data, wsi)):
#         patch_data = torch.load(j_(data, wsi, patch))
#         # check if the loaded feature vector is five crop then average it first then append to the wsi_data
#         if patch_data.shape[0] > 1:
#             patch_data = patch_data.mean(dim=0)
#         wsi_data.append(patch_data)
#     wsi_data = torch.stack(wsi_data).mean(dim=0)
#     save_path = j_(data_save, wsi + ".pt")
#     torch.save(wsi_data, save_path)
# print("done")

In [4]:
# import torch

# # Define the path to the .pt file
# file_path = r"E:\Aamir Gulzar\dataset\paip_data\UNI_Features\training_data_01_MSIH.pt"

# # Load the .pt file
# try:
#     data = torch.load(file_path)
#     shape = torch.tensor(data).shape  # Get shape
#     print("Shape of the data:", shape)
#     print("Values:\n", data)  # Print values
# except Exception as e:
#     print("Error loading file:", e)

### Data Loaders

In [5]:
import os
import torch
from torch.utils.data import Dataset
from typing import List

class WSIDataset(Dataset):
    def __init__(self, save_dir: str, fold_ids: List[str]):
        self.data = []
        self.save_dir = save_dir
        self.fold_ids = fold_ids
        self._load_data()

    def _load_data(self):
        for wsi_file in os.listdir(self.save_dir):
            wsi_path = os.path.join(self.save_dir, wsi_file)
            # Extract the WSI ID and truncate it to the first 12 characters
            wsi_id = os.path.splitext(wsi_file)[0]   # Extract first 12 characters
            if wsi_id not in self.fold_ids:
                # print(f'skiping {wsi_file} becuase {wsi_id} not in list from path {wsi_path}')
                continue  # Skip if the WSI is not in the current fold

            if wsi_path.endswith('.pt'):
                try:
                    # Load WSI features
                    wsi_features = torch.load(wsi_path)
                    if wsi_features.is_cuda:
                        wsi_features = wsi_features.cpu()
                    # Average all feature vectors in the .pt file
                    if wsi_features.dim() > 1:
                        averaged_features = torch.mean(wsi_features, dim=0)
                    else:
                        averaged_features = wsi_features  # In case it is already a single vector

                    # Determine label based on WSI file name
                    label = 0 if '_nonMSI' in wsi_file else 1

                    # Append the averaged features, label, and WSI ID
                    self.data.append((averaged_features, label, wsi_id))
                except Exception as e:
                    print(f"Error loading {wsi_path}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features, label, wsi_id = self.data[idx]
        return features, label, wsi_id

### Evaluation Metrics

In [6]:
def calculate_metric_averages_by_index(all_fold_results, metric_indices):
    """
    Calculate the average of specified metrics over multiple folds.
    Args:
        all_fold_results (list of dicts): Results for each fold.
        metric_indices (dict): Mapping of metric names to their indices.
    Returns:
        dict: Averages of the specified metrics across folds.
    """
    # Initialize averages dictionary
    averages = {metric: 0 for metric in metric_indices.keys()}
    counts = {metric: 0 for metric in metric_indices.keys()}  # Keep track of valid metrics
    num_folds = len(all_fold_results)

    for result in all_fold_results:
        # Iterate through metrics by their index
        for metric, index in metric_indices.items():
            try:
                metric_name = list(result.keys())[index]  # Extract the metric name by index
                if metric_name in result and isinstance(result[metric_name], (int, float)):  # Check if metric exists and is numeric
                    averages[metric] += result[metric_name]
                    counts[metric] += 1
            except IndexError:
                # Metric not present in this result due to model differences
                continue
            except Exception as e:
                print(f"Error processing metric '{metric}': {e}")
    # Compute average only for metrics with valid values
    for metric in averages:
        if counts[metric] > 0:
            averages[metric] /= counts[metric]
    return averages

def print_metrics(metrics):
    for key, value in metrics.items():
        if isinstance(value, (int, float)):  # Check if the value is a number
            print(f"{key}: {value:.4f}")  # Format numbers with 4 decimal places
        # else:
        #     print(f"{key}: {value}")  # For non-numeric values, just print them directly

## Trainer Function

In [7]:
def evaluate(test_loader, model_type='linear'):
    all_test_feats, all_test_labels = [], []
    
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)

    # Convert lists to tensors
    global test_feats, test_labels
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    # Select the model based on the input argument
    if model_type == 'ann':
        eval_metrics, dump = test_saved_ann_model(
            input_dim=VECTOR_DIM,
            hidden_dim=HIDDEN_DIM,
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=ANN_MODEL_PATH
        )
    elif model_type == 'linear':
        eval_metrics = test_saved_logistic_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=LOGISTIC_MODEL_PATH
        )
    elif model_type == 'knn':
        eval_metrics = test_saved_knn_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=KNN_MODEL_PATH
        )
    elif model_type == 'protonet':
        eval_metrics = test_saved_protonet_model(
            test_feats=test_feats,
            test_labels=test_labels,
            model_path=PROTONET_MODEL_PATH
        )
    
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    return eval_metrics

### K-Folds

### Main Runner Function

In [8]:
# Example usage:
slides = pd.read_csv(K_FOLDS_PATH)
slides = slides['Fold1'].dropna().values.tolist()
# Define your folds
# Run k-fold cross-validation with different models
model_types = ['linear','ann','knn','protonet']
metric_indices = {
    'acc': 0,          # 'lin_acc' corresponds to index 0
    'bacc': 1,         # 'lin_bacc' corresponds to index 1
    'macro_f1': 2,        # 'lin_kappa' corresponds to index 2
    'weighted_f1': 3,  # 'lin_weighted_f1' corresponds to index 3
    'auroc': 5         # 'lin_auroc' corresponds to index 4
}

for model in model_types:
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    test_dataset = WSIDataset(DATA_PATH, slides)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    eval_metrics = evaluate(test_loader, model_type=model)
    print_metrics(eval_metrics)



 ********* Training with model: linear********* 


Confusion Matrix:
[[35  0]
 [12  0]]
lin_acc: 0.7447
lin_bacc: 0.5000
lin_macro_f1: 0.4268
lin_weighted_f1: 0.6357
lin_auroc: 0.9024


 ********* Training with model: ann********* 


Confusion Matrix:
[[29  6]
 [ 2 10]]
acc: 0.8298
bacc: 0.8310
macro_f1: 0.7965
weighted_f1: 0.8368
auroc: 0.9095


 ********* Training with model: knn********* 


Confusion Matrix:
[[35  0]
 [11  1]]
knn_acc: 0.7660
knn_bacc: 0.5417
knn_macro_f1: 0.5090
knn_weighted_f1: 0.6828


 ********* Training with model: protonet********* 


Confusion Matrix:
[[34  1]
 [ 9  3]]
proto_acc: 0.7872
proto_bacc: 0.6107
proto_macro_f1: 0.6234
proto_weighted_f1: 0.7450


  prototypes = torch.tensor(prototypes)
  labels_proto = torch.tensor(labels_proto)
