<a href="https://colab.research.google.com/github/ZHAOTIEZHU2333/COMP3702-A1-Code-2025-main/blob/main/Part4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, math, glob, argparse, random, json
from pathlib import Path
import numpy as np
from typing import List, Tuple, Sequence, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    from PIL import Image
    _HAVE_PIL = True
except Exception:
    _HAVE_PIL = False

def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def pick_device():
    if torch.cuda.is_available(): return torch.device("cuda")
    return torch.device("cpu")

def zscore(x: np.ndarray, eps=1e-6):
    m, s = x.mean(), x.std()
    return (x - m) / (s + eps)

def minmax01(x: np.ndarray, eps=1e-6):
    mi, ma = x.min(), x.max()
    return (x - mi) / (ma - mi + eps)

class OASISSliceDataset(Dataset):
    def __init__(self,
                 items: Sequence[Tuple[Path, Optional[Path]]],
                 plane: str = "axial",
                 slice_step: int = 1,
                 classes: int = 4,
                 for_vae: bool = False,
                 ignore_label: Optional[int] = None):
        self.items = list(items)
        self.plane = plane
        self.slice_step = slice_step
        self.classes = classes
        self.for_vae = for_vae
        self.ignore_label = ignore_label
        self.slices: List[Tuple[Path, Optional[Path], int]] = []
        self._index_slices()

    def _index_slices(self):
        for (ip, mp) in self.items:
            img = self._load_volume(ip)
            D, H, W = img.shape
            if mp is not None:
                msk = self._load_volume(mp).astype(np.int64)
                assert msk.shape == img.shape, f"Mask shape mismatch for {ip}"
            else:
                msk = None
            depth = D
            for s in range(0, depth, self.slice_step):
                self.slices.append((ip, mp, s))

    def _load_volume(self, p: Path) -> np.ndarray:
        # Corrected to handle single image files which might not be 3D
        img = Image.open(p).convert("L")
        arr = np.array(img, dtype=np.float32)
        if arr.ndim == 2:
            arr = arr[None, ...] # Add channel dimension for 2D images
        return arr

    def _extract(self, vol: np.ndarray, s: int) -> np.ndarray:
        # This method assumes a 3D volume, need to handle 2D case
        if vol.ndim == 3:
             return vol[s, :, :]
        elif vol.ndim == 2: # Handle case where the input is already a 2D slice
             return vol
        else:
             raise ValueError(f"Unexpected volume dimension: {vol.ndim}")

    def __len__(self): return len(self.slices)

    def __getitem__(self, idx: int):
        ip, mp, s = self.slices[idx]
        img3d = self._load_volume(ip).astype(np.float32)
        # Handle cases where _load_volume returns a 2D image directly
        if img3d.ndim == 2:
            img2d = img3d
        else:
            img2d = self._extract(img3d, s)

        img2d = zscore(img2d)
        img2d = np.clip(img2d, -5, 5)
        img2d = minmax01(img2d)
        img = torch.from_numpy(img2d[None, ...]).float()

        if self.for_vae or (mp is None):
            # Return a dummy mask if for VAE or no mask is provided
            # Assuming a consistent image size for dummy mask
            dummy_mask = torch.zeros(img.shape[2:], dtype=torch.long)
            return img, dummy_mask # Returning a dummy mask instead of a scalar 0

        msk3d = self._load_volume(mp).astype(np.int64)
        # Handle cases where _load_volume returns a 2D image directly
        if msk3d.ndim == 2:
            msk2d = msk3d
        else:
            msk2d = self._extract(msk3d, s)

        uniq = sorted(int(v) for v in np.unique(msk2d))
        bg = 0 if 0 in uniq else uniq[0]
        non_bg = [v for v in uniq if v != bg]
        keep = [bg] + non_bg[:max(0, self.classes - 1)]
        lut = {v: i for i, v in enumerate(keep)}
        msk2d = np.vectorize(lambda v: lut.get(int(v), 0))(msk2d).astype(np.int64)

        msk = torch.from_numpy(msk2d).long()
        return img, msk

