# Baseline U-Net

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os, glob, time, csv, random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import torchvision.transforms.functional as TF

In [None]:
# SETUP

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda":
    torch.cuda.manual_seed_all(SEED)


Device: cuda


In [None]:
# 1. DATASET DISCOVERY


DATA_ROOT = "/content/drive/MyDrive/archive"

def discover_dirs(root):
    img_dirs, mask_dirs = [], []
    for organ in os.listdir(root):
        organ_dir = os.path.join(root, organ)
        if not os.path.isdir(organ_dir):
            continue
        img_dir  = os.path.join(organ_dir, "tissue images")
        mask_dir = os.path.join(organ_dir, "mask binary")
        if os.path.isdir(img_dir) and os.path.isdir(mask_dir):
            img_dirs.append(img_dir)
            mask_dirs.append(mask_dir)
            print("Using organ:", organ)
    return img_dirs, mask_dirs

IMAGE_DIRS, MASK_DIRS = discover_dirs(DATA_ROOT)



Using organ: mouse spleen
Using organ: mouse thymus
Using organ: mouse liver
Using organ: mouse muscle_tibia
Using organ: mouse heart
Using organ: mouse kidney
Using organ: mouse femur
Using organ: mouse fat (white and brown)_subscapula
Using organ: human umbilical cord
Using organ: human tonsile
Using organ: human tongue
Using organ: human testis
Using organ: human salivory gland
Using organ: human spleen
Using organ: human oesophagus
Using organ: human placenta
Using organ: human pancreas
Using organ: human peritoneum
Using organ: human pylorus
Using organ: human rectum
Using organ: human muscle
Using organ: human liver
Using organ: human melanoma
Using organ: human cerebellum
Using organ: human kidney
Using organ: human lung
Using organ: human epiglottis
Using organ: human cardia
Using organ: human brain
Using organ: human jejunum
Using organ: human bladder


In [None]:
# 2. NuInsSeg Dataset


class NuInsSegDataset(Dataset):
    def __init__(self, image_dirs, mask_dirs, image_size=256, augment=True):
        self.image_paths = []
        self.mask_paths  = []
        self.image_size  = image_size
        self.augment     = augment

        self.resize = T.Resize((image_size, image_size))
        self.to_tensor = T.ToTensor()

        for img_dir, m_dir in zip(image_dirs, mask_dirs):
            imgs  = sorted(glob.glob(os.path.join(img_dir, "**/*.png"), recursive=True))
            masks = sorted(glob.glob(os.path.join(m_dir, "**/*.png"), recursive=True))
            mask_map = {os.path.basename(m): m for m in masks}

            for img in imgs:
                fname = os.path.basename(img)
                if fname in mask_map:
                    self.image_paths.append(img)
                    self.mask_paths.append(mask_map[fname])
                else:
                    print(" Missing mask:", img)

        print(f"Loaded {len(self.image_paths)} image-mask pairs")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, i):
        img  = Image.open(self.image_paths[i]).convert("RGB")
        mask = Image.open(self.mask_paths[i]).convert("L")

        img  = self.resize(img)
        mask = self.resize(mask)

        if self.augment:
            if random.random() < 0.5:
                img = TF.hflip(img); mask = TF.hflip(mask)
            if random.random() < 0.5:
                img = TF.vflip(img); mask = TF.vflip(mask)

        img  = self.to_tensor(img)
        mask = self.to_tensor(mask)
        mask = (mask > 0.5).float()

        return img, mask


dataset = NuInsSegDataset(IMAGE_DIRS, MASK_DIRS, image_size=256)

n_total = len(dataset)
n_val = int(n_total * 0.2)
n_train = n_total - n_val

train_ds, val_ds = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=2, shuffle=False)

print(f"Total={n_total}, Train={n_train}, Val={n_val}")

Loaded 665 image-mask pairs
Total=665, Train=532, Val=133


In [None]:
# 3. BASELINE U-NET


