In [None]:
import os, copy, random, numpy as np, pandas as pd
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import densenet121
from PIL import Image

from pycox.evaluation import EvalSurv
from pycox.models.loss import CoxPHLoss
from torchtuples import optim as ttoptim
torch.backends.cudnn.benchmark = True

def seed_everything(s=2025):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
seed_everything(2025)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
X_train_clin = np.load(NPY_TRAIN).astype("float32")
X_val_clin   = np.load(NPY_VAL).astype("float32")
X_test_clin  = np.load(NPY_TEST).astype("float32")

In [None]:
IMSIZE = 456
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

t_train = transforms.Compose([
    transforms.Resize((IMSIZE, IMSIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

t_eval = transforms.Compose([
    transforms.Resize((IMSIZE, IMSIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

In [None]:
class SurvDataset(Dataset):
    def __init__(self, df_part: pd.DataFrame, X_clin: np.ndarray, img_dir: str, transform):
        
        self.df = df_part
        self.Xc = X_clin
        self.img_dir = img_dir
        self.tfm = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        fname = row["PIC_Name"]
        img_path = os.path.join(self.img_dir, str(fname))
        if not os.path.isfile(img_path):
            tried = [img_path]
            exts = [".jpg", ".png", ".jpeg", ".bmp", ".tif", ".tiff"]
            hit = None
            for ext in exts:
                p2 = os.path.join(self.img_dir, str(fname) + ext)
                tried.append(p2)
                if os.path.isfile(p2):
                    hit = p2; break
            if hit is None:
                raise FileNotFoundError(f"could not find image")
            img_path = hit

        img = Image.open(img_path).convert("RGB")
        xi = self.tfm(img)

        xc = torch.from_numpy(self.Xc[i])

        dval = float(row["duration_time"])
        if not np.isfinite(dval):
            dval = 1e-3
        dur = torch.tensor(max(dval, 1e-3), dtype=torch.float32)
        evt = torch.tensor(int(row["end"]), dtype=torch.float32)
        return xi, xc, dur, evt, img_path 

In [None]:
BATCH_TRAIN = 24
BATCH_EVAL  = 32
NUM_WORKERS = 4
PIN_MEMORY  = True

ds_train = SurvDataset(train_df, X_train_clin, IMG_DIRS["train"], t_train)
ds_val   = SurvDataset(val_df,   X_val_clin,   IMG_DIRS["valid"], t_eval)
ds_test  = SurvDataset(test_df,  X_test_clin,  IMG_DIRS["test"],  t_eval)

dl_train = DataLoader(ds_train, batch_size=BATCH_TRAIN, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
dl_val   = DataLoader(ds_val,   batch_size=BATCH_EVAL,  shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
dl_test  = DataLoader(ds_test,  batch_size=BATCH_EVAL,  shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

In [None]:
CKPT_3Y = "your own model"

class ImgBranch(nn.Module):
    def __init__(self, backbone, drop=0.4):
        super().__init__()
        self.features = backbone.features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(drop),
        )
    def forward(self, x):
        x = self.features(x)
        x = nn.functional.relu(x, inplace=True)
        x = self.pool(x)
        return self.proj(x)

class ClinBranch(nn.Module):
    def __init__(self, in_dim, drop=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Dropout(drop),
            nn.Linear(128, 64),  nn.BatchNorm1d(64),  nn.ReLU(inplace=True), nn.Dropout(drop),
        )
    def forward(self, x): return self.net(x)

class MultiModalCox(nn.Module):
    def __init__(self, img_backbone, clin_in, comb=128, drop=0.3):
        super().__init__()
        self.img = ImgBranch(img_backbone)
        self.clin = ClinBranch(clin_in)
        self.fuse = nn.Sequential(
            nn.Linear(512+64, comb), nn.BatchNorm1d(comb), nn.ReLU(inplace=True), nn.Dropout(drop)
        )
        self.head = nn.Linear(comb, 1, bias=False)  # log-risk
    def forward(self, xi, xc):
        zi = self.img(xi); zc = self.clin(xc)
        z = self.fuse(torch.cat([zi, zc], dim=1))
        return self.head(z).squeeze(1)

def load_backbone_from_cls(ckpt_path):
    bb = densenet121(weights=None)
    sd = torch.load(ckpt_path, map_location='cpu')
    if isinstance(sd, dict) and 'state_dict' in sd:
        sd = sd['state_dict']
    sd = {k.replace('module.', ''): v for k,v in sd.items()
          if 'classifier' not in k}  
    bb.load_state_dict(sd, strict=False)
    return bb

backbone = load_backbone_from_cls(CKPT_3Y)
model = MultiModalCox(backbone, clin_in=14).to(device) 
for p in model.img.features.parameters():
    p.requires_grad = False
model.img.features.train(False)

cox_loss = CoxPHLoss()
opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=3e-4, weight_decay=1e-6)

xi, xc, d, e, _ = next(iter(dl_train))
xi, xc, d, e = xi.to(device), xc.to(device), d.to(device), e.to(device)
opt.zero_grad(set_to_none=True)
log_risk = model(xi, xc)
loss = cox_loss(log_risk, d, e)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
print(f"Smoke test OK. one-batch loss={loss.item():.4f}")

In [None]:
@torch.no_grad()
def _predict(model, loader):
    model.eval()
    R, D, E = [], [], []
    for xi, xc, d, e, _ in loader:
        xi, xc = xi.to(device), xc.to(device)
        lr = model(xi, xc).detach().cpu().numpy()
        R.append(lr); D.append(d.numpy()); E.append(e.numpy())
    return np.concatenate(R), np.concatenate(D), np.concatenate(E)

@torch.no_grad()
def evaluate_surv(model, dl_train, dl_eval, times=None):
    import numpy as np, pandas as pd
    from pycox.evaluation import EvalSurv

    def _predict(loader):
        model.eval()
        R, D, E = [], [], []
        for xi, xc, d, e, _ in loader:
            xi, xc = xi.to(device), xc.to(device)
            lr = model(xi, xc).detach().cpu().numpy().astype(np.float64)   
            dn = d.detach().cpu().numpy().astype(np.float64)               
            en = e.detach().cpu().numpy().astype(np.int64)                 
            R.append(lr); D.append(dn); E.append(en)
        r = np.concatenate(R); d = np.concatenate(D); e = np.concatenate(E)
        m = np.isfinite(r) & np.isfinite(d) & np.isfinite(e)
        r, d, e = r[m], d[m], e[m]
        d = np.maximum(d, 1e-3) 
        return r, d, e

    r_tr, d_tr, e_tr = _predict(dl_train)
    n_evt_tr = int((e_tr == 1).sum())
    if n_evt_tr == 0:
        return float('nan'), float('nan')

    order = np.argsort(d_tr, kind='mergesort')
    t_tr, e_tr, r_tr = d_tr[order], e_tr[order], r_tr[order]
    uniq_evt = np.unique(t_tr[e_tr == 1])

    H0_vals, T_vals = [], []
    for ut in uniq_evt:
        at_risk = np.exp(r_tr[t_tr >= ut]).sum()
        d_i = ((t_tr == ut) & (e_tr == 1)).sum()
        H0_vals.append(d_i / max(at_risk, 1e-12))
        T_vals.append(ut)
    H0_vals = np.cumsum(np.asarray(H0_vals, dtype=np.float64))
    T_vals  = np.asarray(T_vals, dtype=np.float64)

    r_ev, d_ev, e_ev = _predict(dl_eval)
    n_evt_ev = int((e_ev == 1).sum())

    if times is None:
        lo_candidates = [1.0]
        if T_vals.size: lo_candidates.append(T_vals.min())
        if d_ev.size:   lo_candidates.append(d_ev.min())
        lo = max(lo_candidates)

        hi_candidates = []
        if T_vals.size: hi_candidates.append(T_vals.max())
        if d_ev.size:   hi_candidates.append(d_ev.max())
        hi = min(hi_candidates) if hi_candidates else lo + 1.0

        if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
            lo = float(T_vals.min()) if T_vals.size else 1.0
            hi = float(T_vals.max()) if T_vals.size else lo + 1.0
            if hi <= lo: hi = lo + 1.0

        times = np.linspace(lo, hi, 80).astype(np.float64)

    Ht_grid = np.interp(times, T_vals, H0_vals, left=0.0, right=(H0_vals[-1] if H0_vals.size else 0.0))
    surv_rows = [np.exp(-Ht_grid * np.exp(ri)) for ri in r_ev]
    surv = pd.DataFrame(np.vstack(surv_rows).T, index=times)

    es = EvalSurv(surv, d_ev, e_ev, censor_surv='km')

    try:
        uno_val = es.concordance_td('uno')
        uno = float(uno_val)
    except Exception as ex:
        try:
            uno_alt = es.concordance_td() 
            uno = float(uno_alt)
        except Exception:
            uno = float('nan')

    try:
        ibs_val = es.integrated_brier_score(times)
        ibs = float(ibs_val)
    except Exception:
        try:
            lo2, hi2 = np.percentile(d_ev, [5, 95])
            if not np.isfinite(lo2) or not np.isfinite(hi2) or hi2 <= lo2:
                raise ValueError
            times2 = np.linspace(lo2, hi2, 60).astype(np.float64)
            Ht2 = np.interp(times2, T_vals, H0_vals, left=0.0, right=(H0_vals[-1] if H0_vals.size else 0.0))
            surv_rows2 = [np.exp(-Ht2 * np.exp(ri)) for ri in r_ev]
            surv2 = pd.DataFrame(np.vstack(surv_rows2).T, index=times2)
            es2 = EvalSurv(surv2, d_ev, e_ev, censor_surv='km')
            ibs = float(es2.integrated_brier_score(times2))
        except Exception:
            ibs = float('nan')

    print(f"[Eval] train_events={n_evt_tr}, eval_events={n_evt_ev}, "
          f"time_range=[{times.min():.1f},{times.max():.1f}], uno={uno:.4f}, ibs={ibs if np.isfinite(ibs) else np.nan}")

    return uno, ibs

In [None]:
def _nan_guard_batch(xi, xc, d, e, name="train"):

    import torch
    def clamp_finite(t, tname):
        t = t.clone()

        mask = ~torch.isfinite(t)
        if mask.any():
            t[mask] = 0
            print(f"[WARN] {name}:{tname} had {mask.sum().item()} non-finite; replaced with 0")
        return t
    xi = clamp_finite(xi, "xi")
    xc = clamp_finite(xc, "xc")
    d  = clamp_finite(d,  "d")
    e  = clamp_finite(e,  "e")

    e = (e > 0.5).float()
    d = torch.clamp(d, min=1e-3)

    if (e.sum() == 0) or (e.sum() == e.numel()):
        pass

    return xi, xc, d, e

In [None]:
xi, xc, d, e, _ = next(iter(dl_train))
print("d stats:", d.min().item(), d.max().item(), torch.isnan(d).any().item())
print("e stats:", e.unique(return_counts=True))
print("xc stats:", torch.isnan(xc).any().item(), torch.isinf(xc).any().item(),
      float(xc.min()), float(xc.max()))
with torch.no_grad():
    out = model(xi.to(device), xc.to(device)).detach().cpu()
print("log_risk stats:", float(out.min()), float(out.max()),
      torch.isnan(out).any().item(), torch.isinf(out).any().item())

In [None]:
def train_phase_A(epochs=6):
    best = {'uno': -1, 'ibs': 1e9, 'state': None}
    for ep in range(1, epochs+1):
        model.train()
        run_loss = 0.0
        num_ok = 0

        for xi, xc, d, e, _ in dl_train:

            xi, xc, d, e = _nan_guard_batch(xi, xc, d, e, name=f"ep{ep}")
            xi, xc, d, e = xi.to(device), xc.to(device), d.to(device), e.to(device)

            if e.sum().item() < 1:
                continue

            opt.zero_grad(set_to_none=True)
            log_risk = model(xi, xc)

            if not torch.isfinite(log_risk).all():
                continue

            loss = cox_loss(log_risk, d, e)
            
            if not torch.isfinite(loss):
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            run_loss += float(loss.detach().cpu().item()) 
            num_ok += 1

        if num_ok == 0:
            print(f"[WARN] Epoch {ep}: no valid batches (likely too many no-event batches). "
                  f"Try larger batch_size or event-stratified sampler.")
            tr_loss_str = "nan"
        else:
            tr_loss_str = f"{run_loss:.3f}"

        try:
            uno, ibs = evaluate_surv(model, dl_train, dl_val)
        except Exception as ex:
            print(f"[WARN] evaluate_surv failed at epoch {ep}: {ex}")
            uno, ibs = float('nan'), float('nan')

        print(f"Epoch {ep}: train_loss={tr_loss_str} | Val UnoC={uno} | IBS={ibs}")

        if np.isfinite(uno) and (uno > best['uno'] or (abs(uno-best['uno'])<1e-5 and np.isfinite(ibs) and ibs < best['ibs'])):
            best.update({'uno': float(uno), 'ibs': float(ibs), 'state': copy.deepcopy(model.state_dict())})

    if best['state'] is not None:
        torch.save(best['state'], "save best weights")
    else:
        print("[WARN] No valid best state to save (all epochs invalid).")

    return best

best = train_phase_A(epochs=6)