In [4]:
# =============================================
# 3D U-Net for CellMap multi-class segmentation (crop-based KFold)
# Input channels: Raw + Gaussian(sigma=2)  (recommended for 3D U-Net)
# Labels: background(0) + 5 classes (1..5)
# Data loading aligned with your .zattrs method
# =============================================

import os
import json
import time
import random
import numpy as np
import zarr  # type: ignore
from tqdm import tqdm  # type: ignore
from scipy.ndimage import gaussian_filter, sobel  # type: ignore

from sklearn.model_selection import KFold  # type: ignore

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt  # type: ignore

In [5]:
# -----------------------------
# Config
# -----------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Crops
CROP_IDS = ["crop292"]
# CROP_IDS = ["crop292", "crop234", "crop236", "crop237", "crop239"]

RAW_S0 = r"../data/jrc_cos7-1a/jrc_cos7-1a.zarr/recon-1/em/fibsem-uint8/s0"
GROUNDTRUTH_ROOT = r"../data/jrc_cos7-1a/jrc_cos7-1a.zarr/recon-1/labels/groundtruth"

SELECT_CLASSES = {
    "cyto": 35,
    "mito_mem": 3,
    "mito_lum": 4,
    "er_mem": 16,
    "er_lum": 17,
}

CLASS_ID_MAP = {
    "cyto": 1,
    "mito_mem": 2,
    "mito_lum": 3,
    "er_mem": 4,
    "er_lum": 5,
}

CLASS_NAMES = ["bg", "cyto", "mito_mem", "mito_lum", "er_mem", "er_lum"]
NUM_CLASSES = 6  # 0..5

REF_CLASS = "nucpl"

# Training
EPOCHS = 20
BATCH_SIZE = 2
LR = 2e-4
WEIGHT_DECAY = 1e-5

# Patch sampling (adjust to your GPU memory)
PATCH_ZYX = (32, 128, 128)  # (Z,Y,X)
PATCHES_PER_EPOCH = 300     # per fold
VAL_PATCHES = 80

# Sliding-window inference for full volume evaluation
INFER_STRIDE = (16, 64, 64)  # overlap = patch - stride

# Output
OUT_DIR = "../Result/unet3d_runs"
os.makedirs(OUT_DIR, exist_ok=True)

In [6]:
# -----------------------------
# Feature extraction for CNN
# Recommended for 3D U-Net: 2 channels = [raw, gauss(s=2)]
# (Optional 3rd channel = gradmag2 if you want)
# -----------------------------
USE_GRADMAG2_AS_3RD_CH = False

def make_cnn_input(raw_uint8_zyx: np.ndarray) -> np.ndarray:
    """
    Input: raw uint8 (Z,Y,X)
    Output: float32 (C,Z,Y,X)
      C=2 or 3
    """
    img = raw_uint8_zyx.astype(np.float32) / 255.0  # raw
    g2 = gaussian_filter(img, sigma=2.0)

    if not USE_GRADMAG2_AS_3RD_CH:
        x = np.stack([img, g2], axis=0).astype(np.float32)
        return x

    # Optional: GradMag on g2 (XY only)
    gx = sobel(g2, axis=2)
    gy = sobel(g2, axis=1)
    gradmag2 = np.sqrt(gx * gx + gy * gy)
    x = np.stack([img, g2, gradmag2], axis=0).astype(np.float32)
    return x


# -----------------------------
# Data loading (aligned to your method)
# -----------------------------
def load_one_crop(crop_id: str, raw_zarr) -> dict:
    """
    Returns dict:
      raw: uint8 (Z,Y,X)
      label: uint8 (Z,Y,X) 0..5
      id: str
      shape: (Z,Y,X)
    """
    crop_root = os.path.join(GROUNDTRUTH_ROOT, crop_id)
    ref_s0 = os.path.join(crop_root, REF_CLASS, "s0")
    ref_zattr = os.path.join(crop_root, REF_CLASS, ".zattrs")

    ref_arr = zarr.open(ref_s0, mode="r")
    Dz, Dy, Dx = ref_arr.shape

    with open(ref_zattr, "r") as f:
        attrs = json.load(f)
    ms = attrs["multiscales"][0]["datasets"][0]
    scale = ms["coordinateTransformations"][0]["scale"]
    trans = ms["coordinateTransformations"][1]["translation"]
    scale_z, scale_y, scale_x = scale
    tz, ty, tx = trans

    vz0 = int(tz / scale_z)
    vy0 = int(ty / scale_y)
    vx0 = int(tx / scale_x)
    vz1, vy1, vx1 = vz0 + Dz, vy0 + Dy, vx0 + Dx

    raw_crop = raw_zarr[vz0:vz1, vy0:vy1, vx0:vx1]  # uint8 (Z,Y,X)

    # Build multi-class label
    label_multi = np.zeros((Dz, Dy, Dx), dtype=np.uint8)
    for cname in SELECT_CLASSES.keys():
        path = os.path.join(crop_root, cname, "s0")
        try:
            arr = zarr.open(path, mode="r")[:]  # binary mask
            cid = CLASS_ID_MAP[cname]
            label_multi[arr > 0] = cid
        except Exception as e:
            print(f"Warning: failed load class {cname} in {crop_id}: {e}")

    return {"raw": raw_crop, "label": label_multi, "shape": (Dz, Dy, Dx), "id": crop_id}


