In [1]:
# add path to the notebook environment
import sys
sys.path.append('../')

import pandas as pd
import torch

data_combined = pd.read_excel('../data/combined_labels_with_patient_id.xlsx')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("Number of samples:", len(data_combined))

Using device: cuda
Number of samples: 1943


In [3]:
from utils.train_utils import load_features_mmap, prepare_labels

encoder = "uni_v2"
verbose = True
precision = 16
major_dir = "../../combined_features/wsi_processed_no_penmarks"
feature_dir = f"{major_dir}/features_{encoder}"

preloaded_features, D, preloaded = load_features_mmap(data_combined, feature_dir=feature_dir, id_col="De-ID", verbose=verbose, precision=precision)

[load_features_mmap] Loading mmap store from: ../../combined_features/wsi_processed_no_penmarks/features_uni_v2/combined_mmap_16
[load_features_mmap] 1943 WSIs | feature_dim=1536 | total_patches=23825514
[load_features_mmap] Est preload for selected slides: 68.17 GB | Avail RAM: 50.10 GB | decision: INDEX-ONLY


100%|██████████| 1943/1943 [00:00<00:00, 900500.85it/s]

[load_features_mmap] Prepared 1943 slides with indices only. Use mmap slices later to load per-batch.





In [4]:
from models.wsi_model import WSIModel

encoder_type = "ABMIL"  # or "ABMIL", "TransMIL", "Mean", "Max", "WIKGMIL", "DSMIL", "CLAM", "ILRA"
if encoder_type == "ABMIL":
    encoder_attrs = {
        "attn_dim": 384,
        "gate": True
    }
elif encoder_type == "TransMIL":
    encoder_attrs = {
        "num_attention_layers": 2,
        "num_heads": 4
    }
elif encoder_type == "WIKGMIL":
    encoder_attrs = {
        "agg_type": "bi-interaction",
        "pool": "attn",
        "topk": 4
    }
elif encoder_type == "DSMIL":
    encoder_attrs = {
        "attn_dim": 384,
        "dropout_v": 0.0
    }
elif encoder_type == "CLAM":
    encoder_attrs = {
        "attention_dim": 384,
        "gate": True,
        "k_sample": 8,
        "subtyping": False,
        "instance_loss_fn": "svm",
        "bag_weight": 0.7
    }
else:
    encoder_attrs = {}

task = "Binary Stage N"  # binary_tasks = ["Binary Stage N", "Binary Stage T", "Binary TNM Stage", "died_within_5_years", "MSI", "BRAF", "KRAS", "NRAS", "RAS"]

train_df, train_labels, train_class_weights, n_classes, patient_id_mapping = prepare_labels(data_combined, "De-ID", task, verbose, cohorts=["SR"], patient_id_col="Patient ID")
test_df, test_labels, test_class_weights, _, _ = prepare_labels(data_combined, "De-ID", task, verbose, cohorts=["RIH"], patient_id_col=None)
train_class_weights = train_class_weights.to(device) if train_class_weights is not None else None
test_class_weights = test_class_weights.to(device) if test_class_weights is not None else None

n_splits        = 5
test_size       = 0.20

# Optimizer / LR schedule
lr              = 1e-04           # stable with eff. BS=16; see notes for alt
l2_reg          = 5e-03           # use param groups: no WD on bias/LayerNorm
scheduler       = "cosine"
epochs          = 40
warmup_epochs   = 0              # ~10% of total
step_on_epochs  = True           # per-update schedule (correct with accumulation)
accum_steps     = 1
early_stopping  = 8

Filtered to 704 samples with valid labels for task 'Binary Stage N'.
Detected 2 classes: [0.0, 1.0]
Computed class weights: [1.1578947305679321, 0.8799999952316284]
Class distribution -> '0.0' (304), '1.0' (400)
Filtered to 175 samples with valid labels for task 'Binary Stage N'.
Detected 2 classes: [0.0, 1.0]
Computed class weights: [1.75, 0.7000000476837158]
Class distribution -> '0.0' (50), '1.0' (125)


In [None]:
from scripts.train import run_experiment, set_global_seed

SEED = 42
set_global_seed(SEED)

