Imports & paths

In [13]:
import os, math, json, numpy as np, torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode as IM

# root with your 80-frame .npz files
NPZ_DIR = r"D:/dataset/npz_80"

# fixed split as you requested
all_npz = sorted([f for f in os.listdir(NPZ_DIR) if f.endswith(".npz")])
train_files = all_npz[:210]
val_files   = all_npz[210:255]

len(train_files), len(val_files)


(210, 45)

Utilities (letterbox resize + paired aug)

In [2]:
from typing import Tuple

def letterbox_params(H:int, W:int, target_hw:Tuple[int,int]):
    th, tw = target_hw
    scale = min(th / H, tw / W)
    nh, nw = int(round(H * scale)), int(round(W * scale))
    pad_h = th - nh
    pad_w = tw - nw
    pad_top    = pad_h // 2
    pad_bottom = pad_h - pad_top
    pad_left   = pad_w // 2
    pad_right  = pad_w - pad_left
    return nh, nw, scale, (pad_left, pad_right, pad_top, pad_bottom)

def apply_letterbox(img_t: torch.Tensor, target_hw:Tuple[int,int]):
    """
    img_t: (C,H,W) in [0,1]
    returns (C,th,tw), meta (H,W, nh,nw, scale, pads)
    """
    C, H, W = img_t.shape
    nh, nw, scale, pads = letterbox_params(H, W, target_hw)
    th, tw = target_hw
    img_t = F.interpolate(img_t.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False).squeeze(0)
    pl, pr, pt, pb = pads
    img_t = F.pad(img_t, (pl, pr, pt, pb), value=0.0)
    return img_t, (H, W, nh, nw, scale, pads)

def undo_letterbox(mask_t: torch.Tensor, meta):
    """
    mask_t: (1,th,tw) float/binary after letterbox
    meta from apply_letterbox
    returns (1,H,W) in original size (nearest)
    """
    H, W, nh, nw, scale, (pl, pr, pt, pb) = meta
    # remove padding
    m = mask_t[..., pt:pt+nh, pl:pl+nw]
    # resize back
    m = F.interpolate(m, size=(H, W), mode="nearest")
    return m


Dataset (positive frames only, 2D)

In [3]:
class SegPositiveNPZ2D(Dataset):
    """
    Yields (img, mask, meta) for frames where label in {1,2}.
    - img: (1, th, tw) normalized to [-1,1]
    - mask: (1, th, tw) in {0,1}
    - meta: dict with file, frame_idx, pixel_spacing, original HW, letterbox meta
    """
    def __init__(self, npz_dir, files, target_hw=(288, 352), augment=True):
        self.npz_dir   = npz_dir
        self.files     = files
        self.target_hw = target_hw
        self.augment   = augment
        self.items     = []  # (fname, frame_idx)

        for f in files:
            case = np.load(os.path.join(npz_dir, f), mmap_mode="r")
            labels = case["label"].astype(np.int64)
            pos = np.where((labels==1)|(labels==2))[0]
            for i in pos:
                self.items.append((f, int(i)))

    def __len__(self): return len(self.items)

    def _load_frame(self, case, idx):
        img = case["image"][idx].astype(np.float32)   # (H,W)
        msk = case["mask"][idx].astype(np.uint8)      # (H,W) values {0,1,2}
        # per-frame min-max -> [0,1]
        mn, mx = img.min(), img.max()
        img = (img - mn) / (mx - mn + 1e-8)
        # map {1,2} -> 1
        msk = (msk > 0).astype(np.float32)
        return img, msk

    def __getitem__(self, idx):
        fname, i = self.items[idx]
        case = np.load(os.path.join(self.npz_dir, fname), mmap_mode="r")

        img_np, msk_np = self._load_frame(case, i)
        H, W = img_np.shape

        # to torch
        img_t = torch.from_numpy(img_np)[None, ...]  # (1,H,W)
        msk_t = torch.from_numpy(msk_np)[None, ...]  # (1,H,W)

        # --- paired augmentations (keep modest) ---
        if self.augment:
            # horizontal flip
            if torch.rand(1).item() < 0.5:
                img_t = torch.flip(img_t, dims=[2])
                msk_t = torch.flip(msk_t, dims=[2])
            # small random rotation (-15..15)
            if torch.rand(1).item() < 0.5:
                ang = (torch.rand(1).item() * 30.0) - 15.0
                img_t = TF.rotate(img_t, ang, interpolation=IM.BILINEAR, fill=0.0)
                msk_t = TF.rotate(msk_t, ang, interpolation=IM.NEAREST,  fill=0.0)
            # slight noise
            if torch.rand(1).item() < 0.25:
                img_t = (img_t + 0.03*torch.randn_like(img_t)).clamp(0,1)

        # letterbox to target
        img_t, meta_lb = apply_letterbox(img_t, self.target_hw)   # (1,th,tw)
        msk_t, _       = apply_letterbox(msk_t, self.target_hw)   # (1,th,tw)

        # final normalize to [-1,1]
        img_t = (img_t - 0.5)/0.5

        meta = {
            "file": fname,
            "frame_idx": i,
            "pixel_spacing": float(case["pixel_spacing"]),
            "orig_hw": (H, W),
            "lb_meta": meta_lb,
        }
        return img_t.float(), msk_t.float(), meta