class VAE(nn.Module):
    def __init__(self, z_dim=64):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 3, 2, 1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, 2, 1), nn.ReLU(inplace=True),
        )
        self.to_mu = nn.Conv2d(128, z_dim, 1)
        self.to_logv = nn.Conv2d(128, z_dim, 1)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 128, 2, 2), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 2, 2), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 2, 2), nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 3, 1, 1), nn.Sigmoid(),
        )

    def reparam(self, mu, logv):
        std = (0.5 * logv).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.enc(x)
        mu = self.to_mu(h)
        logv = self.to_logv(h)
        z = self.reparam(mu, logv)
        xrec = self.dec(z)
        return xrec, mu, logv

def kld_loss(mu, logv):
    return 0.5 * torch.mean(torch.sum(torch.exp(logv) + mu**2 - 1.0 - logv, dim=[1,2,3]))

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, 1, 1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, 1, 1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNet2D(nn.Module):
    def __init__(self, in_ch=1, n_classes=4, base=32):
        super().__init__()
        self.d1 = DoubleConv(in_ch, base)
        self.p1 = nn.MaxPool2d(2)
        self.d2 = DoubleConv(base, base*2)
        self.p2 = nn.MaxPool2d(2)
        self.d3 = DoubleConv(base*2, base*4)
        self.p3 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(base*4, base*8)
        self.u3 = nn.ConvTranspose2d(base*8, base*4, 2, 2)
        self.u2 = nn.ConvTranspose2d(base*4, base*2, 2, 2)
        self.u1 = nn.ConvTranspose2d(base*2, base,   2, 2)
        self.c3 = DoubleConv(base*8, base*4)
        self.c2 = DoubleConv(base*4, base*2)
        self.c1 = DoubleConv(base*2, base)
        self.head = nn.Conv2d(base, n_classes, 1)

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(self.p1(d1))
        d3 = self.d3(self.p2(d2))
        bn = self.bottleneck(self.p3(d3))
        x  = self.u3(bn)
        x  = self.c3(torch.cat([x, d3], dim=1))
        x  = self.u2(x)
        x  = self.c2(torch.cat([x, d2], dim=1))
        x  = self.u1(x)
        x  = self.c1(torch.cat([x, d1], dim=1))
        logits = self.head(x)
        return logits

def transplant_encoder_from_vae(unet: UNet2D, vae: VAE):
    vae_layers = [m for m in vae.enc.modules() if isinstance(m, nn.Conv2d)]
    unet_layers = []
    for block in [unet.d1, unet.d2, unet.d3]:
        for m in block.modules():
            if isinstance(m, nn.Conv2d): unet_layers.append(m)
    k = min(len(vae_layers), len(unet_layers))
    with torch.no_grad():
        for i in range(k):
            if vae_layers[i].weight.shape == unet_layers[i].weight.shape:
                unet_layers[i].weight.copy_(vae_layers[i].weight)
                if (vae_layers[i].bias is not None) and (unet_layers[i].bias is not None):
                    unet_layers[i].bias.copy_(vae_layers[i].bias)

class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1.0, ignore_index: Optional[int] = None):
        super().__init__()
        self.smooth = smooth
        self.ignore_index = ignore_index

    def forward(self, logits, targets_onehot):
        probs = torch.softmax(logits, dim=1)
        dims = (0,2,3)
        intersection = torch.sum(probs * targets_onehot, dim=dims)
        cardinality = torch.sum(probs + targets_onehot, dim=dims)
        dice = (2. * intersection + self.smooth) / (cardinality + self.smooth)
        return 1. - dice.mean()

@torch.no_grad()
def dice_per_class(logits, targets, num_classes, ignore_index: Optional[int] = None):
    probs = torch.softmax(logits, dim=1)
    pred = probs.argmax(dim=1)
    dices = []
    for c in range(num_classes):
        if (ignore_index is not None) and (c == ignore_index):
            dices.append(float("nan"))
            continue
        pc = (pred == c).float()
        tc = (targets == c).float()
        inter = (pc*tc).sum().item()
        denom = pc.sum().item() + tc.sum().item()
        d = (2*inter + 1.0) / (denom + 1.0)
        dices.append(d)
    return dices

import re

