In [None]:
!pip -q install \
    diffusers==0.30.3 \
    accelerate==1.1.1 \
    torchmetrics[image]==1.4.0.post0 \
    torchvision==0.19.1 \
    torch-fidelity==0.3.0 \
    lpips==0.1.4 \
    tqdm \
    pillow
!pip -q install einops scipy

from google.colab import drive
drive.mount('/content/drive')


#ENV + IMPORTS + GLOBAL CONFIG

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import re, copy, csv, pathlib, random, json
from pathlib import Path
import numpy as np
from PIL import Image

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision import transforms, utils as tvu
import torchvision.transforms.functional as tvf

from diffusers import UNNet2DModel, DDPMScheduler
from accelerate import Accelerator

from torchmetrics.functional import (
    structural_similarity_index_measure as ssim_fn,
    peak_signal_noise_ratio as psnr_fn,
    mean_squared_error as mse_fn
)
from torchmetrics.image.fid import FrechetInceptionDistance
import lpips
from tqdm import tqdm

# ---------------- params ----------------
IMG       = 128
BATCH     = 8          # physical batch size
GRAD_ACC  = 2          # effective batch ~ 16
EPOCHS    = 200
LR        = 1e-5
T_STEPS   = 1000
DEVICE    = "cuda" if torch.cuda.is_available() else "cpu"
CFG_SCALE = 2.0

# HU windowing (lungs)
WL        = -600.0
WW        = 1500.0

# labels from masks
MIN_NODULE_PX = 10
NODULE_VAL    = 255

# how often to compute heavy metrics
EVAL_EVERY = 10

# --------------- regex ------------------
_re_lung_np   = re.compile(r"^(.*)_lung_mask_(\d+)\.npy$")
_re_comb_np   = re.compile(r"^(.*)_combined_mask_(\d+)\.npy$")
_re_nod_np    = re.compile(r"^(.*)_nodule_mask_(\d+)\.npy$")
_re_slice_npy = re.compile(r"^(.*)_slice_(\d+)\.npy$")

#HELPERS