Small UNet (2D)

In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
        )
    def forward(self,x): return self.net(x)

class Down(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_c, out_c))
    def forward(self,x): return self.net(x)

class Up(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, in_c//2, 2, stride=2)
        self.conv = DoubleConv(in_c, out_c)
    def forward(self, x, skip):
        x = self.up(x)
        dh = skip.size(2) - x.size(2)
        dw = skip.size(3) - x.size(3)
        x = F.pad(x, [dw//2, dw-dw//2, dh//2, dh-dh//2])
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

class UNet2D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base=32):
        super().__init__()
        self.inc = DoubleConv(in_channels, base)
        self.d1  = Down(base, base*2)
        self.d2  = Down(base*2, base*4)
        self.d3  = Down(base*4, base*8)
        self.bott= DoubleConv(base*8, base*16)
        self.u3  = Up(base*16, base*8)
        self.u2  = Up(base*8,  base*4)
        self.u1  = Up(base*4,  base*2)
        self.u0  = Up(base*2,  base)
        self.out = nn.Conv2d(base, out_channels, 1)
    def forward(self,x):
        x0 = self.inc(x)
        x1 = self.d1(x0)
        x2 = self.d2(x1)
        x3 = self.d3(x2)
        xb = self.bott(x3)
        x  = self.u3(xb, x3)
        x  = self.u2(x,  x2)
        x  = self.u1(x,  x1)
        x  = self.u0(x,  x0)
        return self.out(x)  # logits (B,1,H,W)


Loss & metric (Dice + BCE)

In [5]:
class DiceBCELoss(nn.Module):
    def __init__(self, bce_weight=0.5):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.bce_weight = bce_weight
    def forward(self, logits, target):
        bce = self.bce(logits, target)
        p = torch.sigmoid(logits)
        inter = (p*target).sum(dim=(1,2,3))
        denom = p.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) + 1e-6
        dice = 1 - (2*inter/denom).mean()
        return self.bce_weight*bce + (1-self.bce_weight)*dice

@torch.no_grad()
def dice_from_logits(logits, target, thr=0.5):
    p = (torch.sigmoid(logits) >= thr).float()
    inter = (p*target).sum(dim=(1,2,3))
    denom = p.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) + 1e-6
    return (2*inter/denom).mean().item()


Dataloaders

In [6]:
target_hw = (288, 352)  # keep this modest for your GPU
train_ds = SegPositiveNPZ2D(NPZ_DIR, train_files, target_hw=target_hw, augment=True)
val_ds   = SegPositiveNPZ2D(NPZ_DIR, val_files,   target_hw=target_hw, augment=False)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False, num_workers=0, pin_memory=True)


Train (AMP, early stop on val Dice)

In [8]:
import os
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

# ============================================================
# Dataset
# ============================================================
class NPZSegDataset(Dataset):
    def __init__(self, npz_dir, files, transform=None):
        self.paths = [os.path.join(npz_dir, f) for f in files]
        self.transform = transform

        # Pre-read only shapes using mmap, not entire arrays
        self.index_map = []  # (file_idx, frame_idx)
        for file_idx, path in enumerate(self.paths):
            with np.load(path, mmap_mode="r") as case:
                n_frames = case["image"].shape[0]
                for frame_idx in range(n_frames):
                    self.index_map.append((file_idx, frame_idx))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        file_idx, frame_idx = self.index_map[idx]
        path = self.paths[file_idx]

        # Load only one frame with mmap
        with np.load(path, mmap_mode="r") as case:
            img = case["image"][frame_idx].astype(np.float32, copy=False)
            msk = case["mask"][frame_idx].astype(np.float32, copy=False)

        # normalize per-frame
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)

        img = torch.from_numpy(img).unsqueeze(0)  # (1,H,W)
        msk = torch.from_numpy(msk).unsqueeze(0)  # (1,H,W)

        # resize to 224x224
        img = F.interpolate(img.unsqueeze(0), size=(224,224),
                            mode="bilinear", align_corners=False).squeeze(0)
        msk = F.interpolate(msk.unsqueeze(0), size=(224,224),
                            mode="nearest").squeeze(0)

        # apply transforms (if any)
        if self.transform:
            img = self.transform(img)

        return img, msk



