In [None]:
### Cell 1 — Path & environment check

In [None]:
import os
from pathlib import Path

# Update this to your actual directory containing copied .pt files
PT_DIR = Path("/home/stat-jijianxin1997/SCI_GC_OS/pt_file")

print("PT_DIR:", PT_DIR)
print("Exists:", PT_DIR.exists())

pt_files = sorted([p.name for p in PT_DIR.glob("*.pt")])
print("Number of .pt files:", len(pt_files))
print("First 5 files:", pt_files[:5])

assert PT_DIR.exists(), f"PT_DIR not found: {PT_DIR}"
assert len(pt_files) > 0, f"No .pt files found under: {PT_DIR}"


In [None]:
### Cell 2 — Sanity check: load one .pt and inspect shape

In [None]:
import torch

sample_path = PT_DIR / pt_files[0]
print("Loading:", sample_path)

x = torch.load(sample_path, map_location="cpu")
print("Type:", type(x))

if isinstance(x, torch.Tensor):
    print("Tensor shape:", tuple(x.shape))
    print("Dtype:", x.dtype)
else:
    # In case the saved object is not a plain tensor
    print("Loaded object:", x)


In [None]:
### Cell 3 — Build a minimal demo CSV (so the Dataset can run)

In [None]:
import pandas as pd
from pathlib import Path

# Use a few .pt files for a minimal, runnable demo
demo_pt = pt_files[:4]

demo_df = pd.DataFrame({
    "case_id": [f"DEMO_{i}" for i in range(len(demo_pt))],
    "gender": ["male", "female", "male", "female"][:len(demo_pt)],
    "age_at_index": [60, 55, 70, 49][:len(demo_pt)],
    "label": [0, 1, 2, 3][:len(demo_pt)],                 # discrete interval label (example)
    "survival_months": [10.0, 20.0, 5.0, 18.0][:len(demo_pt)],
    "censor": [0, 1, 0, 1][:len(demo_pt)],                # 0=event, 1=censored (per your loss code)
    "slide_id": [str([name]) for name in demo_pt],        # IMPORTANT: python list string
})

demo_csv_path = Path("./demo_minimal.csv")
demo_df.to_csv(demo_csv_path, index=False)

print("Saved demo CSV to:", demo_csv_path.resolve())
demo_df


In [None]:
### Cell 4 — Create Dataset + DataLoader (PT mode) with padding mask

In [None]:
from torch.utils.data import DataLoader

# Update import paths to match your repo file names
from dataset_position import SwinPrognosisDataset
from model_utils import custom_collate_fn

load_mode = "pt"

dataset = SwinPrognosisDataset(
    df=str(demo_csv_path),
    pt_dir=str(PT_DIR),
    load_mode=load_mode
)

loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    num_workers=0,  # keep 0 in notebooks for stability
    collate_fn=lambda batch: custom_collate_fn(batch, load_mode)
)

batch = next(iter(loader))
print("Batch tuple length:", len(batch))


In [None]:
### Cell 5 — Inspect batch tensors (features + mask)

In [None]:
(patient, gender, age, label, sur_time, censor, feature, coords, num_patches, mask) = batch

print("patient:", patient)
print("gender:", gender.shape, gender.dtype, gender)
print("age:", age.shape, age.dtype, age)
print("label:", label.shape, label.dtype, label)
print("sur_time:", sur_time.shape, sur_time.dtype, sur_time)
print("censor:", censor.shape, censor.dtype, censor)

print("feature:", feature.shape, feature.dtype)  # (B, max_patches, D)
print("mask:", mask.shape, mask.dtype)           # (B, max_patches)
print("num_patches:", num_patches)
print("coords is None:", coords is None)


In [None]:
### Cell 6 — Forward pass with Transformer + compute CombinedSurvLoss

In [None]:
import torch

from transformer_context import Transformer
from loss_func import CombinedSurvLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Adjust these to match your actual feature dimension and discretization (num_classes)
model_params = {
    "num_classes": 4,
    "input_dim": 1024,   # UNI features are typically 1024-d
    "dim": 512,
    "depth": 1,
    "heads": 2,
    "mlp_dim": 128,
    "pool": "cls",
    "dim_head": 128,
    "dropout": 0.3,
    "emb_dropout": 0.3,
}

criterion_params = {"alpha": 0.5}

model = Transformer(**model_params).to(device)
criterion = CombinedSurvLoss(**criterion_params).to(device)

model.eval()

feature = feature.to(device)
mask = mask.to(device)
age = age.to(device)
gender = gender.to(device)
label = label.to(device)
sur_time = sur_time.to(device)
censor = censor.to(device)

with torch.no_grad():
    # IMPORTANT: Transformer.forward signature is (x, age, gender, mask=None)
    outputs = model(feature, age, gender, mask)

loss = criterion(outputs=outputs, y=label, t=sur_time, c=censor)

print("outputs shape:", tuple(outputs.shape))
print("loss:", float(loss))


In [None]:
### Cell 7 (optional) — Quick check: mask correctness

In [None]:
# For each sample, valid patches should sum to num_patches
mask_sums = mask.sum(dim=1).cpu()
print("mask sums:", mask_sums.tolist())
print("num_patches:", num_patches.cpu().tolist())