def discover_pairs_keras(root: Path):
    img_dirs = [
        root / "keras_png_slices_train",
        root / "keras_png_slices_validate",
        root / "keras_png_slices_test",
    ]
    seg_dirs = [
        root / "keras_png_slices_seg_train",
        root / "keras_png_slices_seg_validate",
        root / "keras_png_slices_seg_test",
    ]

    # Check if root directory exists
    if not root.exists():
        print(f"[ERROR] Data root directory not found: {root}")
        return []

    # Check if expected subdirectories exist
    for d in img_dirs + seg_dirs:
        if not d.exists():
            print(f"[ERROR] Expected data subdirectory not found: {d}")
            print("Please ensure your data is organized correctly within the data-root directory.")
            return []

    pairs = []
    for idr, sdr in zip(img_dirs, seg_dirs):
        img_map = {}
        for ip in idr.glob("*.png"):
            # Extract the common part from the image file name
            m = re.search(r'^case_(.+)\.nii\.png$', ip.name)
            if m:
                key = m.group(1)
                img_map[key] = ip

        seg_map = {}
        for sp in sdr.glob("*.png"):
            # Extract the common part from the segmentation file name
            m = re.search(r'^seg_(.+)\.nii\.png$', sp.name)
            if m:
                key = m.group(1)
                seg_map[key] = sp

        # Find common keys (base names) and create pairs
        common_keys = img_map.keys() & seg_map.keys()
        for k in sorted(common_keys):
            pairs.append((img_map[k], seg_map[k]))

    return pairs

def split_train_val(items, val_ratio=0.1, seed=42):
    rng = np.random.RandomState(seed)
    idx = np.arange(len(items)); rng.shuffle(idx)
    n_val = max(1, int(len(items)*val_ratio))
    val_idx, tr_idx = idx[:n_val], idx[n_val:]
    tr = [items[i] for i in tr_idx]
    va = [items[i] for i in val_idx]
    return tr, va

def train_vae(vae, loader, epochs, device, lr=1e-3):
    vae = vae.to(device)
    opt = torch.optim.Adam(vae.parameters(), lr=lr)
    scaler = torch.amp.GradScaler(enabled=(device.type=="cuda"))
    for ep in range(1, epochs+1):
        vae.train()
        total = 0.0
        for x, _ in loader:
            x = x.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")):
                xrec, mu, logv = vae(x)
            with torch.amp.autocast(device_type='cpu', enabled=False):
                rec = F.binary_cross_entropy(xrec.float(), x.float(), reduction="mean")
                kld = kld_loss(mu.float(), logv.float())
                loss = rec + 1e-3 * kld
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
            total += loss.item()*x.size(0)
        print(f"[VAE] epoch {ep}/{epochs} | loss={total/len(loader.dataset):.4f}")
    return vae

def train_unet(unet, loaders, epochs, device, n_classes, ignore_index=None, lr=1e-3):
    train_loader, val_loader = loaders
    unet = unet.to(device)
    ce = nn.CrossEntropyLoss()
    dice = SoftDiceLoss(ignore_index=ignore_index)
    opt = torch.optim.Adam(unet.parameters(), lr=lr)
    scaler = torch.amp.GradScaler(enabled=(device.type=="cuda"))
    best = {"epoch":0, "mean_dsc":0.0}
    for ep in range(1, epochs+1):
        unet.train()
        for x, y in train_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            y1h = F.one_hot(y, n_classes).permute(0,3,1,2).float()
            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")):
                logits = unet(x)
                loss = ce(logits, y) + dice(logits, y1h)
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        unet.eval()
        all_dsc = []
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device); y = y.to(device)
                logits = unet(x)
                dscs = dice_per_class(logits, y, n_classes, ignore_index)
                all_dsc.append(np.array(dscs))
        dsc_mean = float(np.nanmean(np.vstack(all_dsc), axis=0).mean())
        print(f"[UNet] epoch {ep}/{epochs} | mean DSC={dsc_mean:.4f}")
        if dsc_mean > best["mean_dsc"]:
            best = {"epoch": ep, "mean_dsc": dsc_mean}
    print(f"[UNet] best mean DSC={best['mean_dsc']:.4f} @ epoch {best['epoch']}")
    return unet


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data-root", type=str, default="/content/drive/MyDrive/Colab Notebooks/keras_png_slices_data")
    ap.add_argument("--plane", type=str, default="axial", choices=["axial","sagittal","coronal"])
    ap.add_argument("--slice-step", type=int, default=1)
    ap.add_argument("--classes", type=int, default=4)
    ap.add_argument("--ignore-index", type=int, default=None)
    ap.add_argument("--batch", type=int, default=32)
    ap.add_argument("--workers", type=int, default=2)
    ap.add_argument("--vae-epochs", type=int, default=3)
    ap.add_argument("--unet-epochs", type=int, default=10)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--out", type=str, default="runs_part4_colab")
    args, _ = ap.parse_known_args()
    return args

