In [None]:
import os
import copy
import random

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms
from torchvision.models import densenet121

import optuna

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[Env] device={device}, cuda_available={torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"[Env] GPU: {torch.cuda.get_device_name(0)}")

PHASEA_BEST = "your best weights"
CKPT_3Y = "your previous deep learning model"

In [None]:
from torch.utils.data import Dataset, DataLoader

IMSIZE = 456
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

t_train = transforms.Compose([
    transforms.Resize((IMSIZE, IMSIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

t_eval = transforms.Compose([
    transforms.Resize((IMSIZE, IMSIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

#Dataset
class SurvDataset(Dataset):
    def __init__(self, df_part, X_clin, img_dir, transform):
        self.df = df_part.reset_index(drop=True)
        self.Xc = X_clin.astype("float32", copy=False)
        self.img_dir = img_dir
        self.tfm = transform
        assert len(self.df) == len(self.Xc), 

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        fname = str(row["image_name"])
        path = os.path.join(self.img_dir, fname)
        if not os.path.isfile(path):
            for ext in [".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"]:
                p2 = os.path.join(self.img_dir, fname + ext)
                if os.path.isfile(p2):
                    path = p2; break
        img = Image.open(path).convert("RGB")
        xi = self.tfm(img)
        xc = torch.from_numpy(self.Xc[i])
        d  = torch.tensor(max(float(row["duration_months"]), 1e-3), dtype=torch.float32)
        e  = torch.tensor(int(row["event"]), dtype=torch.float32)
        return xi, xc, d, e, path

#DataLoaders
BATCH_TRAIN = 32
BATCH_EVAL  = 32
NUM_WORKERS = 4
PIN_MEMORY  = True

ds_train = SurvDataset(train_df, X_train_clin, IMG_DIRS["train"], t_train)
ds_val   = SurvDataset(val_df,   X_val_clin,   IMG_DIRS["valid"], t_eval)
ds_test  = SurvDataset(test_df,  X_test_clin,  IMG_DIRS["test"],  t_eval)

dl_train = DataLoader(ds_train, batch_size=BATCH_TRAIN, shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
dl_val   = DataLoader(ds_val,   batch_size=BATCH_EVAL,  shuffle=False,
                      num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
dl_test  = DataLoader(ds_test,  batch_size=BATCH_EVAL,  shuffle=False,
                      num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

print(f"[DL] train={len(ds_train)} | val={len(ds_val)} | test={len(ds_test)} | clin_dim={clin_dim}")

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import densenet121

class ImgBranch(nn.Module):
    def __init__(self, backbone, drop=0.4):
        super().__init__()
        self.features = backbone.features                 
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.proj = nn.Sequential(                         
            nn.Flatten(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(drop),
        )
    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = self.pool(x)
        return self.proj(x)

class ClinBranch(nn.Module):
    def __init__(self, in_dim, drop=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Dropout(drop),
            nn.Linear(128, 64),  nn.BatchNorm1d(64),  nn.ReLU(inplace=True), nn.Dropout(drop),
        )
    def forward(self, x): return self.net(x)

class MultiModalCox(nn.Module):
    def __init__(self, img_backbone, clin_in, comb=128, drop=0.3):
        super().__init__()
        self.img  = ImgBranch(img_backbone)
        self.clin = ClinBranch(clin_in)
        self.fuse = nn.Sequential(
            nn.Linear(512+64, comb), nn.BatchNorm1d(comb), nn.ReLU(inplace=True), nn.Dropout(drop)
        )
        self.head = nn.Linear(comb, 1, bias=False)  
    def forward(self, xi, xc):
        zi = self.img(xi); zc = self.clin(xc)
        z  = self.fuse(torch.cat([zi, zc], dim=1))
        return self.head(z).squeeze(1)

In [None]:
def build_model_from_phaseA():
    backbone = densenet121(weights=None)
    model = MultiModalCox(backbone, clin_in=clin_dim).to(device)
    sd = torch.load(PHASEA_BEST, map_location='cpu')
    missing, unexpected = model.load_state_dict(sd, strict=False)
    if missing:   print("[INFO] Missing keys:", missing)
    if unexpected:print("[INFO] Unexpected keys:", unexpected)
    return model

def set_bn_eval(m):
    if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):

def apply_unfreeze(model: nn.Module, mode: str):
    for p in model.img.features.parameters():
        p.requires_grad = False
    model.img.features.apply(set_bn_eval)

    names = ["denseblock4", "transition3"]               
    if mode == "block43":                                
        names += ["denseblock3", "transition2"]

    for name, module in model.img.features.named_children():
        if name in names:
            for p in module.parameters():
                p.requires_grad = True
            module.apply(set_bn_eval)

from pycox.models.loss import CoxPHLoss
cox_loss = CoxPHLoss()

model_smoke = build_model_from_phaseA()
apply_unfreeze(model_smoke, mode="block4")  

xi, xc, d, e, _ = next(iter(dl_train))
xi, xc, d, e = xi.to(device), xc.to(device), d.to(device), e.to(device)

with torch.no_grad():
    log_risk = model_smoke(xi, xc)
    loss = cox_loss(log_risk, d, e)

print(f"[Smoke Model] log_risk=({log_risk.min().item():.3f},{log_risk.max().item():.3f})  loss={loss.item():.3f}")

In [None]:
@torch.no_grad()
def evaluate_surv(model, dl_train, dl_eval, times=None):
    import numpy as np, pandas as pd
    from pycox.evaluation import EvalSurv

    CLIP = 20.0  

    def _predict(loader):
        model.eval()
        R, D, E = [], [], []
        for xi, xc, d, e, _ in loader:
            xi, xc = xi.to(device), xc.to(device)
            lr = model(xi, xc).detach().cpu().numpy().astype(np.float64)
            lr = np.clip(lr, -CLIP, CLIP)                            
            dn = d.detach().cpu().numpy().astype(np.float64)
            en = e.detach().cpu().numpy().astype(np.int64)
            R.append(lr); D.append(dn); E.append(en)
        r = np.concatenate(R); d = np.concatenate(D); e = np.concatenate(E)

        m = np.isfinite(r) & np.isfinite(d) & np.isfinite(e)
        r, d, e = r[m], d[m], e[m]
        d = np.maximum(d, 1e-3)
        return r, d, e

    r_tr, d_tr, e_tr = _predict(dl_train)
    n_evt_tr = int((e_tr == 1).sum())
    if n_evt_tr == 0:
        print("[Eval] no events in training set after filtering â†’ return NaN")
        return float('nan'), float('nan')

    order = np.argsort(d_tr, kind='mergesort')
    t_tr, e_tr, r_tr = d_tr[order], e_tr[order], r_tr[order]
    uniq_evt = np.unique(t_tr[e_tr == 1])

    H0_vals, T_vals = [], []
    for ut in uniq_evt:
        at_risk = np.exp(r_tr[t_tr >= ut]).sum()
        d_i = ((t_tr == ut) & (e_tr == 1)).sum()
        H0_vals.append(d_i / max(at_risk, 1e-12))
        T_vals.append(ut)
    H0_vals = np.cumsum(np.asarray(H0_vals, dtype=np.float64))
    T_vals  = np.asarray(T_vals, dtype=np.float64)

    r_ev, d_ev, e_ev = _predict(dl_eval)
    n_evt_ev = int((e_ev == 1).sum())

    if times is None:
        lo_candidates = [1.0]
        if T_vals.size: lo_candidates.append(T_vals.min())
        if d_ev.size:   lo_candidates.append(d_ev.min())
        lo = max(lo_candidates)

        hi_candidates = []
        if T_vals.size: hi_candidates.append(T_vals.max())
        if d_ev.size:   hi_candidates.append(d_ev.max())
        hi = min(hi_candidates) if hi_candidates else lo + 1.0

        if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
            lo = float(T_vals.min()) if T_vals.size else 1.0
            hi = float(T_vals.max()) if T_vals.size else lo + 1.0
            if hi <= lo: hi = lo + 1.0

        times = np.linspace(lo, hi, 80).astype(np.float64)

    Ht = np.interp(times, T_vals, H0_vals, left=0.0, right=(H0_vals[-1] if H0_vals.size else 0.0))
    surv_rows = [np.exp(-Ht * np.exp(ri)) for ri in r_ev]
    surv = pd.DataFrame(np.vstack(surv_rows).T, index=times) 

    es = EvalSurv(surv, d_ev, e_ev, censor_surv='km')

    try:
        uno = float(es.concordance_td('uno'))
    except Exception:
        try:
            uno = float(es.concordance_td())  # antolini
        except Exception:
            uno = float('nan')

    try:
        ibs = float(es.integrated_brier_score(times))
    except Exception:
        try:
            lo2, hi2 = np.percentile(d_ev, [5, 95])
            if not np.isfinite(lo2) or not np.isfinite(hi2) or hi2 <= lo2:
                raise ValueError
            times2 = np.linspace(lo2, hi2, 60).astype(np.float64)
            Ht2 = np.interp(times2, T_vals, H0_vals, left=0.0, right=(H0_vals[-1] if H0_vals.size else 0.0))
            surv2 = pd.DataFrame(np.vstack([np.exp(-Ht2 * np.exp(ri)) for ri in r_ev]).T, index=times2)
            es2 = EvalSurv(surv2, d_ev, e_ev, censor_surv='km')
            ibs = float(es2.integrated_brier_score(times2))
        except Exception:
            ibs = float('nan')

    print(f"[Eval] train_events={n_evt_tr}, eval_events={n_evt_ev}, "
          f"time_range=[{times.min():.1f},{times.max():.1f}], uno={uno if np.isfinite(uno) else np.nan:.4f}, "
          f"ibs={ibs if np.isfinite(ibs) else np.nan:.4f}")
    return uno, ibs

In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
import optuna
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pycox.models.loss import CoxPHLoss
import os, copy

cox_loss = CoxPHLoss()

def _nan_guard_batch(xi, xc, d, e):

    for t in (xi, xc, d, e):
        bad = ~torch.isfinite(t)
        if bad.any(): t[bad] = 0
    e = (e > 0.5).float()
    d = torch.clamp(d, min=1e-3)
    return xi, xc, d, e

In [None]:
def train_one_trial(model, dl_train, dl_val,
                    lr_back, lr_head, weight_decay,
                    max_epochs=12, patience=5, trial=None):

    params_back, params_head = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if "img.features" in n:
            params_back.append(p)
        else:
            params_head.append(p)

    opt = AdamW([
        {"params": params_back, "lr": lr_back, "weight_decay": weight_decay},
        {"params": params_head, "lr": lr_head, "weight_decay": weight_decay},
    ])

    scheduler = ReduceLROnPlateau(opt, mode='max', factor=0.5, patience=2, verbose=False)

    best = {"uno": -1.0, "ibs": 1e9, "state": None, "epoch": -1}
    epochs_no_improve = 0

    for ep in range(1, max_epochs+1):
        model.train()
        run_loss, num_ok = 0.0, 0

        for xi, xc, d, e, _ in dl_train:
            xi, xc, d, e = _nan_guard_batch(xi, xc, d, e)
            xi, xc, d, e = xi.to(device), xc.to(device), d.to(device), e.to(device)

            if e.sum().item() < 1:
                continue

            opt.zero_grad(set_to_none=True)
            log_risk = model(xi, xc)
            log_risk = torch.clamp(log_risk, min=-20.0, max=20.0)  
            if not torch.isfinite(log_risk).all():
                continue
            loss = cox_loss(log_risk, d, e)
            if not torch.isfinite(loss):
                continue
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            run_loss += float(loss.detach().cpu().item())
            num_ok += 1

        if num_ok == 0:
            if trial: trial.set_user_attr("note", "no valid batches (too many no-event batches)")
            return best

        try:
            uno, ibs = evaluate_surv(model, dl_train, dl_val)
        except Exception as ex:
            uno, ibs = float('nan'), float('nan')

        scheduler.step(uno if np.isfinite(uno) else -1.0)

        improved = False
        if np.isfinite(uno) and (uno > best["uno"] or (abs(uno-best["uno"])<1e-5 and np.isfinite(ibs) and ibs < best["ibs"])):
            best.update({"uno": float(uno), "ibs": float(ibs), "state": copy.deepcopy(model.state_dict()), "epoch": ep})
            improved = True
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if trial:
            trial.report(float(uno) if np.isfinite(uno) else -1.0, ep)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        if epochs_no_improve >= patience:
            break

    return best

In [None]:
import os, torch, json

best_params = {
    "unfreeze": "your best params",
    "lr_backbone": #your best params,
    "lr_head": #your best params,
    "weight_decay": #your best params,
}
print("[Fixed Phase-B] best_params:", best_params)

model = build_model_from_phaseA()
apply_unfreeze(model, best_params["unfreeze"])

best_fixed = train_one_trial(
    model,
    dl_train, dl_val,
    lr_back=best_params["lr_backbone"],
    lr_head=best_params["lr_head"],
    weight_decay=best_params["weight_decay"],
    max_epochs=15,
    patience=5,
    trial=None,  
)

final_pth = os.path.join(SAVE_DIR, "name of best model")
if best_fixed["state"] is not None:
    torch.save(best_fixed["state"], final_pth)
    print(f"[Fixed Phase-B] saved best to: {final_pth}")
    print(f"[Fixed Phase-B] Val UnoC={best_fixed['uno']:.4f}, IBS={best_fixed['ibs']:.4f}, epoch={best_fixed['epoch']}")
else:
    print("warning: no best ")


model_test = build_model_from_phaseA()
sd = torch.load(final_pth, map_location="cpu")
_ = model_test.load_state_dict(sd, strict=False)
model_test = model_test.to(device)

uno_te, ibs_te = evaluate_surv(model_test, dl_train, dl_test)
print(f"[Test] UnoC={uno_te:.4f}, IBS={ibs_te:.4f}")