In [7]:
# -----------------------------
# Patch sampler dataset
# -----------------------------
class RandomPatchDataset(Dataset):
    def __init__(self, crops: list[dict], n_patches: int, patch_zyx=(32,128,128)):
        self.crops = crops
        self.n_patches = n_patches
        self.pz, self.py, self.px = patch_zyx

        # Precompute per-crop valid ranges
        self.ranges = []
        for c in self.crops:
            Dz, Dy, Dx = c["shape"]
            assert Dz >= self.pz and Dy >= self.py and Dx >= self.px, f"Crop too small: {c['id']}"
            self.ranges.append((Dz - self.pz, Dy - self.py, Dx - self.px))

    def __len__(self):
        return self.n_patches

    def __getitem__(self, idx):
        # pick a crop
        ci = np.random.randint(0, len(self.crops))
        crop = self.crops[ci]
        Dz_off, Dy_off, Dx_off = self.ranges[ci]

        z0 = np.random.randint(0, Dz_off + 1)
        y0 = np.random.randint(0, Dy_off + 1)
        x0 = np.random.randint(0, Dx_off + 1)

        raw_patch = crop["raw"][z0:z0+self.pz, y0:y0+self.py, x0:x0+self.px]
        y_patch = crop["label"][z0:z0+self.pz, y0:y0+self.py, x0:x0+self.px]

        x_patch = make_cnn_input(raw_patch)  # (C,Z,Y,X)

        # torch
        x = torch.from_numpy(x_patch)                      # float32
        y = torch.from_numpy(y_patch.astype(np.int64))     # long (Z,Y,X)
        return x, y

