# Interpretability

Sul fold migliore vado a fare forward pass e calcolarmi i gradienti, i positional embedding, correlazioni, curvatura e path degli input che danno quei risultati

Pensiero random (potrebbe essere carino per la tesi e quasi d’obbligo per il paper), lo lascerei comunque come ultimo step se abbiamo tempo: riusciamo sulla fold che va meglio, del dataset che va meglio per il nostro modello tenere traccia in validation di quali sono le coppie di feature (in modo da rintracciare la coppia rad-histo) che portano gradienti maggiori e minori e salvare le rispettive correlazioni (e calcolare il rischio)? Questo significa rintracciare le coppie la cui correlazione ha un’importanza o meno nel calcolo del rischio.
Quello che sarebbe interessante far vedere sono gli estremi opposti su due coppie di rad histo con:
Alta correlazione, alto gradiente, alto risk score, basso survival, alto grado di tumore: sulla radiologia calcolare tipo regolarità della forma (che ci i aspetta bassa, usando la segmentazione), grandezza tumore (che ci si aspetta alta, usando la segmentazione) e sulla histo il numero di cellule epiteliali cancerogene (che ci si aspetta alta, cosa che si può fare con una rete pretrainata) 
Alta correlazione, basso gradiente, basso risk score, alto survival, basso grado di tumore: sulla radiologia calcolare tipo regolarità della forma (che ci i aspetta alta), grandezza tumore (che ci si aspetta bassa, usando la segmentazione) e sulla histo il numero di cellule epiteliali cancerogene (che ci si aspetta bassa, cosa che si può fare con una rete pretrainata) 
Se tu ti occupi di tracciare i gradienti/correlazioni/file (non dovrebbe essere troppo complesso, basta salvarsi le combinazioni di risultati su un file), io mi occuperei di tutto il resto

### Define Functions for interpretability analysis

In [1]:
import os
import torch
import csv
import numpy as np
from torch.utils.data import DataLoader

from data.multimodal_features_surv import MultimodalCTWSIDatasetSurv
from models.dpe.main_model_nobackbone_surv_new_gcs import MADPENetNoBackbonesSurv
from torch.nn.modules.loss import _WeightedLoss


class CoxLoss(_WeightedLoss):
    def forward(self, hazard_pred: torch.Tensor, survtime: torch.Tensor, censor: torch.Tensor):
        censor = censor.float()
        n = len(survtime)
        # risk‑set matrix
        R_mat = survtime.reshape((1, n)) >= survtime.reshape((n, 1))
        theta = hazard_pred.reshape(-1)
        exp_theta = torch.exp(theta)
        # negative log‑partial likelihood
        loss = -torch.mean((theta - torch.log(torch.sum(exp_theta * R_mat, dim=1))) * censor)
        return loss


def compute_grad_and_curvature(fused_features, hazard, survtime, censor):
    hazard = hazard.view(-1)
    loss = CoxLoss()(hazard, survtime, censor)
    # first derivative
    grad_f = torch.autograd.grad(loss, fused_features, create_graph=True)[0]
    grad_norm = grad_f.flatten(1).norm(p=2, dim=1)

    # Hutchinson estimator for Hessian diagonal
    random_vec = grad_f.detach().clone().sign()
    grad_dot_random = torch.sum(grad_f * random_vec)
    hvp = torch.autograd.grad(grad_dot_random, fused_features, retain_graph=False)[0]
    curvature = torch.sum(hvp * random_vec, dim=list(range(1, hvp.ndim)))

    return grad_norm.detach().cpu().numpy(), curvature.detach().cpu().numpy()