class DoubleConv(nn.Module):
    def __init__(self, a, b):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(a, b, 3, padding=1, bias=False),
            nn.BatchNorm2d(b),
            nn.ReLU(True),
            nn.Conv2d(b, b, 3, padding=1, bias=False),
            nn.BatchNorm2d(b),
            nn.ReLU(True),
        )

    def forward(self, x):
        return self.block(x)

class UNetBaseline(nn.Module):
    def __init__(self, ch=32):
        super().__init__()
        self.inc = DoubleConv(3, ch)
        self.d1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(ch, ch*2))
        self.d2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(ch*2, ch*4))
        self.d3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(ch*4, ch*8))

        self.u1 = nn.ConvTranspose2d(ch*8, ch*4, 2, 2)
        self.c1 = DoubleConv(ch*8, ch*4)
        self.u2 = nn.ConvTranspose2d(ch*4, ch*2, 2, 2)
        self.c2 = DoubleConv(ch*4, ch*2)
        self.u3 = nn.ConvTranspose2d(ch*2, ch, 2, 2)
        self.c3 = DoubleConv(ch*2, ch)

        self.out = nn.Conv2d(ch, 1, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.d1(x1)
        x3 = self.d2(x2)
        x4 = self.d3(x3)

        x  = torch.cat([self.u1(x4), x3], 1); x = self.c1(x)
        x  = torch.cat([self.u2(x), x2], 1); x = self.c2(x)
        x  = torch.cat([self.u3(x), x1], 1); x = self.c3(x)
        return self.out(x)


model_name = "UNetBaseline_Colab"
model = UNetBaseline().to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)



# 4. METRICS


def dice(pred, true, eps=1e-7):
    pred = (pred > 0.5).float()
    true = (true > 0.5).float()
    inter = (pred * true).sum((1,2,3))
    union = pred.sum((1,2,3)) + true.sum((1,2,3))
    return ((2*inter + eps)/(union+eps)).mean().item()

def iou(pred, true, eps=1e-7):
    pred = (pred>0.5).float()
    true = (true>0.5).float()
    inter = (pred*true).sum((1,2,3))
    union = pred.sum((1,2,3)) + true.sum((1,2,3)) - inter
    return ((inter+eps)/(union+eps)).mean().item()

def accuracy(pred, true):
    return ((pred>0.5).float() == (true>0.5).float()).float().mean().item()

In [None]:
# 5. TRAIN/EVAL LOOPS


def train_epoch():
    model.train()
    tl, td = 0, 0
    for img, mask in train_loader:
        img, mask = img.to(device), mask.to(device)
        optimizer.zero_grad()
        logits = model(img)
        loss = criterion(logits, mask)
        loss.backward()
        optimizer.step()
        prob = torch.sigmoid(logits)
        tl += loss.item()
        td += dice(prob, mask)
    return tl/len(train_loader), td/len(train_loader)

def eval_epoch():
    model.eval()
    d, j, a = 0, 0, 0
    with torch.no_grad():
        for img, mask in val_loader:
            img, mask = img.to(device), mask.to(device)
            logits = model(img)
            prob = torch.sigmoid(logits)
            d += dice(prob, mask)
            j += iou(prob, mask)
            a += accuracy(prob, mask)
    n = len(val_loader)
    return d/n, j/n, a/n


In [None]:

# 6. TRAINING LOOP (ALIGNED LOGGING)


epochs = 20
t_loss, t_dice = [], []
v_dice, v_iou, v_acc = [], [], []

print("\n--- BASELINE TRAINING STARTED ---\n")

for ep in range(1, epochs+1):
    start = time.time()

    L, D = train_epoch()
    vd, vi, va = eval_epoch()

    et = time.time() - start

    t_loss.append(L); t_dice.append(D)
    v_dice.append(vd); v_iou.append(vi); v_acc.append(va)

    print(
        f"[{model_name}] "
        f"Epoch {ep:>2}/{epochs} | "
        f"Loss: {L:>7.4f} | "
        f"TrainDice: {D:>7.4f} | "
        f"ValDice: {vd:>7.4f} | "
        f"ValIoU: {vi:>7.4f} | "
        f"ValAcc: {va:>7.4f} | "
        f"Time: {et:>6.2f}s"
    )

print("\n--- TRAINING COMPLETE ---\n")


