In [1]:
# add path to the notebook environment
import sys
sys.path.append('../')

import pandas as pd
import torch

data_combined = pd.read_excel('../data/combined_labels_with_patient_id.xlsx')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("Number of samples:", len(data_combined))

Using device: cuda
Number of samples: 1943


In [2]:
from utils.train_utils import load_features_mmap, prepare_labels

encoder = "uni_v1"
verbose = True
precision = 16
major_dir = "../../combined_features/wsi_processed_no_penmarks"
feature_dir = f"{major_dir}/features_{encoder}"

preloaded_features, D, preloaded = load_features_mmap(data_combined, feature_dir=feature_dir, id_col="De-ID", verbose=verbose, precision=precision)

[load_features_mmap] Loading mmap store from: ../../combined_features/wsi_processed_no_penmarks/features_uni_v1/combined_mmap_16
[load_features_mmap] 1943 WSIs | feature_dim=1024 | total_patches=23825514
[load_features_mmap] Est preload for selected slides: 45.44 GB | Avail RAM: 114.91 GB | decision: PRELOAD


100%|██████████| 1943/1943 [00:00<00:00, 92620.96it/s]


In [3]:
from models.wsi_model import WSIModel

encoder_type = "ABMIL"  # or "ABMIL", "TransMIL", "Mean", "Max", "WIKGMIL", "DSMIL", "CLAM", "ILRA"
if encoder_type == "ABMIL":
    encoder_attrs = {
        "attn_dim": 384,
        "gate": True
    }
elif encoder_type == "TransMIL":
    encoder_attrs = {
        "num_attention_layers": 2,
        "num_heads": 4
    }
elif encoder_type == "WIKGMIL":
    encoder_attrs = {
        "agg_type": "bi-interaction",
        "pool": "attn",
        "topk": 4
    }
elif encoder_type == "DSMIL":
    encoder_attrs = {
        "attn_dim": 384,
        "dropout_v": 0.0
    }
elif encoder_type == "CLAM":
    encoder_attrs = {
        "attention_dim": 384,
        "gate": True,
        "k_sample": 8,
        "subtyping": False,
        "instance_loss_fn": "svm",
        "bag_weight": 0.7
    }
else:
    encoder_attrs = {}

task = "Binary Stage N"  # binary_tasks = ["Binary Stage N", "Binary Stage T", "Binary TNM Stage", "died_within_5_years", "MSI", "BRAF", "KRAS", "NRAS", "RAS"]

train_df, train_labels, train_class_weights, n_classes, patient_id_mapping = prepare_labels(data_combined, "De-ID", task, verbose, cohorts=["SR"], patient_id_col="Patient ID")
test_df, test_labels, test_class_weights, _, _ = prepare_labels(data_combined, "De-ID", task, verbose, cohorts=["RIH"], patient_id_col=None)
train_class_weights = train_class_weights.to(device) if train_class_weights is not None else None
test_class_weights = test_class_weights.to(device) if test_class_weights is not None else None

n_splits        = 5
test_size       = 0.20

# Optimizer / LR schedule
lr              = 1e-04           # stable with eff. BS=16; see notes for alt
l2_reg          = 1e-02           # use param groups: no WD on bias/LayerNorm
scheduler       = "cosine"
epochs          = 40
warmup_epochs   = 0              # ~10% of total
step_on_epochs  = True           # per-update schedule (correct with accumulation)
accum_steps     = 1
early_stopping  = 8

Filtered to 704 samples with valid labels for task 'Binary Stage N'.
Detected 2 classes: [0.0, 1.0]
Computed class weights: [1.1578947305679321, 0.8799999952316284]
Class distribution -> '0.0' (304), '1.0' (400)
Filtered to 175 samples with valid labels for task 'Binary Stage N'.
Detected 2 classes: [0.0, 1.0]
Computed class weights: [1.75, 0.7000000476837158]
Class distribution -> '0.0' (50), '1.0' (125)


In [4]:
from scripts.experiment import run_experiment

test_metrics, ci_dict, final_model, cv_results = run_experiment(
    model_builder=lambda: WSIModel(
        input_feature_dim=D,  # Example input feature dimension
        n_classes=n_classes,
        encoder_type=encoder_type,
        head_dropout=0.35,
        head_dim=512,
        num_fc_layers=1,
        hidden_dim=128,
        ds_dropout=0.3,
        simple_mlp=False,
        freeze_encoder=False,
        encoder_attrs=encoder_attrs
    ),
    preloaded_features=preloaded_features,
    train_labels=train_labels,
    test_labels=test_labels,
    patient_id_mapping=patient_id_mapping,
    device=device,
    epochs=epochs,
    task=task,
    lr=lr,
    l2_reg=l2_reg,
    early_stopping=early_stopping,
    bag_size=None,
    replacement=False,
    class_weights=train_class_weights,
    key_metric="roc_auc" if n_classes == 2 else "balanced_accuracy",
    precision=precision,
    warmup_epochs=warmup_epochs,
    accum_steps=accum_steps,
    step_on_epochs=step_on_epochs,
    preloaded=preloaded,
    feature_dir=feature_dir,
    optimal_threshold=True
)


=== Running Experiment on Device: cuda ===
Train set size: 704
Test set size: 175

=== Running Experiment with 5-Fold Cross-Validation ===
Task: Binary Stage N, Sample Size: 704 (train+val)

=== Fold 1/5 ===
[Epoch 5/40] TrainLoss=0.6045 | ValScore=0.7296 | ValLoss=0.8935 | accuracy=0.6786; precision=0.7324; recall=0.6667; f1=0.6980; balanced_accuracy=0.6801; log_loss=0.8712; roc_auc=0.7296
[Epoch 10/40] TrainLoss=0.3194 | ValScore=0.6935 | ValLoss=1.6046 | accuracy=0.6429; precision=0.6591; recall=0.7436; f1=0.6988; balanced_accuracy=0.6299; log_loss=1.5070; roc_auc=0.6935
[Early stopping at epoch 10; Saving best model at epoch 2

[Fold 1] Best Val Score: 0.7661
[Fold 1] [Binary Stage N] accuracy: 0.6786
[Fold 1] [Binary Stage N] precision: 0.6435
[Fold 1] [Binary Stage N] recall: 0.9487
[Fold 1] [Binary Stage N] f1: 0.7668
[Fold 1] [Binary Stage N] balanced_accuracy: 0.6437
[Fold 1] [Binary Stage N] log_loss: 0.8526
[Fold 1] [Binary Stage N] roc_auc: 0.7661


=== Fold 2/5 ===
[Epoch