In [None]:
def _load_npy(p: pathlib.Path) -> np.ndarray:
    a = np.load(p)
    return a[len(a)//2] if a.ndim == 3 else a

def read_patient_list(path: pathlib.Path):
    if not path.exists():
        return set()
    with open(path, "r") as f:
        return set(line.strip().split("_")[0] for line in f if line.strip())

def _patient_from_prefix(prefix: str) -> str:
    return prefix.split("_")[0]

def window_hu_01(img_hu: np.ndarray, wl=WL, ww=WW) -> np.ndarray:
    lo, hi = wl - ww / 2.0, wl + ww / 2.0
    img = np.clip(img_hu, lo, hi)
    return (img - lo) / (hi - lo + 1e-6)

# DATASET

In [None]:
class LungSlice(Dataset):
    """
    NPY-only. Returns:
        lung_mask (1,H,W), ct (1,H,W), class_map (1,H,W), label (0/1).
    """
    def __init__(self, root, img_size=IMG, split=None, forced_patients=None):
        root = pathlib.Path(root)
        self.root, self.img_size = root, img_size

        if forced_patients is not None and len(forced_patients) > 0:
            split_set = forced_patients
        elif split in {"train", "val", "test"}:
            split_set = read_patient_list(root / f"{split}_patients.txt")
        else:
            split_set = None  # use all

        lung_np, nod_np, comb_np, ct_npy = {}, {}, {}, {}
        for p in root.rglob("*.npy"):
            n = p.name
            if   (m := _re_lung_np.match(n)):   lung_np[(m[1], m[2])] = p
            elif (m := _re_nod_np.match(n)):    nod_np[(m[1], m[2])]  = p
            elif (m := _re_comb_np.match(n)):   comb_np[(m[1], m[2])] = p
            elif (m := _re_slice_npy.match(n)): ct_npy[(m[1], m[2])]  = p

        self.trip = []
        keys = set(lung_np.keys()) | set(ct_npy.keys()) | set(nod_np.keys()) | set(comb_np.keys())
        for k in keys:
            prefix, idx = k
            if split_set is not None and _patient_from_prefix(prefix) not in split_set:
                continue

            lung_p = lung_np.get(k)
            if lung_p is None:
                continue

            ct_p = ct_npy.get(k)
            if ct_p is None:
                cand = root / f"{prefix}_slice_{idx}.npy"
                ct_p = cand if cand.exists() else None
            if ct_p is None:
                continue

            nm_p  = nod_np.get(k)
            cmb_p = comb_np.get(k)
            self.trip.append((lung_p, nm_p, cmb_p, ct_p, prefix, idx))

        if not self.trip:
            raise RuntimeError("No (CT slice, lung_mask[, nodule/combined]) pairs found.")

        random.shuffle(self.trip)

        self.mask_tf = transforms.Compose([
            transforms.Resize((img_size, img_size), antialias=True, interpolation=Image.NEAREST),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t > 0.5).float())
        ])

        self.labels, self._precheck = [], []
        for lung_p, nod_p, cmb_p, ct_p, prefix, idx in self.trip:
            if nod_p is not None:
                a = _load_npy(nod_p)
                y = int((a > 0).sum() >= MIN_NODULE_PX)
                src = "nodule_mask"
            elif cmb_p is not None:
                a = _load_npy(cmb_p)
                y = int((a == NODULE_VAL).sum() >= MIN_NODULE_PX)
                src = "combined_mask"
            else:
                y = 0
                src = "none"
            self.labels.append(y)
            self._precheck.append((f"{prefix}_slice_{idx}", y, src))

    def __len__(self):
        return len(self.trip)

    def __getitem__(self, i):
        lung_p, _, _, ct_p, prefix, idx = self.trip[i]

        lung = _load_npy(lung_p)
        lung = ((lung > 0).astype(np.uint8) * 255)
        m_lung = self.mask_tf(Image.fromarray(lung))  # (1,H,W)

        ct = _load_npy(ct_p).astype(np.float32)
        lo, hi = WL - WW / 2.0, WL + WW / 2.0
        ct = np.clip(ct, lo, hi)
        ct = (ct - lo) / (hi - lo + 1e-6)
        ct = torch.from_numpy(ct).unsqueeze(0)
        ct = F.interpolate(ct.unsqueeze(0), size=(self.img_size, self.img_size),
                           mode="bilinear", align_corners=False).squeeze(0)
        ct = ct * 2 - 1

        y = int(self.labels[i])
        class_map = torch.full_like(m_lung, float(y))

        return m_lung, ct, class_map, y


#BALANCED BATCH SAMPLER