--- BASELINE TRAINING STARTED ---

[UNetBaseline_Colab] Epoch  1/20 | Loss:  0.4293 | TrainDice:  0.5825 | ValDice:  0.6129 | ValIoU:  0.4720 | ValAcc:  0.8419 | Time: 340.64s
[UNetBaseline_Colab] Epoch  2/20 | Loss:  0.3339 | TrainDice:  0.6513 | ValDice:  0.6504 | ValIoU:  0.5087 | ValAcc:  0.9170 | Time:  24.38s
[UNetBaseline_Colab] Epoch  3/20 | Loss:  0.2834 | TrainDice:  0.6730 | ValDice:  0.6709 | ValIoU:  0.5334 | ValAcc:  0.8961 | Time:  23.61s
[UNetBaseline_Colab] Epoch  4/20 | Loss:  0.2447 | TrainDice:  0.7034 | ValDice:  0.7013 | ValIoU:  0.5635 | ValAcc:  0.9248 | Time:  23.71s
[UNetBaseline_Colab] Epoch  5/20 | Loss:  0.2274 | TrainDice:  0.7063 | ValDice:  0.6999 | ValIoU:  0.5631 | ValAcc:  0.9256 | Time:  23.51s
[UNetBaseline_Colab] Epoch  6/20 | Loss:  0.2058 | TrainDice:  0.7244 | ValDice:  0.7242 | ValIoU:  0.5843 | ValAcc:  0.9224 | Time:  23.11s
[UNetBaseline_Colab] Epoch  7/20 | Loss:  0.1925 | TrainDice:  0.7312 | ValDice:  0.7526 | ValIoU:  0.6217 | ValAcc:  

In [None]:
# 7. MADs / FLOPs COUNTER