In [8]:
# -----------------------------
# 3D U-Net (small & clean)
# -----------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.LeakyReLU(0.1, inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class UNet3D(nn.Module):
    def __init__(self, in_ch, n_classes, base=32):
        super().__init__()
        self.enc1 = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool3d(2)

        self.enc2 = DoubleConv(base, base*2)
        self.pool2 = nn.MaxPool3d(2)

        self.enc3 = DoubleConv(base*2, base*4)
        self.pool3 = nn.MaxPool3d(2)

        self.bott = DoubleConv(base*4, base*8)

        self.up3 = nn.ConvTranspose3d(base*8, base*4, 2, stride=2)
        self.dec3 = DoubleConv(base*8, base*4)

        self.up2 = nn.ConvTranspose3d(base*4, base*2, 2, stride=2)
        self.dec2 = DoubleConv(base*4, base*2)

        self.up1 = nn.ConvTranspose3d(base*2, base, 2, stride=2)
        self.dec1 = DoubleConv(base*2, base)

        self.out = nn.Conv3d(base, n_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b  = self.bott(self.pool3(e3))

        d3 = self.up3(b)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.out(d1)


# -----------------------------
# Loss: CE + Soft Dice (multi-class)
# -----------------------------
def soft_dice_loss(logits, targets, num_classes=6, eps=1e-6):
    """
    logits: (B,C,Z,Y,X)
    targets: (B,Z,Y,X) int
    """
    probs = torch.softmax(logits, dim=1)
    onehot = F.one_hot(targets, num_classes=num_classes).permute(0,4,1,2,3).float()

    dims = (0,2,3,4)
    inter = torch.sum(probs * onehot, dims)
    denom = torch.sum(probs + onehot, dims)
    dice = (2*inter + eps) / (denom + eps)

    # ignore background? (optional) — here we include all; you can drop bg by dice[1:]
    loss = 1.0 - dice.mean()
    return loss


# -----------------------------
# Metrics (Dice per class)
# -----------------------------
@torch.no_grad()
def dice_per_class(pred, gt, num_classes=6, eps=1e-6):
    """
    pred, gt: (Z,Y,X) int (cpu numpy or torch)
    returns: list length num_classes
    """
    if isinstance(pred, torch.Tensor):
        pred = pred.cpu().numpy()
    if isinstance(gt, torch.Tensor):
        gt = gt.cpu().numpy()

    out = []
    for c in range(num_classes):
        p = (pred == c)
        g = (gt == c)
        inter = (p & g).sum()
        denom = p.sum() + g.sum()
        out.append((2*inter + eps) / (denom + eps))
    return out


# -----------------------------
# Sliding-window inference
# -----------------------------
@torch.no_grad()
def sliding_window_predict(model, raw_zyx: np.ndarray, patch_zyx=(32,128,128), stride_zyx=(16,64,64)):
    """
    raw_zyx: uint8 (Z,Y,X)
    returns: pred (Z,Y,X) uint8
    """
    model.eval()
    C = 3 if USE_GRADMAG2_AS_3RD_CH else 2

    pz, py, px = patch_zyx
    sz, sy, sx = stride_zyx
    Dz, Dy, Dx = raw_zyx.shape

    # score accumulation
    scores = np.zeros((NUM_CLASSES, Dz, Dy, Dx), dtype=np.float32)
    counts = np.zeros((Dz, Dy, Dx), dtype=np.float32)

    z_starts = list(range(0, max(Dz - pz, 0) + 1, sz))
    y_starts = list(range(0, max(Dy - py, 0) + 1, sy))
    x_starts = list(range(0, max(Dx - px, 0) + 1, sx))
    if z_starts[-1] != Dz - pz: z_starts.append(Dz - pz)
    if y_starts[-1] != Dy - py: y_starts.append(Dy - py)
    if x_starts[-1] != Dx - px: x_starts.append(Dx - px)

    for z0 in tqdm(z_starts, desc="Infer z"):
        for y0 in y_starts:
            for x0 in x_starts:
                patch_raw = raw_zyx[z0:z0+pz, y0:y0+py, x0:x0+px]
                x_patch = make_cnn_input(patch_raw)  # (C,Z,Y,X)
                x = torch.from_numpy(x_patch).unsqueeze(0).to(DEVICE)  # (1,C,Z,Y,X)

                logits = model(x)
                prob = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()  # (K,Z,Y,X)

                scores[:, z0:z0+pz, y0:y0+py, x0:x0+px] += prob
                counts[z0:z0+pz, y0:y0+py, x0:x0+px] += 1.0

    scores /= np.maximum(counts[None, ...], 1e-6)
    pred = np.argmax(scores, axis=0).astype(np.uint8)
    return pred

In [9]:
# -----------------------------
# Train one fold
# -----------------------------
def train_one_fold(fold_id, train_crops, val_crops):
    run_dir = os.path.join(OUT_DIR, f"fold_{fold_id}")
    os.makedirs(run_dir, exist_ok=True)

    in_ch = 3 if USE_GRADMAG2_AS_3RD_CH else 2
    model = UNet3D(in_ch=in_ch, n_classes=NUM_CLASSES, base=32).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    ce = nn.CrossEntropyLoss()

    train_ds = RandomPatchDataset(train_crops, n_patches=PATCHES_PER_EPOCH, patch_zyx=PATCH_ZYX)
    val_ds   = RandomPatchDataset(val_crops, n_patches=VAL_PATCHES, patch_zyx=PATCH_ZYX)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

    best_val = 1e9
    best_path = os.path.join(run_dir, "best.pt")

    for epoch in range(1, EPOCHS+1):
        t0 = time.time()
        model.train()
        tr_loss = 0.0

        for x, y in tqdm(train_loader, desc=f"[Fold {fold_id}] Train epoch {epoch}", leave=False):
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            optimizer.zero_grad()
            logits = model(x)
            loss = 0.5 * ce(logits, y) + 0.5 * soft_dice_loss(logits, y, num_classes=NUM_CLASSES)
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()

        tr_loss /= max(len(train_loader), 1)

        # val
        model.eval()
        va_loss = 0.0
        with torch.no_grad():
            for x, y in tqdm(val_loader, desc=f"[Fold {fold_id}] Val epoch {epoch}", leave=False):
                x = x.to(DEVICE)
                y = y.to(DEVICE)
                logits = model(x)
                loss = 0.5 * ce(logits, y) + 0.5 * soft_dice_loss(logits, y, num_classes=NUM_CLASSES)
                va_loss += loss.item()
        va_loss /= max(len(val_loader), 1)

        dt = time.time() - t0
        print(f"[Fold {fold_id}] Epoch {epoch:02d} | train={tr_loss:.4f} val={va_loss:.4f} | {dt/60:.2f} min")

        # save best
        if va_loss < best_val:
            best_val = va_loss
            torch.save({"model": model.state_dict(), "in_ch": in_ch}, best_path)

    print(f"[Fold {fold_id}] Best val loss: {best_val:.4f} saved to {best_path}")

    # Evaluate on full validation crop (first val crop only)
    ckpt = torch.load(best_path, map_location=DEVICE)
    model.load_state_dict(ckpt["model"])
    model.eval()

    val_crop = val_crops[0]
    pred = sliding_window_predict(model, val_crop["raw"], patch_zyx=PATCH_ZYX, stride_zyx=INFER_STRIDE)

    dices = dice_per_class(pred, val_crop["label"], num_classes=NUM_CLASSES)
    print(f"[Fold {fold_id}] Full-volume Dice per class:")
    for ci, d in enumerate(dices):
        print(f"  {ci}:{CLASS_NAMES[ci]}  Dice={d:.4f}")

    # Save a few slice visualizations
    vis_dir = os.path.join(run_dir, "vis")
    os.makedirs(vis_dir, exist_ok=True)
    Dz = val_crop["raw"].shape[0]
    zs = [Dz//4, Dz//2, (3*Dz)//4]

    for z in zs:
        fig = plt.figure(figsize=(18,6))
        plt.subplot(1,3,1); plt.title("Raw"); plt.imshow(val_crop["raw"][z], cmap="gray"); plt.axis("off")
        plt.subplot(1,3,2); plt.title("GT");  plt.imshow(val_crop["label"][z], cmap="tab10", vmin=0, vmax=9); plt.axis("off")
        plt.subplot(1,3,3); plt.title("Pred");plt.imshow(pred[z], cmap="tab10", vmin=0, vmax=9); plt.axis("off")
        plt.tight_layout()
        outp = os.path.join(vis_dir, f"val_{val_crop['id']}_z{z:04d}.png")
        plt.savefig(outp, dpi=200)
        plt.close(fig)

    return best_val




In [10]:
def main():
    print("DEVICE:", DEVICE)

    # ------------------------------------------------
    # 1. Open raw zarr
    # ------------------------------------------------
    raw_zarr = zarr.open(RAW_S0, mode="r")
    print("Raw shape:", raw_zarr.shape)

    # ------------------------------------------------
    # 2. Load ONLY ONE crop (for debug)
    # ------------------------------------------------
    print("\n===== Loading one crop for debug =====")
    all_crops = []

    cid = CROP_IDS[0]   # 只用第一个 crop，例如 crop292
    print(f"Loading {cid} ...")
    crop = load_one_crop(cid, raw_zarr)
    print("  shape:", crop["shape"], "labels:", np.unique(crop["label"]))

    all_crops.append(crop)

    # ------------------------------------------------
    # 3. Train & validate on the SAME crop (debug only)
    # ------------------------------------------------
    print("\n===== DEBUG TRAIN (no KFold) =====")

    train_crops = all_crops
    val_crops = all_crops

    train_one_fold(
        fold_id=0,
        train_crops=train_crops,
        val_crops=val_crops
    )

    print("\n===== DEBUG RUN FINISHED =====")


In [11]:
# def main():
#     print("DEVICE:", DEVICE)
#     raw_zarr = zarr.open(RAW_S0, mode="r")
#     print("Raw shape:", raw_zarr.shape)

#     # Load all crops into RAM (simple & robust; if OOM, we can stream per-epoch)
#     all_crops = []
#     print("\n===== Loading crops =====")
#     for cid in CROP_IDS:
#         print(f"Loading {cid}...")
#         c = load_one_crop(cid, raw_zarr)
#         print("  shape:", c["shape"], "labels:", np.unique(c["label"]))
#         all_crops.append(c)

#     # KFold on crop index
#     crop_indices = np.arange(len(all_crops))
#     kf = KFold(n_splits=len(all_crops), shuffle=True, random_state=SEED)

#     fold_losses = []
#     for fold_id, (tr_idx, va_idx) in enumerate(kf.split(crop_indices)):
#         train_crops = [all_crops[i] for i in tr_idx]
#         val_crops   = [all_crops[i] for i in va_idx]  # single crop

#         print("\n" + "="*60)
#         print(f"FOLD {fold_id} | train={[c['id'] for c in train_crops]} | val={[c['id'] for c in val_crops]}")
#         loss = train_one_fold(fold_id, train_crops, val_crops)
#         fold_losses.append(loss)

#     print("\n===== Done =====")
#     print("Fold best val losses:", fold_losses)


if __name__ == "__main__":
    main()


DEVICE: cpu
Raw shape: (1813, 4368, 20609)

===== Loading one crop for debug =====
Loading crop292 ...
  shape: (400, 400, 400) labels: [0 1 2 3 4 5]

===== DEBUG TRAIN (no KFold) =====


                                                                         

[Fold 0] Epoch 01 | train=1.1083 val=0.9969 | 36.70 min


                                                                         

[Fold 0] Epoch 02 | train=0.9495 val=0.8697 | 35.51 min


                                                                        

KeyboardInterrupt: 