In [10]:
import os
import cv2
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm


In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"


In [12]:

class HCDataset(Dataset):
    def __init__(self, root, csv_path, transform=None):
        self.root = root
        self.df = pd.read_csv(csv_path)
        self.transform = transform

        self.images = sorted([f for f in os.listdir(root) if f.endswith("_HC.png")])
        self.spacing = dict(zip(self.df["filename"], self.df["pixel size(mm)"]))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        fname = self.images[idx]

        img = cv2.imread(os.path.join(self.root, fname), 0).astype(np.float32) / 255.0
        img = np.stack([img]*3, axis=-1)

        mask = cv2.imread(
            os.path.join(self.root, fname.replace(".png", "_Annotation.png")), 0
        )
        mask = (mask > 0).astype(np.float32)

        if self.transform:
            aug = self.transform(image=img, mask=mask)
            img, mask = aug["image"], aug["mask"]

        return img, mask, self.spacing[fname], fname


In [13]:
train_tfms = A.Compose([
    A.Resize(224,224),
    A.Rotate(limit=20, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),
    ToTensorV2()
])


In [14]:
class SwinSeg(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model(
            "swin_tiny_patch4_window7_224",
            pretrained=True,
            features_only=True
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(768,256,2,2),
            nn.ReLU(),
            nn.ConvTranspose2d(256,64,2,2),
            nn.ReLU()
        )

        self.head = nn.Conv2d(64,1,1)

    def forward(self, x):
        f = self.encoder(x)[-1]          # (B, H', W', C)
        f = f.permute(0,3,1,2)           # (B, C, H', W')

        d = self.decoder(f)              # still low-res
        out = self.head(d)

        # ðŸ”¥ CRITICAL FIX: upsample to input size
        out = torch.nn.functional.interpolate(
            out,
            size=x.shape[2:],            # (H, W) = 224x224
            mode="bilinear",
            align_corners=False
        )

        return torch.sigmoid(out)


In [32]:
def dice_score_torch(pred, target, eps=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()

    inter = (pred * target).sum(dim=(1,2,3))
    union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))

    return ((2 * inter + eps) / (union + eps)).mean().item()


def iou_score_torch(pred, target, eps=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()

    inter = (pred * target).sum(dim=(1,2,3))
    union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) - inter

    return ((inter + eps) / (union + eps)).mean().item()


In [15]:
def dice_loss(pred, target, eps=1e-6):
    pred = pred.view(-1)
    target = target.view(-1)
    inter = (pred * target).sum()
    return 1 - (2*inter + eps) / (pred.sum() + target.sum() + eps)


In [16]:
dataset = HCDataset("training_set", "training_set_pixel_size_and_HC.csv", train_tfms)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
ds_train, ds_val = random_split(dataset, [train_size, val_size])

dl_train = DataLoader(ds_train, batch_size=4, shuffle=True)
dl_val   = DataLoader(ds_val, batch_size=4, shuffle=False)


In [17]:
model = SwinSeg().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)


In [18]:
for epoch in range(20):
    model.train()
    total = 0

    for img, mask, _, _ in dl_train:
        img = img.to(device)
        mask = mask.unsqueeze(1).to(device)

        pred = model(img)
        loss = dice_loss(pred, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += loss.item()

    print(f"Epoch {epoch+1:02d} | Train Loss {total/len(dl_train):.4f}")


Epoch 01 | Train Loss 0.9851
Epoch 02 | Train Loss 0.9850
Epoch 03 | Train Loss 0.9848
Epoch 04 | Train Loss 0.9847
Epoch 05 | Train Loss 0.9847
Epoch 06 | Train Loss 0.9845
Epoch 07 | Train Loss 0.9845
Epoch 08 | Train Loss 0.9846
Epoch 09 | Train Loss 0.9845
Epoch 10 | Train Loss 0.9846
Epoch 11 | Train Loss 0.9845
Epoch 12 | Train Loss 0.9845
Epoch 13 | Train Loss 0.9845
Epoch 14 | Train Loss 0.9845
Epoch 15 | Train Loss 0.9846
Epoch 16 | Train Loss 0.9846
Epoch 17 | Train Loss 0.9845
Epoch 18 | Train Loss 0.9845
Epoch 19 | Train Loss 0.9844
Epoch 20 | Train Loss 0.9844


In [19]:
def clean_mask(mask):
    mask = (mask > 0.5).astype(np.uint8)

    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask)
    largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
    mask = (labels == largest).astype(np.uint8)

    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5,5),np.uint8))
    return mask


In [20]:
def fit_ellipse_from_mask(mask):
    cnts,_ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if len(cnts)==0 or len(cnts[0])<5:
        return None
    (cx,cy),(MA,ma),angle = cv2.fitEllipse(max(cnts, key=cv2.contourArea))
    return cx, cy, MA/2, ma/2, angle


In [21]:
def ellipse_circumference(a,b):
    return math.pi*(3*(a+b)-math.sqrt((3*a+b)*(a+3*b)))


In [22]:
def compute_biometry(mask, spacing):
    ell = fit_ellipse_from_mask(mask)
    if ell is None:
        return None,None,None

    _,_,a,b,_ = ell
    a_mm = a * spacing
    b_mm = b * spacing

    HC  = ellipse_circumference(a_mm, b_mm)
    BPD = 2 * b_mm
    OFD = 2 * a_mm
    return HC, BPD, OFD


In [25]:
df = pd.read_csv("training_set_pixel_size_and_HC.csv")
gt_hc_dict = dict(zip(df["filename"], df["head circumference (mm)"]))


In [33]:
model.eval()

dice_vals = []
iou_vals  = []
pred_hc   = []
gt_hc     = []

with torch.no_grad():
    for img, gt_mask, spacing, fname in dl_val:
        img = img.to(device)
        gt_mask = gt_mask.unsqueeze(1).to(device)

        # ---- prediction ----
        pred = model(img)

        # ---- Dice / IoU (CORRECT) ----
        dice_vals.append(dice_score_torch(pred, gt_mask))
        iou_vals.append(iou_score_torch(pred, gt_mask))

        # ---- HC computation (per image) ----
        for b in range(img.size(0)):
            pred_np = pred[b,0].cpu().numpy()
            clean = clean_mask(pred_np)

            ellipse = fit_ellipse_from_mask(clean)
            if ellipse is None:
                continue

            _, _, a, b_ax, _ = ellipse
            hc_pred = ellipse_circumference(
                a * spacing[b].item(),
                b_ax * spacing[b].item()
            )

            hc_gt = gt_hc_dict[fname[b]]
            pred_hc.append(hc_pred)
            gt_hc.append(hc_gt)

dice_mean = np.mean(dice_vals)
iou_mean  = np.mean(iou_vals)

pred_hc = np.array(pred_hc)
gt_hc   = np.array(gt_hc)

hc_mae  = np.mean(np.abs(pred_hc - gt_hc))
hc_rmse = np.sqrt(np.mean((pred_hc - gt_hc)**2))

print(f"Validation Dice : {dice_mean:.4f}")
print(f"Validation IoU  : {iou_mean:.4f}")
print(f"Valid HC samples: {len(pred_hc)} / {len(ds_val)}")
print(f"HC MAE  (mm)    : {hc_mae:.2f}")
print(f"HC RMSE (mm)    : {hc_rmse:.2f}")



Validation Dice : 0.0152
Validation IoU  : 0.0076
Valid HC samples: 162 / 162
HC MAE  (mm)    : 158.67
HC RMSE (mm)    : 169.50