def count_mads(model, input_size=(1,3,256,256), device="cuda"):
    mads = []
    hooks = []

    def conv_hook(m, inp, out):
        if isinstance(m, nn.Conv2d):
            x = inp[0]
            Cin = x.shape[1]
            Cout = m.out_channels
            kH, kW = m.kernel_size
            Hout, Wout = out.shape[2], out.shape[3]
            groups = m.groups
            mads.append(Cout * (Cin//groups) * kH * kW * Hout * Wout)

    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            hooks.append(module.register_forward_hook(conv_hook))

    dummy = torch.randn(*input_size).to(device)
    model.eval()
    with torch.no_grad():
        _ = model(dummy)

    for h in hooks:
        h.remove()

    return sum(mads)


In [None]:
# 8. RUNTIME, MEMORY, PARAMS


params = sum(p.numel() for p in model.parameters())
mads   = count_mads(model, device=device)

print(f"[{model_name}] Params: {params}")
print(f"[{model_name}] MADs  : {mads:.3g}")

dummy = torch.randn(1,3,256,256).to(device)

# Inference time
model.eval()
if device=="cuda":
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()

with torch.no_grad():
    for _ in range(5):
        _ = model(dummy)

if device=="cuda": torch.cuda.synchronize()
t0 = time.time()
with torch.no_grad():
    for _ in range(50):
        _ = model(dummy)
if device=="cuda": torch.cuda.synchronize()
t1 = time.time()

infer_ms = (t1 - t0)/50*1000
mem_infer = torch.cuda.max_memory_allocated()/(1024**2) if device=="cuda" else 0

print(f"[{model_name}] Inference: {infer_ms:.3f} ms | Memory: {mem_infer:.2f} MB")

# Training step time
crit_step = nn.BCEWithLogitsLoss()
opt2 = optim.SGD(model.parameters(), lr=1e-3)
dummy_mask = torch.rand(1,1,256,256).to(device)

if device=="cuda":
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()

for _ in range(5):
    opt2.zero_grad(); out=model(dummy)
    loss=crit_step(out,dummy_mask)
    loss.backward(); opt2.step()

if device=="cuda": torch.cuda.synchronize()

t0=time.time()
for _ in range(50):
    opt2.zero_grad(); out=model(dummy)
    loss=crit_step(out,dummy_mask)
    loss.backward(); opt2.step()
if device=="cuda": torch.cuda.synchronize()
t1=time.time()

train_ms=(t1-t0)/50*1000
mem_train=torch.cuda.max_memory_allocated()/(1024**2) if device=="cuda" else 0

peak_mem=max(mem_infer,mem_train)

print(f"[{model_name}] TrainStep: {train_ms:.3f} ms | Memory: {peak_mem:.2f} MB")

[UNetBaseline_Colab] Params: 1927009
[UNetBaseline_Colab] MADs  : 8.82e+09
[UNetBaseline_Colab] Inference: 6.292 ms | Memory: 79.18 MB
[UNetBaseline_Colab] TrainStep: 17.614 ms | Memory: 225.49 MB


In [None]:
# 9. CSV EXPORT


final_train_dice=t_dice[-1]
final_val_dice  =v_dice[-1]
final_val_iou   =v_iou[-1]
final_val_acc   =v_acc[-1]

csv_path=f"/content/{model_name}_metrics.csv"
with open(csv_path,"w",newline="") as f:
    wr=csv.writer(f)
    wr.writerow([
        "Model","Params","MADs","Infer_ms","Train_ms","Memory_MB",
        "FinalTrainDice","FinalValDice","FinalValIoU","FinalValAcc"
    ])
    wr.writerow([
        model_name,params,mads,infer_ms,train_ms,peak_mem,
        final_train_dice,final_val_dice,final_val_iou,final_val_acc
    ])

print("CSV saved:",csv_path)



# 10. SAVE WEIGHTS

torch.save(model.state_dict(), f"/content/{model_name}.pth")
print("Model saved:", f"/content/{model_name}.pth")

CSV saved: /content/UNetBaseline_Colab_metrics.csv
Model saved: /content/UNetBaseline_Colab.pth


# ORIENTED-1D U-Net

In [None]:
# 2. NuInsSeg Dataset

class NuInsSegDataset(Dataset):
    def __init__(self, image_dirs, mask_dirs, image_size=256, augment=True):
        self.image_paths = []
        self.mask_paths = []
        self.image_size = image_size
        self.augment = augment
        self.resize = T.Resize((image_size, image_size))
        self.to_tensor = T.ToTensor()

        for img_dir, m_dir in zip(image_dirs, mask_dirs):
            imgs = sorted(glob.glob(os.path.join(img_dir, "**/*.png"), recursive=True))
            masks = sorted(glob.glob(os.path.join(m_dir, "**/*.png"), recursive=True))
            mask_map = {os.path.basename(m): m for m in masks}

            for img in imgs:
                fname = os.path.basename(img)
                if fname in mask_map:
                    self.image_paths.append(img)
                    self.mask_paths.append(mask_map[fname])
                else:
                    print(" Missing mask:", img)

        print(f"Loaded {len(self.image_paths)} image-mask pairs")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, i):
        img = Image.open(self.image_paths[i]).convert("RGB")
        mask = Image.open(self.mask_paths[i]).convert("L")

        img = self.resize(img)
        mask = self.resize(mask)

        if self.augment:
            if random.random() < 0.5:
                img = TF.hflip(img); mask = TF.hflip(mask)
            if random.random() < 0.5:
                img = TF.vflip(img); mask = TF.vflip(mask)

        img = self.to_tensor(img)
        mask = self.to_tensor(mask)
        mask = (mask > 0.5).float()
        return img, mask


dataset = NuInsSegDataset(IMAGE_DIRS, MASK_DIRS, image_size=256)

n_total = len(dataset)
n_val = int(n_total * 0.2)
n_train = n_total - n_val

train_ds, val_ds = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=2, shuffle=False)

print(f"Total={n_total}, Train={n_train}, Val={n_val}")


Loaded 665 image-mask pairs
Total=665, Train=532, Val=133


In [None]:
# 3. ORIENTED DEPTHWISE 1D CONV (8 ANGLES)