In [9]:
# ============================================================
# U-Net Model
# ============================================================
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNet2D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_channels, base)
        self.enc2 = ConvBlock(base, base*2)
        self.enc3 = ConvBlock(base*2, base*4)
        self.enc4 = ConvBlock(base*4, base*8)

        self.pool = nn.MaxPool2d(2,2)
        self.bottleneck = ConvBlock(base*8, base*16)

        self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, stride=2)
        self.dec4 = ConvBlock(base*16, base*8)

        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock(base*8, base*4)

        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)

        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock(base*2, base)

        self.out_conv = nn.Conv2d(base, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.out_conv(d1)



In [10]:
# ============================================================
# Loss (BCE + Dice)
# ============================================================
class BCEDiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.smooth = smooth

    def forward(self, logits, targets):
        bce = self.bce(logits, targets)
        probs = torch.sigmoid(logits)
        num = (probs*targets).sum(dim=(2,3))*2 + self.smooth
        den = probs.sum(dim=(2,3)) + targets.sum(dim=(2,3)) + self.smooth
        dice = 1 - (num/den).mean()
        return bce + dice


In [11]:
# ============================================================
# Training
# ============================================================
def train_one_epoch(model, loader, optimizer, scaler, criterion, device):
    model.train()
    total_loss = 0
    for img, msk in tqdm(loader, desc="Train", leave=False):
        img, msk = img.to(device), msk.to(device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(device=="cuda")):
            logits = model(img)
            loss   = criterion(logits, msk)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()*img.size(0)

    return total_loss/len(loader.dataset)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for img, msk in tqdm(loader, desc="Val", leave=False):
            img, msk = img.to(device), msk.to(device)
            with torch.cuda.amp.autocast(enabled=(device=="cuda")):
                logits = model(img)
                loss   = criterion(logits, msk)
            total_loss += loss.item()*img.size(0)

    return total_loss/len(loader.dataset)



In [12]:
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
    npz_dir = "D:/dataset/npz_80"
    all_files = sorted(os.listdir(npz_dir))
    train_files = all_files[:210]
    val_files   = all_files[210:255]

    # augmentations (optional)
    train_transform = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
    ])
    val_transform = None

    train_dataset = NPZSegDataset(npz_dir, train_files, transform=train_transform)
    val_dataset   = NPZSegDataset(npz_dir, val_files, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, pin_memory=True)
    val_loader   = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0, pin_memory=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    model = UNet2D(in_channels=1, out_channels=1, base=32).to(device)

    criterion = BCEDiceLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=3, factor=0.5)

    scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

    # ================================
    # Early Stopping + Checkpointing
    # ================================
    best_val_loss = float("inf")
    patience = 5
    epochs_no_improve = 0

    n_epochs = 20
    for epoch in range(1, n_epochs+1):
        print(f"\nEpoch {epoch}/{n_epochs}")
        train_loss = train_one_epoch(model, train_loader, optimizer, scaler, criterion, device)
        val_loss   = validate(model, val_loader, criterion, device)
        scheduler.step(val_loss)

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "best_model.pth")
            print("✅ New best model saved!")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} epochs.")

        # Early stopping
        if epochs_no_improve >= patience:
            print("⏹ Early stopping triggered!")
            break

    # Save last model anyway
    torch.save(model.state_dict(), "last_model.pth")
    print("Training finished. Best model saved as best_model.pth")


  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))


Using device: cuda

Epoch 1/20


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
                                                      

Train Loss: 1.0538 | Val Loss: 0.9889
✅ New best model saved!

Epoch 2/20


                                                            

Train Loss: 1.0232 | Val Loss: 0.9572
✅ New best model saved!

Epoch 3/20


                                                            

Train Loss: 1.0100 | Val Loss: 0.9235
✅ New best model saved!

Epoch 4/20


                                                            

Train Loss: 0.9970 | Val Loss: 0.9106
✅ New best model saved!

Epoch 5/20


                                                            

Train Loss: 0.9812 | Val Loss: 0.9001
✅ New best model saved!

Epoch 6/20


                                                            

Train Loss: 0.9633 | Val Loss: 0.8893
✅ New best model saved!

Epoch 7/20


                                                            

Train Loss: 0.9388 | Val Loss: 0.8664
✅ New best model saved!

Epoch 8/20


                                                            

KeyboardInterrupt: 