def main():
    args = parse_args()
    seed_everything(args.seed)
    device = pick_device()
    print(f"[INFO] device={device}")
    root = Path(args.data_root)
    pairs = discover_pairs_keras(root)
    assert len(pairs) > 0, f"No image/mask png pairs found under {root}"
    train_items, val_items = split_train_val(pairs, val_ratio=0.1, seed=args.seed)
    print(f"[DATA] subjects: train={len(train_items)} val={len(val_items)}")
    vae_train = OASISSliceDataset(train_items, plane=args.plane, slice_step=args.slice_step,
                                  classes=args.classes, for_vae=True)
    seg_train = OASISSliceDataset(train_items, plane=args.plane, slice_step=args.slice_step,
                                  classes=args.classes, for_vae=False, ignore_label=args.ignore_index)
    seg_val   = OASISSliceDataset(val_items,   plane=args.plane, slice_step=args.slice_step,
                                  classes=args.classes, for_vae=False, ignore_label=args.ignore_index)
    vae_loader = DataLoader(vae_train, batch_size=args.batch, shuffle=True,
                            num_workers=args.workers, pin_memory=(device.type=="cuda"))
    tr_loader  = DataLoader(seg_train, batch_size=args.batch, shuffle=True,
                            num_workers=args.workers, pin_memory=(device.type=="cuda"))
    va_loader  = DataLoader(seg_val,   batch_size=args.batch, shuffle=False,
                            num_workers=args.workers, pin_memory=(device.type=="cuda"))
    vae = VAE(z_dim=64)
    vae = train_vae(vae, vae_loader, epochs=args.vae_epochs, device=device, lr=1e-3)
    unet = UNet2D(in_ch=1, n_classes=args.classes, base=32)
    transplant_encoder_from_vae(unet, vae)
    unet = train_unet(unet, (tr_loader, va_loader), epochs=args.unet_epochs,
                      device=device, n_classes=args.classes, ignore_index=args.ignore_index, lr=1e-3)
    unet.eval()
    all_dsc = []
    with torch.no_grad():
        for x, y in va_loader:
            x = x.to(device); y = y.to(device)
            logits = unet(x)
            dscs = dice_per_class(logits, y, args.classes, args.ignore_index)
            all_dsc.append(np.array(dscs))
    dsc_matrix = np.vstack(all_dsc)
    dsc_mean_per_class = np.nanmean(dsc_matrix, axis=0).tolist()
    dsc_mean = float(np.nanmean(dsc_matrix))
    os.makedirs(args.out, exist_ok=True)
    with open(Path(args.out)/"metrics.json","w") as f:
        json.dump({"mean_dsc": dsc_mean, "dsc_per_class": dsc_mean_per_class}, f, indent=2)
    print(f"[RESULT] mean DSC={dsc_mean:.4f} | per-class={np.round(dsc_mean_per_class,4)}")
    print(f"[SAVE] {Path(args.out)/'metrics.json'}")

if __name__ == "__main__":
    main()

[INFO] device=cuda
[DATA] subjects: train=10196 val=1132


  scaler = GradScaler(enabled=(device.type=="cuda"))
  with autocast(enabled=(device.type=="cuda")):
  with autocast(enabled=False):


[VAE] epoch 1/3 | loss=0.4097
[VAE] epoch 2/3 | loss=0.3948
[VAE] epoch 3/3 | loss=0.3935


  scaler = GradScaler(enabled=(device.type=="cuda"))
  with autocast(enabled=(device.type=="cuda")):


KeyboardInterrupt: 