class OrientedConv1D(nn.Module):
    """
    Depthwise 1D conv applied along 8 rotated directions.
    (Simple PyTorch implementation; not CUDA fast, but works.)
    """
    def __init__(self, channels, kernel_size=15, angles=8):
        super().__init__()
        self.angles = angles
        self.kernel_size = kernel_size

        # depthwise 1D convolution
        self.conv1d = nn.Conv1d(
            channels, channels,
            kernel_size,
            groups=channels,
            padding=kernel_size // 2,
            bias=False
        )

    def forward(self, x):
        B, C, H, W = x.shape
        out_sum = 0

        for a in range(self.angles):
            angle = a * (180 / self.angles)

            # rotate image
            x_rot = TF.rotate(x, angle, interpolation=TF.InterpolationMode.BILINEAR)

            # flatten to 1D along width
            x1d = x_rot.view(B, C, -1)  # shape (B, C, H*W)

            y = self.conv1d(x1d)       # depthwise 1D conv
            y = y.view(B, C, H, W)     # restore 2D shape

            out_sum = out_sum + y

        return out_sum / self.angles


In [None]:
# 4. ORIENTED 1D U-NET


class DoubleConvO1D(nn.Module):
    def __init__(self, a, b):
        super().__init__()
        self.block = nn.Sequential(
            OrientedConv1D(a, kernel_size=15, angles=8),
            nn.BatchNorm2d(a),
            nn.ReLU(True),
            nn.Conv2d(a, b, 3, padding=1),
            nn.BatchNorm2d(b),
            nn.ReLU(True),
        )

    def forward(self, x):
        return self.block(x)