def evaluate_and_log(folds_dir, ct_path, wsi_path, test_path, output_dir, batch_size=16, n_folds=5):
    os.makedirs(output_dir, exist_ok=True)

    for fold in range(n_folds):
        print(f"\n=== Fold {fold} ===")
        model = MADPENetNoBackbonesSurv(
            rad_input_dim=1024, histo_input_dim=768,
            inter_dim=256, token_dim=256, dim_hider=256
        )
        model.eval()

        # Load checkpoint
        fold_dir = os.path.join(folds_dir, f"fold_{fold}")
        model_subdir = next(d for d in os.listdir(fold_dir)
                            if os.path.isdir(os.path.join(fold_dir, d)))
        model_file = next(f for f in os.listdir(os.path.join(fold_dir, model_subdir))
                          if f.endswith("mixed_missing.pth"))
        checkpoint = torch.load(
            os.path.join(fold_dir, model_subdir, model_file), map_location="cpu"
        )
        model.load_state_dict(checkpoint["model_state_dict"])

        dataset = MultimodalCTWSIDatasetSurv(
            fold=fold,
            split="test",
            ct_path=ct_path,
            wsi_path=wsi_path,
            labels_splits_path=test_path,
            missing_modality_prob=0.0,
            require_both_modalities=False,
            pairing_mode="one_to_one",
            allow_repeats=True,
            pairs_per_patient=None,
            missing_modality="wsi"
        )

        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0)

        # CSV Logger
        log_file = os.path.join(output_dir, f"fold_{fold}_log.csv")
        fold_tensor_dir = os.path.join(output_dir, f"fold_{fold}")
        os.makedirs(fold_tensor_dir, exist_ok=True)

        with open(log_file, "w", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([
                "patient_id", "ct_path", "wsi_path",
                "ct_available", "wsi_available",
                "hazard_score", "survtime", "censor",
                "grad_norm", "curvature",
                "fused_path", "pe_path", "hazard_tensor_path"
            ])

            sample_counter = 0
            for batch in loader:
                # ensure floats
                for k, v in batch.items():
                    if isinstance(v, torch.Tensor):
                        batch[k] = v.float()

                with torch.enable_grad():
                    for v in batch.values():
                        if isinstance(v, torch.Tensor):
                            v.requires_grad = True

                    outputs = model(
                        batch["ct_feature"],
                        batch["wsi_feature"],
                        modality_flag=batch["modality_mask"],
                        output_layers=["hazard", "fused_features", "positional_embeddings"]
                    )
                    hazard = outputs["hazard"].squeeze(1)  # [B]
                    fused_features = outputs["fused_features"]    # [B, D]
                    pos_emb = outputs["positional_embeddings"]    # [B, D]

                    grad_norms, curvatures = compute_grad_and_curvature(
                        fused_features, hazard,
                        batch["survtime"], batch["censor"]
                    )

                # save per-sample
                B = hazard.size(0)
                for j in range(B):
                    sample_id = f"sample_{sample_counter:06d}"
                    fused_path = os.path.join(fold_tensor_dir, f"{sample_id}_fused.npy")
                    pe_path = os.path.join(fold_tensor_dir, f"{sample_id}_pe.npy")
                    hazard_path = os.path.join(fold_tensor_dir, f"{sample_id}_hazard.npy")

                    np.save(fused_path, fused_features[j].detach().cpu().numpy())
                    np.save(pe_path, pos_emb[j].detach().cpu().numpy())
                    np.save(hazard_path, hazard[j].detach().cpu().numpy())

                    ct_flag = int(batch["modality_mask"][j, 0].item())
                    wsi_flag = int(batch["modality_mask"][j, 1].item())
                    if batch["censor"][j].item()==0: continue
                    writer.writerow([
                        batch["patient_id"][j],
                        dataset.samples[sample_counter]["ct_path"],
                        dataset.samples[sample_counter]["wsi_feature"],
                        ct_flag, wsi_flag,
                        hazard[j].item(),
                        batch["survtime"][j].item(),
                        batch["censor"][j].item(),
                        grad_norms[j], curvatures[j],
                        fused_path, pe_path, hazard_path
                    ])

                    sample_counter += 1
                    

import pandas as pd
import os
def find_topk_gradient_extremes(
    csv_path: str,
    k: int = 10,
    output_dir: str = None,
):
    """
    Finds top-k highest and lowest |gradient norm| samples from a CSV log.

    Args:
        csv_path (str): Path to fold_X_log.csv generated during eval.
        k (int): Number of top and bottom samples to retrieve.
        output_dir (str): Optional folder to save CSVs. If None, only prints results.
    """
    df = pd.read_csv(csv_path)

    if "grad_norm" not in df.columns:
        raise ValueError("The CSV does not contain 'grad_norm' column.")

    # Compute absolute gradient norms
    df["abs_grad_norm"] = df["grad_norm"].abs()

    # Sort by absolute gradient norm
    df_sorted = df.sort_values(by="abs_grad_norm", ascending=True)

    topk_low = df_sorted.head(k).copy()
    topk_high = df_sorted.tail(k).copy()[::-1]  # High to low

    print("\n🔹 Top-k LOW |Gradient Norm| Samples:")
    print(topk_low[["patient_id", "grad_norm", "abs_grad_norm", "ct_available", "wsi_available"]])

    print("\n🔸 Top-k HIGH |Gradient Norm| Samples:")
    print(topk_high[["patient_id", "grad_norm", "abs_grad_norm", "ct_available", "wsi_available"]])

    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        low_path = os.path.join(output_dir, "topk_low_gradient.csv")
        high_path = os.path.join(output_dir, "topk_high_gradient.csv")
        topk_low.to_csv(low_path, index=False)
        topk_high.to_csv(high_path, index=False)
        print(f"\n📁 Saved to:\n  {low_path}\n  {high_path}")

    return topk_low, topk_high


#if __name__ == "__main__":
#    # Example usage
#    for i in range(4):
#        csv_input = f"./interpretability/interpretability_UCEC_mixed50/fold_{i}_log.csv"
#        find_topk_gradient_extremes(
#            csv_path=csv_input,
#            k=15,
#            output_dir=f"./interpretability/interpretability_UCEC_mixed50/fold_{i}_topk"
#        )


## SET GLOBAL SEED

In [2]:
import torch
import random
import numpy as np
SEED = 0

def set_global_seed(seed=SEED):
    """
    Set a global seed for reproducibility across different libraries and random number generators.

    Args:
        seed (int): Seed value to be used
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_global_seed()

## CPTAC-PDA train-mixed 50%

In [14]:
folds_dir = "./models/ckpts/CPTACPDA_trainmixed50_multival_Titan_MedImSight_redone"
test_path = "./data/processed/processed_CPTAC_PDA_survival/k=all.tsv"
ct_path = "../MedImageInsights/embeddings_output_cptacpda_93"
wsi_path = "../../TitanCPTACPDA/20x_512px_0px_overlap/slide_features_titan"
output_dir = "./interpretability/interpretability_pda_mixed50"
evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16)
for i in range(5):
    csv_input = f"./interpretability/interpretability_pda_mixed50/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=10,
        output_dir=f"./interpretability/interpretability_pda_mixed50/fold_{i}_topk"
        )


=== Fold 0 ===

=== Fold 1 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 2 ===

=== Fold 3 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 4 ===

🔹 Top-k LOW |Gradient Norm| Samples:
   patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
1   C3L-04479   0.012782       0.012782             1              1
5   C3N-02573   0.017344       0.017344             0              1
3   C3N-02573   0.017345       0.017345             0              1
4   C3N-02573   0.017345       0.017345             0              1
10  C3N-03839   0.018279       0.018279             0              1
12  C3N-03839   0.018280       0.018280             0              1
11  C3N-03839   0.018280       0.018280             0              1
9   C3N-03839   0.018288       0.018288             0              1
2   C3L-04479   0.021653       0.021653             1              1
0   C3L-02897   0.022701       0.022701             0              1

🔸 Top-k HIGH |Gradient Norm| Samples:
   patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
6   C3N-02940   0.024397       0.024397             0              1
8   C3N-02

  checkpoint = torch.load(


## CPTAC-PDA train-mixed 30%

In [15]:
folds_dir = "./models/ckpts/CPTACPDA_trainmixed30_multival_Titan_MedImSight"
test_path = "./data/processed/processed_CPTAC_PDA_survival/k=all.tsv"
ct_path = "../MedImageInsights/embeddings_output_cptacpda_93"
wsi_path = "../../TitanCPTACPDA/20x_512px_0px_overlap/slide_features_titan"
output_dir = "./interpretability/interpretability_PDA_mixed30"
evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16)
# Example usage
for i in range(5):
    csv_input = f"./interpretability/interpretability_PDA_mixed30/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=15,
        output_dir=f"./interpretability/interpretability_PDA_mixed30/fold_{i}_topk"
    )



=== Fold 0 ===


  checkpoint = torch.load(



=== Fold 1 ===

=== Fold 2 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 3 ===

=== Fold 4 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



🔹 Top-k LOW |Gradient Norm| Samples:
   patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
1   C3L-04479   0.000917       0.000917             1              1
10  C3N-03839   0.008852       0.008852             0              1
12  C3N-03839   0.008852       0.008852             0              1
11  C3N-03839   0.008852       0.008852             0              1
9   C3N-03839   0.008852       0.008852             0              1
0   C3L-02897   0.010575       0.010575             0              1
5   C3N-02573   0.011518       0.011518             0              1
4   C3N-02573   0.011518       0.011518             0              1
3   C3N-02573   0.011518       0.011518             0              1
6   C3N-02940   0.011803       0.011803             0              1
8   C3N-02940   0.011803       0.011803             0              1
7   C3N-02940   0.011803       0.011803             0              1
2   C3L-04479   0.111472       0.111472             1            

## CPTAC-PDA Mixed 15%

In [16]:

folds_dir = "./models/ckpts/CPTACPDA_trainmixed15_multival_Titan_MedImSight"
test_path = "./data/processed/processed_CPTAC_PDA_survival/k=all.tsv"
ct_path = "../MedImageInsights/embeddings_output_cptacpda_93"
wsi_path = "../../TitanCPTACPDA/20x_512px_0px_overlap/slide_features_titan"
output_dir = "./interpretability/interpretability_PDA_mixed15"
evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16)
for i in range(5):
    csv_input = f"./interpretability/interpretability_PDA_mixed15/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=15,
        output_dir=f"./interpretability/interpretability_PDA_mixed15/fold_{i}_topk"
    )



=== Fold 0 ===


  checkpoint = torch.load(



=== Fold 1 ===

=== Fold 2 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 3 ===

=== Fold 4 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



🔹 Top-k LOW |Gradient Norm| Samples:
   patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
2   C3L-04479   0.001864       0.001864             1              1
1   C3L-04479   0.036011       0.036011             1              1
5   C3N-02573   0.045232       0.045232             0              1
4   C3N-02573   0.045249       0.045249             0              1
3   C3N-02573   0.045285       0.045285             0              1
10  C3N-03839   0.047283       0.047283             0              1
12  C3N-03839   0.047340       0.047340             0              1
11  C3N-03839   0.047428       0.047428             0              1
9   C3N-03839   0.047582       0.047582             0              1
0   C3L-02897   0.055466       0.055466             0              1
6   C3N-02940   0.062714       0.062714             0              1
8   C3N-02940   0.063409       0.063409             0              1
7   C3N-02940   0.063509       0.063509             0            

## CPTAC PDA mixed 5%

In [17]:

folds_dir = "./models/ckpts/CPTACPDA_trainmixed5_multival_Titan_MedImSight"
test_path = "./data/processed/processed_CPTAC_PDA_survival/k=all.tsv"
ct_path = "../MedImageInsights/embeddings_output_cptacpda_93"
wsi_path = "../../TitanCPTACPDA/20x_512px_0px_overlap/slide_features_titan"
output_dir = "./interpretability/interpretability_PDA_mixed5"

evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16)
for i in range(5):
    csv_input = f"./interpretability/interpretability_PDA_mixed5/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=15,
        output_dir=f"./interpretability/interpretability_PDA_mixed5/fold_{i}_topk"
    )


=== Fold 0 ===


  checkpoint = torch.load(



=== Fold 1 ===

=== Fold 2 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 3 ===

=== Fold 4 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



🔹 Top-k LOW |Gradient Norm| Samples:
   patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
2   C3L-04479   0.011065       0.011065             1              1
1   C3L-04479   0.013125       0.013125             1              1
9   C3N-03839   0.013275       0.013275             0              1
11  C3N-03839   0.013275       0.013275             0              1
12  C3N-03839   0.013275       0.013275             0              1
10  C3N-03839   0.013275       0.013275             0              1
8   C3N-02940   0.017700       0.017700             0              1
7   C3N-02940   0.017700       0.017700             0              1
6   C3N-02940   0.017701       0.017701             0              1
3   C3N-02573   0.018294       0.018294             0              1
4   C3N-02573   0.018294       0.018294             0              1
5   C3N-02573   0.018294       0.018294             0              1
0   C3L-02897   0.022016       0.022016             0            

## CPTAC UCEC 5% 

In [3]:
folds_dir = "./models/ckpts/CPTAC_UCEC_titan_medimsight_trainmixed5"
ct_path= "../MedImageInsights/embeddings_cptacucec"
wsi_path= "../../trident_processed_UCEC_titan/20x_512px_0px_overlap/slide_features_titan"
test_path =  "./data/processed/processed_CPTACUCEC_survival/k=all.tsv"
output_dir = "./interpretability/interpretability_UCEC_mixed5"

evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16, n_folds=4)
for i in range(4):
    csv_input = f"./interpretability/interpretability_UCEC_mixed5/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=15,
        output_dir=f"./interpretability/interpretability_UCEC_mixed5/fold_{i}_topk"
    )


=== Fold 0 ===


  checkpoint = torch.load(



=== Fold 1 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 2 ===

=== Fold 3 ===

🔹 Top-k LOW |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
0  C3N-02631   0.024910       0.024910             1              1
1  C3N-02631   0.025809       0.025809             1              1
2  C3N-02631   0.059789       0.059789             1              1

🔸 Top-k HIGH |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
2  C3N-02631   0.059789       0.059789             1              1
1  C3N-02631   0.025809       0.025809             1              1
0  C3N-02631   0.024910       0.024910             1              1

📁 Saved to:
  ./interpretability/interpretability_UCEC_mixed5/fold_0_topk/topk_low_gradient.csv
  ./interpretability/interpretability_UCEC_mixed5/fold_0_topk/topk_high_gradient.csv

🔹 Top-k LOW |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
1  C3N-02678    0.21507        0.21507             1       

  checkpoint = torch.load(


## CPTAC UCEC 15%

In [4]:
folds_dir = "./models/ckpts/CPTAC_UCEC_titan_medimsight_trainmixed15"
ct_path= "../MedImageInsights/embeddings_cptacucec"
wsi_path= "../../trident_processed_UCEC_titan/20x_512px_0px_overlap/slide_features_titan"
test_path =  "./data/processed/processed_CPTACUCEC_survival/k=all.tsv"
output_dir = "./interpretability/interpretability_UCEC_mixed15"

evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16, n_folds=4)
for i in range(4):
    csv_input = f"./interpretability/interpretability_UCEC_mixed15/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=15,
        output_dir=f"./interpretability/interpretability_UCEC_mixed15/fold_{i}_topk"
    )


=== Fold 0 ===


  checkpoint = torch.load(



=== Fold 1 ===

=== Fold 2 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 3 ===

🔹 Top-k LOW |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
2  C3N-02631   0.000066       0.000066             1              1
0  C3N-02631   0.003875       0.003875             1              1
1  C3N-02631   0.057459       0.057459             1              1

🔸 Top-k HIGH |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
1  C3N-02631   0.057459       0.057459             1              1
0  C3N-02631   0.003875       0.003875             1              1
2  C3N-02631   0.000066       0.000066             1              1

📁 Saved to:
  ./interpretability/interpretability_UCEC_mixed15/fold_0_topk/topk_low_gradient.csv
  ./interpretability/interpretability_UCEC_mixed15/fold_0_topk/topk_high_gradient.csv

🔹 Top-k LOW |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
1  C3N-02678   0.008612       0.008612             1              1
0  C3

  checkpoint = torch.load(


## CPTAC UCEC 30%

In [5]:
folds_dir = "./models/ckpts/CPTAC_UCEC_titan_medimsight_trainmixed30"
ct_path= "../MedImageInsights/embeddings_cptacucec"
wsi_path= "../../trident_processed_UCEC_titan/20x_512px_0px_overlap/slide_features_titan"
test_path =  "./data/processed/processed_CPTACUCEC_survival/k=all.tsv"
output_dir = "./interpretability/interpretability_UCEC_mixed30"

evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16, n_folds=4)
for i in range(4):
    csv_input = f"./interpretability/interpretability_UCEC_mixed30/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=15,
        output_dir=f"./interpretability/interpretability_UCEC_mixed30/fold_{i}_topk"
    )


=== Fold 0 ===


  checkpoint = torch.load(



=== Fold 1 ===

=== Fold 2 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 3 ===

🔹 Top-k LOW |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
1  C3N-02631   0.042108       0.042108             1              1
0  C3N-02631   0.046808       0.046808             1              1
2  C3N-02631   0.129161       0.129161             1              1

🔸 Top-k HIGH |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
2  C3N-02631   0.129161       0.129161             1              1
0  C3N-02631   0.046808       0.046808             1              1
1  C3N-02631   0.042108       0.042108             1              1

📁 Saved to:
  ./interpretability/interpretability_UCEC_mixed30/fold_0_topk/topk_low_gradient.csv
  ./interpretability/interpretability_UCEC_mixed30/fold_0_topk/topk_high_gradient.csv

🔹 Top-k LOW |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
0  C3N-02678   0.010099       0.010099             1              1
1  C3

  checkpoint = torch.load(


## CPTAC UCEC 50%

In [6]:
folds_dir = "./models/ckpts/CPTACUCEC_trainmixed50_multival_Titan_MedImSight_fullepochs"
ct_path = "../MedImageInsights/embeddings_cptacucec"
wsi_path = "../../trident_processed_UCEC_titan/20x_512px_0px_overlap/slide_features_titan"
test_path = "./data/processed/processed_CPTACUCEC_survival/k=all.tsv"
output_dir = "./interpretability/interpretability_UCEC_mixed50"

evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16, n_folds=4)
for i in range(4):
    csv_input = f"./interpretability/interpretability_UCEC_mixed50/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=15,
        output_dir=f"./interpretability/interpretability_UCEC_mixed50/fold_{i}_topk"
    )


=== Fold 0 ===

=== Fold 1 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



=== Fold 2 ===

=== Fold 3 ===


  checkpoint = torch.load(
  checkpoint = torch.load(



🔹 Top-k LOW |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
1  C3N-02631   0.028265       0.028265             1              1
0  C3N-02631   0.028285       0.028285             1              1
2  C3N-02631   0.054864       0.054864             1              1

🔸 Top-k HIGH |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
2  C3N-02631   0.054864       0.054864             1              1
0  C3N-02631   0.028285       0.028285             1              1
1  C3N-02631   0.028265       0.028265             1              1

📁 Saved to:
  ./interpretability/interpretability_UCEC_mixed50/fold_0_topk/topk_low_gradient.csv
  ./interpretability/interpretability_UCEC_mixed50/fold_0_topk/topk_high_gradient.csv

🔹 Top-k LOW |Gradient Norm| Samples:
  patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
0  C3N-02678        0.0            0.0             1              1
1  C3N-02678        0

-------------------------

## CPTAC PDA test

In [3]:

def evaluate_and_log(folds_dir, ct_path, wsi_path, test_path, output_dir, batch_size=16, n_folds=5):
    os.makedirs(output_dir, exist_ok=True)

    for fold in range(n_folds):
        print(f"\n=== Fold {fold} ===")
        model = MADPENetNoBackbonesSurv(
            rad_input_dim=1024, histo_input_dim=768,
            inter_dim=256, token_dim=256, dim_hider=256
        )
        model.eval()

        # Load checkpoint
        fold_dir = os.path.join(folds_dir, f"fold_{fold}")
        model_subdir = next(d for d in os.listdir(fold_dir)
                            if os.path.isdir(os.path.join(fold_dir, d)))
        model_file = next(f for f in os.listdir(os.path.join(fold_dir, model_subdir))
                          if f.endswith("mixed_missing.pth"))
        checkpoint = torch.load(
            os.path.join(fold_dir, model_subdir, model_file), map_location="cpu"
        )
        model.load_state_dict(checkpoint["model_state_dict"])

        dataset = MultimodalCTWSIDatasetSurv(
            fold=0,
            split="train",
            ct_path=ct_path,
            wsi_path=wsi_path,
            labels_splits_path=test_path,
            missing_modality_prob=0.0,
            require_both_modalities=False,
            pairing_mode="one_to_one",
            allow_repeats=True,
            pairs_per_patient=None,
            missing_modality="wsi"
        )

        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0)

        # CSV Logger
        log_file = os.path.join(output_dir, f"fold_{fold}_log.csv")
        fold_tensor_dir = os.path.join(output_dir, f"fold_{fold}")
        os.makedirs(fold_tensor_dir, exist_ok=True)

        with open(log_file, "w", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([
                "patient_id", "ct_path", "wsi_path",
                "ct_available", "wsi_available",
                "hazard_score", "survtime", "censor",
                "grad_norm", "curvature",
                "fused_path", "pe_path", "hazard_tensor_path"
            ])

            sample_counter = 0
            for batch in loader:
                # ensure floats
                for k, v in batch.items():
                    if isinstance(v, torch.Tensor):
                        batch[k] = v.float()

                with torch.enable_grad():
                    for v in batch.values():
                        if isinstance(v, torch.Tensor):
                            v.requires_grad = True

                    outputs = model(
                        batch["ct_feature"],
                        batch["wsi_feature"],
                        modality_flag=batch["modality_mask"],
                        output_layers=["hazard", "fused_features", "positional_embeddings"]
                    )
                    hazard = outputs["hazard"].squeeze(1)  # [B]
                    fused_features = outputs["fused_features"]    # [B, D]
                    pos_emb = outputs["positional_embeddings"]    # [B, D]

                    grad_norms, curvatures = compute_grad_and_curvature(
                        fused_features, hazard,
                        batch["survtime"], batch["censor"]
                    )

                # save per-sample
                B = hazard.size(0)
                for j in range(B):
                    sample_id = f"sample_{sample_counter:06d}"
                    fused_path = os.path.join(fold_tensor_dir, f"{sample_id}_fused.npy")
                    pe_path = os.path.join(fold_tensor_dir, f"{sample_id}_pe.npy")
                    hazard_path = os.path.join(fold_tensor_dir, f"{sample_id}_hazard.npy")

                    np.save(fused_path, fused_features[j].detach().cpu().numpy())
                    np.save(pe_path, pos_emb[j].detach().cpu().numpy())
                    np.save(hazard_path, hazard[j].detach().cpu().numpy())

                    ct_flag = int(batch["modality_mask"][j, 0].item())
                    wsi_flag = int(batch["modality_mask"][j, 1].item())
                    if batch["censor"][j].item()==0: continue
                    writer.writerow([
                        batch["patient_id"][j],
                        dataset.samples[sample_counter]["ct_path"],
                        dataset.samples[sample_counter]["wsi_feature"],
                        ct_flag, wsi_flag,
                        hazard[j].item(),
                        batch["survtime"][j].item(),
                        batch["censor"][j].item(),
                        grad_norms[j], curvatures[j],
                        fused_path, pe_path, hazard_path
                    ])

                    sample_counter += 1
                    

import pandas as pd
import os
def find_topk_gradient_extremes(
    csv_path: str,
    k: int = 10,
    output_dir: str = None,
):
    """
    Finds top-k highest and lowest |gradient norm| samples from a CSV log.

    Args:
        csv_path (str): Path to fold_X_log.csv generated during eval.
        k (int): Number of top and bottom samples to retrieve.
        output_dir (str): Optional folder to save CSVs. If None, only prints results.
    """
    df = pd.read_csv(csv_path)

    if "grad_norm" not in df.columns:
        raise ValueError("The CSV does not contain 'grad_norm' column.")

    # Compute absolute gradient norms
    df["abs_grad_norm"] = df["grad_norm"].abs()

    # Sort by absolute gradient norm
    df_sorted = df.sort_values(by="abs_grad_norm", ascending=True)

    topk_low = df_sorted.head(k).copy()
    topk_high = df_sorted.tail(k).copy()[::-1]  # High to low

    print("\n🔹 Top-k LOW |Gradient Norm| Samples:")
    print(topk_low[["patient_id", "grad_norm", "abs_grad_norm", "ct_available", "wsi_available"]])

    print("\n🔸 Top-k HIGH |Gradient Norm| Samples:")
    print(topk_high[["patient_id", "grad_norm", "abs_grad_norm", "ct_available", "wsi_available"]])

    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        low_path = os.path.join(output_dir, "topk_low_gradient.csv")
        high_path = os.path.join(output_dir, "topk_high_gradient.csv")
        topk_low.to_csv(low_path, index=False)
        topk_high.to_csv(high_path, index=False)
        print(f"\n📁 Saved to:\n  {low_path}\n  {high_path}")

    return topk_low, topk_high

import torch
import random
import numpy as np
SEED = 0

def set_global_seed(seed=SEED):
    """
    Set a global seed for reproducibility across different libraries and random number generators.

    Args:
        seed (int): Seed value to be used
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_global_seed()

In [4]:
ct_path="../MedImageInsights/embeddings_output_cptacpda_93"
wsi_path="../../TitanCPTACPDA/20x_512px_0px_overlap/slide_features_titan"
output_dir = "./interpretability/interpretability_PDA_mixed50TEST"
folds_dir = "./models/ckpts/CPTACPDA_trainmixed50_multival_Titan_MedImSight_new"
test_path = "./data/processed/processed_CPTAC_PDA_test/k=all.tsv"


evaluate_and_log(folds_dir,ct_path,wsi_path,test_path,output_dir,batch_size=16, n_folds=5)
for i in range(5):
    csv_input = f"./interpretability/interpretability_PDA_mixed50TEST/fold_{i}_log.csv"
    find_topk_gradient_extremes(
        csv_path=csv_input,
        k=15,
        output_dir=f"./interpretability/interpretability_PDA_mixed50TEST/fold_{i}_topk"
    )


=== Fold 0 ===


  checkpoint = torch.load(



=== Fold 1 ===


  checkpoint = torch.load(



=== Fold 2 ===


  checkpoint = torch.load(



=== Fold 3 ===


  checkpoint = torch.load(



=== Fold 4 ===


  checkpoint = torch.load(



🔹 Top-k LOW |Gradient Norm| Samples:
    patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
99   C3N-02768   0.000245       0.000245             0              1
98   C3N-02768   0.000246       0.000246             0              1
97   C3N-02768   0.000253       0.000253             0              1
161  C3N-03190   0.000846       0.000846             0              1
162  C3N-03190   0.000853       0.000853             0              1
163  C3N-03190   0.000855       0.000855             0              1
191  C3N-01379   0.001462       0.001462             0              1
192  C3N-01379   0.001470       0.001470             0              1
193  C3N-01379   0.001484       0.001484             0              1
108  C3N-01168   0.005592       0.005592             0              1
109  C3N-01168   0.005593       0.005593             0              1
146  C3L-01158   0.005970       0.005970             0              1
166  C3N-03086   0.005977       0.005977            

## CPTAC UCEC test

In [52]:
import os
import torch
import csv
import numpy as np
from torch.utils.data import DataLoader
from data.multimodal_features_surv import MultimodalCTWSIDatasetSurv
from models.dpe.main_model_nobackbone_surv_new_gcs import MADPENetNoBackbonesSurv

from torch.nn.modules.loss import _WeightedLoss

class CoxLoss(_WeightedLoss):
    def forward(self, hazard_pred: torch.Tensor, survtime: torch.Tensor, censor: torch.Tensor):
        censor = censor.float()
        n = len(survtime)
        R_mat = survtime.reshape((1, n)) >= survtime.reshape((n, 1))
        theta = hazard_pred.reshape(-1)
        exp_theta = torch.exp(theta)
        loss = -torch.mean((theta - torch.log(torch.sum(exp_theta * R_mat, dim=1))) * censor)
        return loss


cox_loss_fn = CoxLoss()

def compute_grad_and_curvature(fused_features, hazard, survtime, censor):
    hazard = hazard.view(-1)
    loss = cox_loss_fn(hazard, survtime, censor)
    grad_f = torch.autograd.grad(loss, fused_features, create_graph=True)[0]
    grad_norm = grad_f.flatten(1).norm(p=2, dim=1)

    random_vec = grad_f.detach().clone().sign()
    grad_dot_random = torch.sum(grad_f * random_vec)
    hvp = torch.autograd.grad(grad_dot_random, fused_features, retain_graph=False)[0]
    curvature = torch.sum(hvp * random_vec, dim=list(range(1, hvp.ndim)))
    
    #print("grad_f_att:", grad_f)
    #print("grad_norm:", grad_f.flatten(1).norm(p=2, dim=1))

    
    return grad_norm.detach().cpu().numpy(), curvature.detach().cpu().numpy()   

def evaluate_and_log():
    folds_dir = "./models/ckpts/CPTACPDA_trainmixed50_multival_Titan_MedImSight"
    test_path = "./data/processed/processed_CPTAC_PDA_test/k=all.tsv"
    ct_path = "../MedImageInsights/embeddings_output_cptacpda_93"
    wsi_path = "../../TitanCPTACPDA/20x_512px_0px_overlap/slide_features_titan"
    output_dir = "./interpretability/interpretability_PDA_testmixed50"
    os.makedirs(output_dir, exist_ok=True)

    for fold in range(1):
        print(f"\n=== Fold {fold} ===")
        model = MADPENetNoBackbonesSurv(
            rad_input_dim=1024, histo_input_dim=768,
            inter_dim=256, token_dim=256, dim_hider=256
        )
        model.eval()

        # Load checkpoint
        fold_dir = os.path.join(folds_dir, f"fold_{fold}")
        model_subdir = next(d for d in os.listdir(fold_dir) if os.path.isdir(os.path.join(fold_dir, d)))
        model_file = next(f for f in os.listdir(os.path.join(fold_dir, model_subdir)) if f.endswith("mixed_missing.pth"))
        checkpoint = torch.load(os.path.join(fold_dir, model_subdir, model_file), map_location="cpu")
        model.load_state_dict(checkpoint["model_state_dict"])

        dataset = MultimodalCTWSIDatasetSurv(
            fold=fold,
            split="train",
            ct_path=ct_path,
            wsi_path=wsi_path,
            labels_splits_path=test_path,
            missing_modality_prob=0.0,
            require_both_modalities=False,
            pairing_mode="one_to_one",
            allow_repeats=True,
            pairs_per_patient=None,
            missing_modality="wsi"
        )

        loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

        # CSV Logger
        log_file = os.path.join(output_dir, f"fold_{fold}_log.csv")
        fold_tensor_dir = os.path.join(output_dir, f"fold_{fold}")
        os.makedirs(fold_tensor_dir, exist_ok=True)

        with open(log_file, "w", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([
                "patient_id", "ct_path", "wsi_path",
                "ct_available", "wsi_available",
                "hazard_score", "survtime", "censor",
                "grad_norm", "curvature",
                "fused_path", "pe_path", "hazard_tensor_path"
            ])

            for i, batch in enumerate(loader):
                for k in batch:
                    if isinstance(batch[k], torch.Tensor):
                        batch[k] = batch[k].float()

                with torch.enable_grad():
                    for k in batch:
                        if isinstance(batch[k], torch.Tensor):
                            batch[k].requires_grad = True

                    outputs = model(
                        batch["ct_feature"],
                        batch["wsi_feature"],
                        modality_flag=batch["modality_mask"],
                        output_layers=["hazard", "fused_features", "positional_embeddings"]
                    )
                    fused_features = outputs["fused_features"]
                    fused_features.requires_grad_(True)

                    # Recompute hazard *outside* the model to ensure it is explicitly tied to fused_features
                    hazard = model.hazard_net(fused_features)

                    #fused_features = outputs["fused_features"]
                    pos_emb = outputs["positional_embeddings"].squeeze()
                    #fused_features.requires_grad_(True)
                    grad_norm, curvature = compute_grad_and_curvature(
                        fused_features, hazard, batch["survtime"], batch["censor"]
                    )

                # Save tensor files
                sample_id = f"sample_{i:04d}"
                fused_path = os.path.join(fold_tensor_dir, f"{sample_id}_fused.npy")
                pe_path = os.path.join(fold_tensor_dir, f"{sample_id}_pe.npy")
                hazard_path = os.path.join(fold_tensor_dir, f"{sample_id}_hazard.npy")

                np.save(fused_path, fused_features.detach().cpu().numpy())
                np.save(pe_path, pos_emb.detach().cpu().numpy())
                np.save(hazard_path, hazard.detach().cpu().numpy())

                ct_flag = int(batch["modality_mask"][0, 0].item())
                wsi_flag = int(batch["modality_mask"][0, 1].item())

                writer.writerow([
                    batch["patient_id"][0],
                    dataset.samples[i]["ct_path"],
                    dataset.samples[i]["wsi_feature"],
                    ct_flag, wsi_flag,
                    hazard.item(),
                    batch["survtime"].item(),
                    batch["censor"].item(),
                    grad_norm[0],
                    curvature[0],
                    fused_path,
                    pe_path,
                    hazard_path
                ])


if __name__ == "__main__":
    evaluate_and_log()

import pandas as pd
import os
def find_topk_gradient_extremes(
    csv_path: str,
    k: int = 10,
    output_dir: str = None,
):
    """
    Finds top-k highest and lowest |gradient norm| samples from a CSV log.

    Args:
        csv_path (str): Path to fold_X_log.csv generated during eval.
        k (int): Number of top and bottom samples to retrieve.
        output_dir (str): Optional folder to save CSVs. If None, only prints results.
    """
    df = pd.read_csv(csv_path)

    if "grad_norm" not in df.columns:
        raise ValueError("The CSV does not contain 'grad_norm' column.")

    # Compute absolute gradient norms
    df["abs_grad_norm"] = df["grad_norm"].abs()

    # Sort by absolute gradient norm
    df_sorted = df.sort_values(by="abs_grad_norm", ascending=True)

    topk_low = df_sorted.head(k).copy()
    topk_high = df_sorted.tail(k).copy()[::-1]  # High to low

    print("\n🔹 Top-k LOW |Gradient Norm| Samples:")
    print(topk_low[["patient_id", "grad_norm", "abs_grad_norm", "ct_available", "wsi_available"]])

    print("\n🔸 Top-k HIGH |Gradient Norm| Samples:")
    print(topk_high[["patient_id", "grad_norm", "abs_grad_norm", "ct_available", "wsi_available"]])

    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        low_path = os.path.join(output_dir, "topk_low_gradient.csv")
        high_path = os.path.join(output_dir, "topk_high_gradient.csv")
        topk_low.to_csv(low_path, index=False)
        topk_high.to_csv(high_path, index=False)
        print(f"\n📁 Saved to:\n  {low_path}\n  {high_path}")

    return topk_low, topk_high


if __name__ == "__main__":
    # Example usage
    for i in range(1):
        csv_input = f"./interpretability/interpretability_PDA_testmixed50/fold_{i}_log.csv"
        find_topk_gradient_extremes(
            csv_path=csv_input,
            k=15,
            output_dir=f"./interpretability/interpretability_PDA_testmixed50/fold_{i}_topk"
        )



=== Fold 0 ===


  checkpoint = torch.load(os.path.join(fold_dir, model_subdir, model_file), map_location="cpu")



🔹 Top-k LOW |Gradient Norm| Samples:
    patient_id  grad_norm  abs_grad_norm  ct_available  wsi_available
0    C3L-04473        0.0            0.0             0              1
177  C3L-01662        0.0            0.0             0              1
178  C3L-01662        0.0            0.0             0              1
179  C3L-01662        0.0            0.0             0              1
180  C3L-01662        0.0            0.0             0              1
181  C3L-01160        0.0            0.0             0              1
182  C3L-01160        0.0            0.0             0              1
183  C3L-01160        0.0            0.0             0              1
184  C3N-03884        0.0            0.0             0              1
185  C3N-03884        0.0            0.0             0              1
186  C3N-03884        0.0            0.0             0              1
187  C3L-01031        0.0            0.0             0              1
188  C3L-01031        0.0            0.0            