Imports

In [None]:
import sys, json, random, math
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Add project root to PYTHONPATH so relative imports work
sys.path.insert(0, "./")                      # adjust if notebook lives elsewhere

from data.multimodal_features_surv import MultimodalCTWSIDatasetSurv
from models.dpe.main_model_nobackbone_surv_new_gcs import madpe_nobackbone
from training.losses import CoxLoss


: 

In [None]:
##  ✏️ EDIT THESE PATHS ########################################################
CHECKPOINTS = {
    "baseline": "checkpoints/model_no_gcs.pt",
    "gcs"     : "checkpoints/model_with_gcs.pt",
}
DATA_CFG = dict(                         # mirrors the JSON config you train with
    fold                = 0,
    split               = "test",        # use held-out data for fair comparison
    ct_path             = "../MedImageInsights/embeddings_output_cptacpda",
    wsi_path            = "../trident/trident_processed/20x_224px_0px_overlap/slide_features_titan",
    labels_splits_path  = "./data/processed/processed_CPTACUCEC_survival/k=all.tsv",
    missing_modality_prob = 0.0,
    missing_modality      = "both",
    require_both_modalities = True,
    pairing_mode           = "one_to_one",
    allow_repeats          = False,
    pairs_per_patient      = None,
)
BATCH_SIZE_EVAL  = 16          # bigger → faster, mind VRAM
GRID_EXTENT      = 1.0         # how far to move (in ‖θ‖ units) along each dir
GRID_STEPS       = 51          # resolution: odd ⇒ centre cell is θ*
DEVICE           = "cuda" if torch.cuda.is_available() else "cpu"
SEED             = 0


In [None]:

SEED = 0


def set_global_seed(seed=SEED):
    """
    Set a global seed for reproducibility across different libraries and random number generators.

    Args:
        seed (int): Seed value to be used
    """
    # Python's built-in random module
    random.seed(seed)

    # Numpy
    np.random.seed(seed)

    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups

    # Configure PyTorch to make computations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def params_to_vec(model: nn.Module) -> torch.Tensor:
    return torch.nn.utils.parameters_to_vector(model.parameters()).detach()

def vec_to_params_(model: nn.Module, vec: torch.Tensor):
    torch.nn.utils.vector_to_parameters(vec, model.parameters())

@torch.no_grad()
def evaluate_loss(model: nn.Module, loader, loss_fn) -> float:
    """Mean Cox loss on the whole loader (no gradient)."""
    model.eval()
    tot_loss, nsamples = 0.0, 0
    for batch in loader:
        # move to device & rename to match model signature
        ct  = batch["ct_feature"].float().to(DEVICE)
        wsi = batch["wsi_feature"].float().to(DEVICE)
        mask= batch["modality_mask"].to(DEVICE)
        surv= batch["survtime"].to(DEVICE)
        cens= batch["censor"].to(DEVICE)
        out = model(ct, wsi, modality_flag=mask)["hazard"]
        loss = loss_fn(out, surv, cens)
        bs   = ct.size(0)
        tot_loss += loss.item() * bs
        nsamples += bs
    return tot_loss / nsamples
set_global_seed(SEED)