class UNetOriented1D(nn.Module):
    def __init__(self, ch=32):
        super().__init__()
        self.inc = DoubleConvO1D(3, ch)
        self.d1 = nn.Sequential(nn.MaxPool2d(2), DoubleConvO1D(ch, ch*2))
        self.d2 = nn.Sequential(nn.MaxPool2d(2), DoubleConvO1D(ch*2, ch*4))
        self.d3 = nn.Sequential(nn.MaxPool2d(2), DoubleConvO1D(ch*4, ch*8))

        self.u1 = nn.ConvTranspose2d(ch*8, ch*4, 2, 2)
        self.c1 = DoubleConvO1D(ch*8, ch*4)
        self.u2 = nn.ConvTranspose2d(ch*4, ch*2, 2, 2)
        self.c2 = DoubleConvO1D(ch*4, ch*2)
        self.u3 = nn.ConvTranspose2d(ch*2, ch, 2, 2)
        self.c3 = DoubleConvO1D(ch*2, ch)

        self.out = nn.Conv2d(ch, 1, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.d1(x1)
        x3 = self.d2(x2)
        x4 = self.d3(x3)

        x = torch.cat([self.u1(x4), x3], 1)
        x = self.c1(x)
        x = torch.cat([self.u2(x), x2], 1)
        x = self.c2(x)
        x = torch.cat([self.u3(x), x1], 1)
        x = self.c3(x)
        return self.out(x)


model_name = "UNet_Oriented1D_Colab"
model = UNetOriented1D().to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
# 5. METRICS


def dice(pred, true, eps=1e-7):
    pred = (pred > 0.5).float()
    true = (true > 0.5).float()
    inter = (pred * true).sum((1,2,3))
    union = pred.sum((1,2,3)) + true.sum((1,2,3))
    return ((2*inter + eps) / (union + eps)).mean().item()

def iou(pred, true, eps=1e-7):
    pred = (pred > 0.5).float()
    true = (true > 0.5).float()
    inter = (pred * true).sum((1,2,3))
    union = pred.sum((1,2,3)) + true.sum((1,2,3)) - inter
    return ((inter + eps) / (union + eps)).mean().item()

def accuracy(pred, true):
    return ((pred > 0.5).float() == (true > 0.5).float()).float().mean().item()

In [None]:
# 6. TRAIN / EVAL LOOPS


def train_epoch():
    model.train()
    tl, td = 0, 0
    for img, mask in train_loader:
        img, mask = img.to(device), mask.to(device)
        optimizer.zero_grad()
        logits = model(img)
        loss = criterion(logits, mask)
        loss.backward()
        optimizer.step()
        prob = torch.sigmoid(logits)
        tl += loss.item()
        td += dice(prob, mask)
    return tl/len(train_loader), td/len(train_loader)

def eval_epoch():
    model.eval()
    d, j, a = 0, 0, 0
    with torch.no_grad():
        for img, mask in val_loader:
            img, mask = img.to(device), mask.to(device)
            logits = model(img)
            prob = torch.sigmoid(logits)
            d += dice(prob, mask)
            j += iou(prob, mask)
            a += accuracy(prob, mask)
    n = len(val_loader)
    return d/n, j/n, a/n


In [None]:
# 7. TRAINING LOOP WITH ALIGNED LOGGING


epochs = 20
t_loss, t_dice = [], []
v_dice, v_iou, v_acc = [], [], []

print("\n----- TRAINING STARTED (ORIENTED 1D) -----\n")

for ep in range(1, epochs+1):
    ep_start = time.time()

    L, D = train_epoch()
    vd, vi, va = eval_epoch()

    ep_time = time.time() - ep_start

    t_loss.append(L); t_dice.append(D)
    v_dice.append(vd); v_iou.append(vi); v_acc.append(va)

    print(
        f"[{model_name}] "
        f"Epoch {ep:>2}/{epochs} | "
        f"Loss: {L:>7.4f} | "
        f"TrainDice: {D:>7.4f} | "
        f"ValDice: {vd:>7.4f} | "
        f"ValIoU: {vi:>7.4f} | "
        f"ValAcc: {va:>7.4f} | "
        f"Time: {ep_time:>6.2f}s"
    )

print("\n----- TRAINING COMPLETE -----\n")


----- TRAINING STARTED (ORIENTED 1D) -----

[UNet_Oriented1D_Colab] Epoch  1/20 | Loss:  0.5274 | TrainDice:  0.0582 | ValDice:  0.0457 | ValIoU:  0.0270 | ValAcc:  0.8119 | Time:  61.11s
[UNet_Oriented1D_Colab] Epoch  2/20 | Loss:  0.4628 | TrainDice:  0.0279 | ValDice:  0.0444 | ValIoU:  0.0257 | ValAcc:  0.8279 | Time:  60.46s
[UNet_Oriented1D_Colab] Epoch  3/20 | Loss:  0.4382 | TrainDice:  0.0305 | ValDice:  0.0453 | ValIoU:  0.0260 | ValAcc:  0.8201 | Time:  61.09s
[UNet_Oriented1D_Colab] Epoch  4/20 | Loss:  0.4296 | TrainDice:  0.0296 | ValDice:  0.0363 | ValIoU:  0.0206 | ValAcc:  0.8372 | Time:  66.83s
[UNet_Oriented1D_Colab] Epoch  5/20 | Loss:  0.4229 | TrainDice:  0.0348 | ValDice:  0.0428 | ValIoU:  0.0241 | ValAcc:  0.8349 | Time:  62.46s
[UNet_Oriented1D_Colab] Epoch  6/20 | Loss:  0.4222 | TrainDice:  0.0366 | ValDice:  0.0354 | ValIoU:  0.0198 | ValAcc:  0.8371 | Time:  60.89s
[UNet_Oriented1D_Colab] Epoch  7/20 | Loss:  0.4218 | TrainDice:  0.0325 | ValDice:  0.0404

In [None]:
# 8. FLOPs / MADs COUNTER (SAME AS BASELINE)


def count_mads(model, input_size=(1,3,256,256), device="cuda"):
    hooks = []
    mads = []

    def conv_hook(m, inp, out):
        if isinstance(m, nn.Conv2d):
            x = inp[0]
            Cin = x.shape[1]
            Cout = m.out_channels
            kH, kW = m.kernel_size
            Hout, Wout = out.shape[2], out.shape[3]
            groups = m.groups
            mads.append(Cout * (Cin//groups) * kH * kW * Hout * Wout)

    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            hooks.append(module.register_forward_hook(conv_hook))

    dummy = torch.randn(*input_size).to(device)
    model.eval()
    with torch.no_grad():
        _ = model(dummy)

    for h in hooks:
        h.remove()

    return sum(mads)



In [None]:
# 9. RUNTIME, MEMORY, PARAMS, CSV EXPORT


# Param count
params = sum(p.numel() for p in model.parameters())

# MADs
mads = count_mads(model, device=device)
print(f"[{model_name}] Params: {params}")
print(f"[{model_name}] MADs  : {mads:.3g}")

# Inference time + GPU memory
dummy = torch.randn(1,3,256,256).to(device)

model.eval()
if device == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()

with torch.no_grad():
    for _ in range(5):
        _ = model(dummy)

if device == "cuda":
    torch.cuda.synchronize()

t0 = time.time()
with torch.no_grad():
    for _ in range(50):
        _ = model(dummy)
if device == "cuda":
    torch.cuda.synchronize()
t1 = time.time()

infer_ms = (t1 - t0) / 50 * 1000.0
mem_infer = torch.cuda.max_memory_allocated()/(1024**2) if device == "cuda" else 0

print(f"[{model_name}] Inference: {infer_ms:.3f} ms, Memory: {mem_infer:.2f} MB")

# Training-step time
crit_step = nn.BCEWithLogitsLoss()
opt2 = optim.SGD(model.parameters(), lr=1e-3)
dummy_mask = torch.rand(1,1,256,256).to(device)

if device == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()

for _ in range(5):
    opt2.zero_grad()
    out = model(dummy)
    loss = crit_step(out, dummy_mask)
    loss.backward()
    opt2.step()

if device == "cuda":
    torch.cuda.synchronize()

t0 = time.time()
for _ in range(50):
    opt2.zero_grad()
    out = model(dummy)
    loss = crit_step(out, dummy_mask)
    loss.backward()
    opt2.step()
if device == "cuda":
    torch.cuda.synchronize()
t1 = time.time()

train_ms = (t1 - t0) / 50 * 1000.0
mem_train = torch.cuda.max_memory_allocated()/(1024**2) if device == "cuda" else 0

peak_mem = max(mem_infer, mem_train)

print(f"[{model_name}] TrainStep: {train_ms:.3f} ms, Memory: {peak_mem:.2f} MB")

[UNet_Oriented1D_Colab] Params: 960884
[UNet_Oriented1D_Colab] MADs  : 4.59e+09
[UNet_Oriented1D_Colab] Inference: 34.266 ms, Memory: 137.94 MB
[UNet_Oriented1D_Colab] TrainStep: 88.475 ms, Memory: 510.09 MB


In [None]:
# 10. CSV EXPORT


final_train_dice = t_dice[-1]
final_val_dice   = v_dice[-1]
final_val_iou    = v_iou[-1]
final_val_acc    = v_acc[-1]

csv_path = f"/content/{model_name}_metrics.csv"

with open(csv_path, "w", newline="") as f:
    wr = csv.writer(f)
    wr.writerow([
        "Model","Params","MADs",
        "Infer_ms","Train_ms","Memory_MB",
        "FinalTrainDice","FinalValDice",
        "FinalValIoU","FinalValAcc"
    ])
    wr.writerow([
        model_name,
        params,
        mads,
        infer_ms,
        train_ms,
        peak_mem,
        final_train_dice,
        final_val_dice,
        final_val_iou,
        final_val_acc
    ])

print(f"[{model_name}] CSV saved at: {csv_path}")



[UNet_Oriented1D_Colab] CSV saved at: /content/UNet_Oriented1D_Colab_metrics.csv


In [None]:
# 11. SAVE WEIGHTS


torch.save(model.state_dict(), f"/content/{model_name}.pth")
print("Model saved:", f"/content/{model_name}.pth")

Model saved: /content/UNet_Oriented1D_Colab.pth


# Metrcis Comparision

In [None]:
# COMPARISON SCRIPT: BASELINE vs ORIENTED-1D (NuInsSeg)


import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os


# 1. LOAD METRICS FROM CSV FILES


baseline_csv = "/content/UNetBaseline_Colab_metrics.csv"
oriented_csv = "/content/UNet_Oriented1D_Colab_metrics.csv"

df_b = pd.read_csv(baseline_csv)
df_o = pd.read_csv(oriented_csv)

df_b["ModelType"] = "Baseline"
df_o["ModelType"] = "Oriented1D"

df = pd.concat([df_b, df_o], ignore_index=True)
df


Unnamed: 0,Model,Params,MADs,Infer_ms,Train_ms,Memory_MB,FinalTrainDice,FinalValDice,FinalValIoU,FinalValAcc,ModelType
0,UNetBaseline_Colab,1927009,8816427008,6.291833,17.614226,225.485352,0.777706,0.774445,0.650614,0.938533,Baseline
1,UNet_Oriented1D_Colab,960884,4588568576,34.265575,88.474669,510.087402,0.061143,0.068691,0.040263,0.839196,Oriented1D


In [None]:

# 2. METRICS TO COMPARE


metrics = [
    "Params",
    "MADs",
    "Infer_ms",
    "Train_ms",
    "Memory_MB",
    "FinalTrainDice",
    "FinalValDice",
    "FinalValIoU",
    "FinalValAcc",
]

# Create output folder
out_dir = "/content/comparison_plots"
os.makedirs(out_dir, exist_ok=True)

print("Saving plots to:", out_dir)


Saving plots to: /content/comparison_plots


In [None]:

# 3. BAR PLOTS FOR EACH METRIC


for metric in metrics:
    plt.figure(figsize=(6,4))

    values = [
        df[df.ModelType == "Baseline"][metric].values[0],
        df[df.ModelType == "Oriented1D"][metric].values[0]
    ]

    plt.bar(["Baseline", "Oriented-1D"], values, color=["skyblue","orange"])
    plt.title(f"{metric} Comparison")
    plt.ylabel(metric)
    plt.grid(axis="y", linestyle="--", alpha=0.5)

    plt.savefig(f"{out_dir}/{metric}_comparison.png", dpi=200, bbox_inches="tight")
    plt.close()

print("All comparison plots saved!")


All comparison plots saved!


In [None]:

# 4. COMBINED TABLE SUMMARY (PRINTABLE)


summary = pd.DataFrame({
    "Metric": metrics,
    "Baseline": [df_b[m].values[0] for m in metrics],
    "Oriented1D": [df_o[m].values[0] for m in metrics],
})

summary["Difference (O1D - Base)"] = summary["Oriented1D"] - summary["Baseline"]

print("\n===== COMPARISON TABLE =====\n")
print(summary)
summary.to_csv(f"{out_dir}/comparison_summary.csv", index=False)
print("\nSaved:", f"{out_dir}/comparison_summary.csv")



===== COMPARISON TABLE =====

           Metric      Baseline    Oriented1D  Difference (O1D - Base)
0          Params  1.927009e+06  9.608840e+05            -9.661250e+05
1            MADs  8.816427e+09  4.588569e+09            -4.227858e+09
2        Infer_ms  6.291833e+00  3.426558e+01             2.797374e+01
3        Train_ms  1.761423e+01  8.847467e+01             7.086044e+01
4       Memory_MB  2.254854e+02  5.100874e+02             2.846021e+02
5  FinalTrainDice  7.777062e-01  6.114342e-02            -7.165628e-01
6    FinalValDice  7.744446e-01  6.869071e-02            -7.057538e-01
7     FinalValIoU  6.506139e-01  4.026259e-02            -6.103513e-01
8     FinalValAcc  9.385327e-01  8.391958e-01            -9.933688e-02

Saved: /content/comparison_plots/comparison_summary.csv


In [None]:
# 5. OPTIONAL: SAVE ALL PLOTS INTO A SINGLE PDF

from matplotlib.backends.backend_pdf import PdfPages

pdf_path = f"{out_dir}/comparison_plots.pdf"
with PdfPages(pdf_path) as pdf:
    for metric in metrics:
        img = plt.imread(f"{out_dir}/{metric}_comparison.png")
        plt.figure(figsize=(6,4))
        plt.imshow(img)
        plt.axis("off")
        pdf.savefig(bbox_inches="tight")
        plt.close()

print("Combined PDF saved:", pdf_path)


Combined PDF saved: /content/comparison_plots/comparison_plots.pdf