In [None]:
class BalancedBatchSampler(Sampler):
    def __init__(self, labels, batch_size):
        assert batch_size % 2 == 0, "BATCH must be even"
        self.batch_size = batch_size
        self.pos = [i for i, y in enumerate(labels) if y == 1]
        self.neg = [i for i, y in enumerate(labels) if y == 0]
        if len(self.pos) == 0 or len(self.neg) == 0:
            raise RuntimeError("Both classes are required.")
        random.shuffle(self.pos)
        random.shuffle(self.neg)
        self.len_batches = min(len(self.pos), len(self.neg)) * 2 // self.batch_size

    def __iter__(self):
        pos_ptr, neg_ptr = 0, 0
        for _ in range(self.len_batches):
            batch = []
            for _ in range(self.batch_size // 2):
                if pos_ptr >= len(self.pos):
                    random.shuffle(self.pos)
                    pos_ptr = 0
                batch.append(self.pos[pos_ptr])
                pos_ptr += 1
            for _ in range(self.batch_size // 2):
                if neg_ptr >= len(self.neg):
                    random.shuffle(self.neg)
                    neg_ptr = 0
                batch.append(self.neg[neg_ptr])
                neg_ptr += 1
            random.shuffle(batch)
            yield batch

    def __len__(self):
        return self.len_batches

#MODEL + SCHEDULER

In [None]:
def build_model(img_size=IMG, t_steps=T_STEPS):
    unet = UNet2DModel(
        sample_size        = img_size,
        in_channels        = 3,   # [lung_mask, class_map, x_t]
        out_channels       = 1,
        block_out_channels = (32, 64, 96),
        down_block_types   = ("DownBlock2D",) * 3,
        up_block_types     = ("UpBlock2D",) * 3,
        layers_per_block   = 1,
    )
    sched = DDPMScheduler(num_train_timesteps=t_steps)
    return unet, sched


#EMA

In [None]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        self.register(model)

    def register(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.detach().clone()

    def update(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)

    def copy_to(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                p.data.copy_(self.shadow[name].data)


#SAMPLING GRID

In [None]:
@torch.no_grad()
def sample_grid(unet, base_sched, dataset, fn, steps=50, img_size=IMG, device=DEVICE):
    unet.eval()
    test_indices = list(range(min(len(dataset), 128)))

    pos, neg = [], []
    for idx in test_indices:
        ml, ct, cls, y = dataset[idx]
        (pos if y == 1 else neg).append((idx, ml, ct, cls, y))
    random.seed(0)
    random.shuffle(pos)
    random.shuffle(neg)
    cases = (pos[:2] if len(pos) >= 2 else pos) + (neg[:2] if len(neg) >= 2 else neg)
    if len(cases) == 0:
        return

    lung = torch.stack([c[1] for c in cases]).to(device)
    ctgt = torch.stack([c[2] for c in cases])
    cls  = torch.stack([c[3] for c in cases]).to(device)

    scheduler = copy.deepcopy(base_sched)
    scheduler.set_timesteps(steps, device=device)

    x = torch.randn((len(cases), 1, img_size, img_size), device=device)
    for t in scheduler.timesteps:
        inp_c = torch.cat([lung, cls, x], 1)
        inp_u = torch.cat([lung, torch.zeros_like(cls), x], 1)
        eps_c = unet(inp_c, t).sample
        eps_u = unet(inp_u, t).sample
        eps   = eps_u + CFG_SCALE * (eps_c - eps_u)
        x     = scheduler.step(eps, t, x).prev_sample
    gen = x.clamp(-1, 1).cpu()

    rows = []
    for y, g in zip(ctgt, gen):
        rows.extend([y.repeat(3, 1, 1), g.repeat(3, 1, 1)])
    grid = tvu.make_grid(rows, nrow=2, normalize=True, value_range=(-1, 1))
    tvf.to_pil_image(grid).save(fn)


# %% FIXED TEST SUBSET

def build_test_set(dataset, save_dir, n_pos=16, n_neg=16):
    pos_idx, neg_idx = [], []
    for i in range(len(dataset)):
        _, _, _, y = dataset[i]
        if y == 1 and len(pos_idx) < n_pos:
            pos_idx.append(i)
        elif y == 0 and len(neg_idx) < n_neg:
            neg_idx.append(i)
        if len(pos_idx) == n_pos and len(neg_idx) == n_neg:
            break
    if len(pos_idx) < n_pos:
        fill = n_pos - len(pos_idx)
        pos_idx += neg_idx[:fill]
        neg_idx = neg_idx[fill:]
    if len(neg_idx) < n_neg:
        fill = n_neg - len(neg_idx)
        neg_idx += pos_idx[:fill]
        pos_idx = pos_idx[fill:]
    test_indices = pos_idx + neg_idx
    os.makedirs(save_dir, exist_ok=True)
    np.save(os.path.join(save_dir, "test_indices.npy"), np.array(test_indices, dtype=np.int32))
    return test_indices

# METRICS OBJECTS (GLOBAL)

In [None]:
fid_metric_global   = FrechetInceptionDistance(normalize=True).to(DEVICE)
lpips_metric_global = lpips.LPIPS(net='alex').to(DEVICE)


@torch.no_grad()
def compute_fid_lpips_masked(real_ct, gen_ct, lung_mask):
    real = (real_ct * lung_mask + 1) / 2
    fake = (gen_ct  * lung_mask + 1) / 2
    real3 = real.repeat(1, 3, 1, 1).to(DEVICE)
    fake3 = fake.repeat(1, 3, 1, 1).to(DEVICE)
    real299 = F.interpolate(real3, size=(299, 299), mode="bilinear", align_corners=False)
    fake299 = F.interpolate(fake3, size=(299, 299), mode="bilinear", align_corners=False)

    fid_metric_global.reset()
    fid_metric_global.update(real299, real=True)
    fid_metric_global.update(fake299, real=False)
    fid_val = float(fid_metric_global.compute().cpu())

    lpips_vals = []
    for r, g in zip(real3, fake3):
        lpips_vals.append(float(lpips_metric_global(r.unsqueeze(0), g.unsqueeze(0)).detach().cpu()))
    return fid_val, float(np.mean(lpips_vals))


def inception_score_placeholder(gen_imgs):
    return float("nan"), float("nan")


# EVALUATION

In [None]:
@torch.no_grad()
def evaluate_model(
    unet,
    base_sched,
    dataset,
    test_indices,
    save_dir,
    epoch,
    steps=1000,
    img_size=IMG,
    device=DEVICE,
    use_ema=False,
    ema=None,
    compute_heavy=False
):
    if use_ema and ema is not None:
        bak = {k: v.detach().clone() for k, v in unet.state_dict().items()}
        ema.copy_to(unet)

    unet.eval()
    batch = [dataset[i] for i in test_indices]
    if len(batch) == 0:
        metrics = {
            "psnr_mean": 0, "psnr_std": 0,
            "ssim_mean": 0, "ssim_std": 0,
            "mse_mean": 0,  "mse_std": 0,
            "fid": 0, "lpips": 0, "is_mean": 0, "is_std": 0
        }
        if use_ema and ema is not None:
            unet.load_state_dict(bak)
        return metrics

    lung = torch.stack([b[0] for b in batch]).to(device)
    ct   = torch.stack([b[1] for b in batch])
    cls  = torch.stack([b[2] for b in batch]).to(device)

    scheduler = copy.deepcopy(base_sched)
    scheduler.set_timesteps(steps, device=device)

    x = torch.randn((len(batch), 1, img_size, img_size), device=device)
    for t in scheduler.timesteps:
        inp_c = torch.cat([lung, cls, x], 1)
        inp_u = torch.cat([lung, torch.zeros_like(cls), x], 1)
        eps_c = unet(inp_c, t).sample
        eps_u = unet(inp_u, t).sample
        eps   = eps_u + CFG_SCALE * (eps_c - eps_u)
        x     = scheduler.step(eps, t, x).prev_sample
    gen = x.clamp(-1, 1).cpu()

    psnr_list, ssim_list, mse_list = [], [], []
    lung_cpu = lung.cpu()
    for j, (g, y) in enumerate(zip(gen, ct)):
        mask = lung_cpu[j]
        g_m = g * mask
        y_m = y * mask
        g4, y4 = g_m.unsqueeze(0), y_m.unsqueeze(0)
        psnr_list.append(float(psnr_fn(g4, y4, data_range=2.0)))
        ssim_list.append(float(ssim_fn(g4, y4, data_range=2.0)))
        mse_list.append(float(mse_fn(g4, y4)))

    fid_val = lpips_val = float("nan")
    is_mean = is_std = float("nan")
    if compute_heavy:
        fid_val, lpips_val = compute_fid_lpips_masked(ct, gen, lung_cpu)
        is_mean, is_std    = inception_score_placeholder(gen)

    metrics = {
        "psnr_mean": float(np.mean(psnr_list)), "psnr_std": float(np.std(psnr_list)),
        "ssim_mean": float(np.mean(ssim_list)), "ssim_std": float(np.std(ssim_list)),
        "mse_mean":  float(np.mean(mse_list)),  "mse_std":  float(np.std(mse_list)),
        "fid":   0.0 if np.isnan(fid_val)   else float(fid_val),
        "lpips": 0.0 if np.isnan(lpips_val) else float(lpips_val),
        "is_mean": 0.0 if np.isnan(is_mean) else float(is_mean),
        "is_std":  0.0 if np.isnan(is_std)  else float(is_std),
    }

    rows = []
    for j in range(min(4, len(gen))):
        rows.extend([ct[j].repeat(3, 1, 1), gen[j].repeat(3, 1, 1)])
    grid = tvu.make_grid(rows, nrow=2, normalize=True, value_range=(-1, 1))
    eval_png = os.path.join(save_dir, f"eval_ep{epoch:03d}.png")
    os.makedirs(save_dir, exist_ok=True)
    tvf.to_pil_image(grid).save(eval_png)

    if use_ema and ema is not None:
        unet.load_state_dict(bak)

    return metrics

# CHECKPOINT UTILS

In [None]:
def get_last_checkpoint(save_dir):
    eps = []
    for d in os.listdir(save_dir):
        if d.startswith("ep") and d[2:].isdigit():
            eps.append(int(d[2:]))
    return max(eps) if eps else None


# TRAINING LOOP

In [None]:
def train(
    dataset_dir,
    save_dir,
    img_size=IMG,
    batch_size=BATCH,
    epochs=EPOCHS,
    lr=LR,
    t_steps=T_STEPS,
    resume=False,
    max_batches=None,
    use_ema=True
):
    os.makedirs(save_dir, exist_ok=True)
    print(f"Save dir: {save_dir}")

    root = pathlib.Path(dataset_dir)

    # train = train + test patients
    train_pat = read_patient_list(root / "train_patients.txt")
    test_pat  = read_patient_list(root / "test_patients.txt")
    train_all = train_pat | test_pat

    ds_train = LungSlice(dataset_dir, img_size=img_size, split=None, forced_patients=train_all)
    ds_val   = LungSlice(dataset_dir, img_size=img_size, split="val")
    print(f"Train={len(ds_train)}  Val={len(ds_val)}")

    test_idx_path = os.path.join(save_dir, "test_indices.npy")
    if os.path.exists(test_idx_path):
        test_indices = np.load(test_idx_path).astype(int).tolist()
        print("Loaded test_indices.npy")
    else:
        print("Building test set (16 pos + 16 neg) from VAL")
        test_indices = build_test_set(ds_val, save_dir, n_pos=16, n_neg=16)

    sampler = BalancedBatchSampler(ds_train.labels, batch_size)
    loader  = DataLoader(
        ds_train,
        batch_sampler=sampler,
        pin_memory=False,
        num_workers=2,
        persistent_workers=False
    )

    unet, scheduler = build_model(img_size=img_size, t_steps=t_steps)

    optim = torch.optim.AdamW(unet.parameters(), lr=lr)
    acc = Accelerator(mixed_precision="fp16" if DEVICE == "cuda" else "no")
    unet, optim, loader = acc.prepare(unet, optim, loader)

    ema = EMA(unet, decay=0.999) if use_ema else None

    start_epoch = 1
    if resume:
        last = get_last_checkpoint(save_dir)
        if last is not None:
            ckpt_path = os.path.join(save_dir, f"ep{last:03d}", "diffusion_pytorch_model.bin")
            print(f"Resuming from ep{last:03d}")
            unet.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
            if use_ema and os.path.exists(os.path.join(save_dir, f"ep{last:03d}", "diffusion_pytorch_model_ema.bin")):
                ema_state = torch.load(
                    os.path.join(save_dir, f"ep{last:03d}", "diffusion_pytorch_model_ema.bin"),
                    map_location=DEVICE
                )
                unet.load_state_dict(ema_state, strict=False)
                ema.register(unet)
            start_epoch = last + 1
        else:
            print("No checkpoint found. Starting from scratch.")

    metrics_csv = os.path.join(save_dir, "metrics.csv")
    if not os.path.exists(metrics_csv):
        with open(metrics_csv, "w", newline='') as f:
            csv.writer(f).writerow(
                ["epoch", "psnr_mean", "psnr_std", "ssim_mean", "ssim_std",
                 "mse_mean", "mse_std", "fid", "lpips", "is_mean", "is_std"]
            )

    print("Sample sanity from TRAIN:")
    for name, y, src in ds_train._precheck[:5]:
        print(f"  {name}: label={y} src={src}")

    for ep in range(start_epoch, epochs + 1):
        torch.cuda.empty_cache()
        unet.train()
        tot = 0
        n_seen = 0
        pbar = tqdm(loader, desc=f"ep {ep:03d}", ncols=80)
        for i, batch in enumerate(pbar, start=1):
            lung, ct, cls_map, _ = batch
            b = ct.size(0)
            t = torch.randint(0, scheduler.config.num_train_timesteps, (b,), device=ct.device)
            noise = torch.randn_like(ct)
            x_t = scheduler.add_noise(ct, noise, t)

            inp = torch.cat([lung, cls_map, x_t], 1)
            eps_hat = unet(inp, t).sample
            loss = F.mse_loss(eps_hat, noise) / GRAD_ACC

            acc.backward(loss)
            if i % GRAD_ACC == 0:
                optim.step()
                optim.zero_grad(set_to_none=True)

            tot += float(loss.item()) * b * GRAD_ACC
            n_seen += b

            if use_ema and ema is not None:
                ema.update(unet)

            if max_batches is not None and i >= max_batches:
                break

        if acc.is_main_process:
            print(f"[ep {ep:03d}] loss={tot / max(1, n_seen):.4f}")

            ep_dir = os.path.join(save_dir, f"ep{ep:03d}")
            os.makedirs(ep_dir, exist_ok=True)
            torch.save(unet.state_dict(), os.path.join(ep_dir, "diffusion_pytorch_model.bin"))
            if use_ema and ema is not None:
                ema_model = copy.deepcopy(unet).to("cpu")
                ema.copy_to(ema_model)
                torch.save(ema_model.state_dict(), os.path.join(ep_dir, "diffusion_pytorch_model_ema.bin"))
                del ema_model

            sample_grid(
                unet, scheduler, ds_val,
                os.path.join(save_dir, f"samples_ep{ep:03d}.png"),
                steps=50, img_size=img_size, device=DEVICE
            )

            compute_heavy = (ep % EVAL_EVERY == 0)
            metrics = evaluate_model(
                unet, scheduler, ds_val, test_indices, save_dir, epoch=ep,
                steps=1000, img_size=img_size, device=DEVICE,
                use_ema=(use_ema and ema is not None), ema=ema,
                compute_heavy=compute_heavy
            )
            print(f"  PSNR: {metrics['psnr_mean']:.2f} ± {metrics['psnr_std']:.2f}")
            print(f"  SSIM: {metrics['ssim_mean']:.3f} ± {metrics['ssim_std']:.3f}")
            print(f"  MSE:  {metrics['mse_mean']:.5f} ± {metrics['mse_std']:.5f}")
            if compute_heavy:
                print(f"  FID:  {metrics['fid']:.2f}")
                print(f"  LPIPS:{metrics['lpips']:.4f}")
                print(f"  IS:   {metrics['is_mean']:.3f} ± {metrics['is_std']:.3f}")

            with open(metrics_csv, "a", newline='') as f:
                csv.writer(f).writerow([
                    ep,
                    metrics['psnr_mean'], metrics['psnr_std'],
                    metrics['ssim_mean'], metrics['ssim_std'],
                    metrics['mse_mean'],  metrics['mse_std'],
                    metrics['fid'], metrics['lpips'],
                    metrics['is_mean'], metrics['is_std']
                ])

    print("Done.")