test_metrics, ci_dict, final_model, cv_results = run_experiment(
    model_builder=lambda: WSIModel(
        input_feature_dim=1536,  # Example input feature dimension
        n_classes=n_classes,
        encoder_type=encoder_type,
        head_dropout=0.35,
        head_dim=512,
        num_fc_layers=1,
        hidden_dim=128,
        ds_dropout=0.3,
        simple_mlp=False,
        freeze_encoder=False,
        encoder_attrs=encoder_attrs
    ),
    preloaded_features=preloaded_features,
    train_labels=train_labels,
    test_labels=test_labels,
    patient_id_mapping=patient_id_mapping,
    device=device,
    epochs=epochs,
    task=task,
    lr=lr,
    l2_reg=l2_reg,
    early_stopping=early_stopping,
    bag_size=None,
    replacement=False,
    class_weights=train_class_weights,
    key_metric="roc_auc" if n_classes == 2 else "balanced_accuracy",
    precision=precision,
    warmup_epochs=warmup_epochs,
    accum_steps=accum_steps,
    step_on_epochs=step_on_epochs,
    preloaded=preloaded,
    feature_dir=feature_dir
)


=== Running Experiment on Device: cuda ===
Train set size: 704
Test set size: 175

=== Running Experiment with 5-Fold Cross-Validation ===
Task: Binary Stage N, Sample Size: 704 (train+val)
[Fold 1 | Epoch 5/40] TrainLoss=0.5202 | ValScore=0.7103 | ValLoss=0.8473 | accuracy=0.6000; precision=0.6964; recall=0.5000; f1=0.5821; balanced_accuracy=0.6129; log_loss=0.8473; roc_auc=0.7103
[Fold 1 | Epoch 10/40] TrainLoss=0.3024 | ValScore=0.6990 | ValLoss=1.0015 | accuracy=0.7071; precision=0.7126; recall=0.7949; f1=0.7515; balanced_accuracy=0.6958; log_loss=1.0015; roc_auc=0.6990
[Fold 1] Early stopping at epoch 11; Saving best model at epoch 3

[Fold 1] Best Val Score: 0.7654
[Fold 1] [Binary Stage N] accuracy: 0.7357
[Fold 1] [Binary Stage N] precision: 0.6990
[Fold 1] [Binary Stage N] recall: 0.9231
[Fold 1] [Binary Stage N] f1: 0.7956
[Fold 1] [Binary Stage N] balanced_accuracy: 0.7115
[Fold 1] [Binary Stage N] log_loss: 0.6331
[Fold 1] [Binary Stage N] roc_auc: 0.7654

[Fold 2 | Epoch 

In [6]:
from sklearn.metrics import roc_curve

def oof_positive_probs(oof_logits):  # torch.Tensor [N,1] or [N,2]
    import torch, torch.nn.functional as F
    if oof_logits.ndim == 1 or oof_logits.size(1) == 1:
        return torch.sigmoid(oof_logits.squeeze(-1)).cpu().numpy()
    else:
        return F.softmax(oof_logits, dim=1)[:, 1].cpu().numpy()

def tune_threshold_youden(y_true, p1):
    fpr, tpr, thr = roc_curve(y_true, p1)
    j = tpr - fpr
    i = int(np.argmax(j))
    # roc_curve returns len(thr)=len(tpr)=len(fpr); thr are decision thresholds on p1
    return float(thr[i]), float(j[i])

oof = cv_results["oof"]
y = oof["y"]
p1 = oof_positive_probs(oof["logits"])

t_j,   _ = tune_threshold_youden(y, p1)
print(f"Optimal threshold by Youden's J statistic: {t_j:.4f}")

Optimal threshold by Youden's J statistic: 0.6631


In [7]:
import numpy as np
from sklearn.metrics import confusion_matrix, precision_score, recall_score, roc_auc_score

def op_metrics(y_true, p1, t=0.9922):
    y_pred = (p1 >= t).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sens = tp / (tp + fn + 1e-12)   # recall
    spec = tn / (tn + fp + 1e-12)
    ppv  = precision_score(y_true, y_pred, zero_division=0)
    npv  = tn / (tn + fn + 1e-12)
    return dict(threshold=t, sensitivity=sens, specificity=spec, ppv=ppv, npv=npv)

# Make sure p1 are probabilities of the positive class!
# p1 = sigmoid(oof_logits.squeeze(1))  # if single-logit
# p1 = softmax(oof_logits,1)[:,1]      # if two-logit
op_dict = op_metrics(y, p1, t=t_j)
print(f"Optimal threshold metrics: {op_dict}")

Optimal threshold metrics: {'threshold': 0.6631237864494324, 'sensitivity': np.float64(0.7149999999999982), 'specificity': np.float64(0.6842105263157872), 'ppv': 0.7486910994764397, 'npv': np.float64(0.6459627329192527)}
