In [None]:
import os 
import json 

ROOT = "/kaggle/input/arcade-dataset/arcade"
TRAIN_IMG_DIR = os.path.join(ROOT,"syntax","train","images")
TRAIN_ANN_JSON = os.path.join(ROOT,"syntax","train","annotations","train.json")
VAL_IMG_DIR = os.path.join(ROOT,"syntax","val","images")
VAL_ANN_JSON = os.path.join(ROOT,"syntax","val","annotations","val.json")
TEST_IMG_DIR = os.path.join(ROOT,"syntax","test","images")
TEST_ANN_JSON = os.path.join(ROOT,"syntax","test","annotations","test.json")

print("Number of train images : ",len(os.listdir(TRAIN_IMG_DIR)))
print("Number of val images : ",len(os.listdir(VAL_IMG_DIR)))
print("Number of test images : ",len(os.listdir(TEST_IMG_DIR)))

with open(TRAIN_ANN_JSON,"r") as f:
    train_data = json.load(f)
with open(VAL_ANN_JSON,"r") as f:
    val_data = json.load(f)
with open(TEST_ANN_JSON,"r") as f:
    test_data = json.load(f)

# mapping 
def build_imgid_to_anns(json_data):
    imgid_to_ann = {}
    for a in json_data["annotations"]:
        imgid_to_ann.setdefault(a["image_id"],[]).append(a)
    return imgid_to_ann
    
train_imgid_to_anns = build_imgid_to_anns(train_data)
val_imgid_to_anns = build_imgid_to_anns(val_data)
test_imgid_to_anns = build_imgid_to_anns(test_data)

In [None]:
import cv2

# convert polygon annotations into binary mask of shape(H,W)
def polygons_to_mask(ann,height,width):
    mask = np.zeros((height,width),dtype=np.uint8)
    for a in ann:
        seg = a.get("segmentation",[])  # seg is arrray of polygon annotation arrays - [[x1,y1,x2,y2..],[...]]
        for poly in seg:
            if not poly or len(poly)<6:
                continue
            pts = np.array(poly,dtype=np.float32).reshape(-1,2)
            pts = np.round(pts).astype(np.int32)

            # clip to image bounds before filling 
            pts[:,0]=np.clip(pts[:,0],0,width-1)
            pts[:,1]=np.clip(pts[:,1],0,height-1)
            cv2.fillPoly(mask,[pts],1)
    return mask

In [None]:
# dataset definition 
import torch 
from torch.utils.data import Dataset,DataLoader

class ArcadeDataset(Dataset):
    def __init__(self,img_dir,json_dict,imgid_to_anns,transform=None,return_meta=False):
        self.img_dir = img_dir
        self.json_dict = json_dict
        self.imgid_to_anns = imgid_to_anns
        self.transform = transform
        self.return_meta = return_meta
        self.images = json_dict["images"]
    def __len__(self):
        return len(self.images)
    def __getitem__(self,idx):
        imgrec = self.images[idx]
        img_id = imgrec["id"]
        fname = imgrec["file_name"]
        h,w = imgrec["height"],imgrec["width"]

        # load image 
        path = os.path.join(self.img_dir,fname)
        img = cv2.imread(path,cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(f"Image not found {path}")
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    
        # build binary mask from polygon 
        anns = self.imgid_to_anns.get(img_id,[])
        mask = polygons_to_mask(anns,h,w)

        # albumentations
        if self.transform is not None:
            out = self.transform(image=img,mask=mask)
            img_t = out["image"]
            mask_t = out["mask"].unsqueeze(0)
        else:
            img_t = torch.from_numpy(img).permute(2,0,1).float()/255.0
            mask_t = torch.from_numpy(mask).unsqueeze(0).float()
    
        # ensure mask is float 0/1
        mask_t = mask_t.float()
        mask_t = (mask_t>0.5).float()
    
        if self.return_meta:
            meta = {"img_id":img_id,"file_name":fname}
            return img_t,mask_t,meta
            
        return img_t,mask_t

In [None]:
# build the dataset and dataloader 
train_ds = ArcadeDataset(
    img_dir=TRAIN_IMG_DIR,
    json_dict = train_data,
    imgid_to_anns = train_imgid_to_anns,
    transform = None,
    return_meta = False
)
val_ds = ArcadeDataset(
    img_dir = VAL_IMG_DIR,
    json_dict = val_data,
    imgid_to_anns = val_imgid_to_anns,
    transform = None,
    return_meta = False
)
test_ds = ArcadeDataset(
    img_dir = TEST_IMG_DIR,
    json_dict = test_data,
    imgid_to_anns = test_imgid_to_anns,
    transform = None,
    return_meta = False
)

train_loader = DataLoader(train_ds,batch_size=4,shuffle=True,num_workers=2,pin_memory=True)
val_loader = DataLoader(val_ds,batch_size=4,shuffle=True,num_workers=2,pin_memory=True)
test_loader = DataLoader(test_ds,batch_size=4,shuffle=True,num_workers=2,pin_memory=True)

In [None]:
import torch
import torch.nn as nn 
import numpy as np

class DoubleConv(nn.Module):
    def __init__(self,in_ch,out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch,out_ch,3,padding=1,bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch,out_ch,3,padding=1,bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self,in_ch=1,out_ch=1,base=32):
        super().__init__()
        self.enc1 = DoubleConv(in_ch,base)
        self.enc2 = DoubleConv(base,base*2)
        self.enc3 = DoubleConv(base*2,base*4)
        self.enc4 = DoubleConv(base*4,base*8)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(base*8,base*16)

        self.up4 = nn.ConvTranspose2d(base*16,base*8,2,stride=2)
        self.dec4 = DoubleConv(base*16,base*8)
        self.up3 = nn.ConvTranspose2d(base*8,base*4,2,stride=2)
        self.dec3 = DoubleConv(base*8,base*4)
        self.up2 = nn.ConvTranspose2d(base*4,base*2,2,stride=2)
        self.dec2 = DoubleConv(base*4,base*2)
        self.up1 = nn.ConvTranspose2d(base*2,base,2,stride=2)
        self.dec1 = DoubleConv(base*2,base)

        self.out = nn.Conv2d(base,out_ch,1)

    def forward(self,x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))
        d4 = self.up4(b)
        d4 = torch.cat([d4,e4],dim=1)
        d4 = self.dec4(d4)
        d3 = self.up3(d4)
        d3 = torch.cat([d3,e3],dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2,e2],dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1,e1],dim=1)
        d1 = self.dec1(d1)
        return self.out(d1)

In [None]:
class BCEDiceLoss(nn.Module):
    def __init__(self,bce_weight=0.5,eps=1e-7):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.bce_weight = bce_weight
        self.eps = eps
    def forward(self,logits,targets):
        bce = self.bce(logits,targets)
        probs = torch.sigmoid(logits)
        intersection = (probs*targets).sum(dim=(2,3))
        union = (probs+targets).sum(dim=(2,3))
        dice = (2*intersection+self.eps)/(union+self.eps)
        dice_loss = 1-dice.mean()
        return self.bce_weight*bce+(1-self.bce_weight)*dice_loss

In [None]:
def dice_iou_from_logits(logits,targets,thr=0.5,eps=1e-7):
    probs = torch.sigmoid(logits)
    preds = (probs>thr).float()
    intersection = (preds*targets).sum(dim=(2,3))
    union = (preds+targets).sum(dim=(2,3))
    dice = (2*intersection+eps)/(union+eps)
    iou = (intersection+eps)/((preds+targets-preds*targets).sum(dim=(2,3))+eps)
    return dice.mean().item(),iou.mean().item()

In [None]:
from tqdm import tqdm
def run_epoch(model,loader,optimizer,loss_fn,device,train=True):
    model.train(train)
    total_loss=0.0
    total_dice=0.0
    total_iou=0.0
    pbar = tqdm(loader,desc="train" if train else "val",leave=False)
    for img,mask in pbar:
        img = img.to(device)
        mask = mask.to(device)
        with torch.set_grad_enabled(train):
            logits = model(img)
            loss = loss_fn(logits,mask)
            if train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
        dice,iou = dice_iou_from_logits(logits.detach(),mask.detach())
        total_loss += loss.item()
        total_dice += dice
        total_iou += iou
        pbar.set_postfix(loss=loss.item(),dice=dice,iou=iou)

    n = len(loader)
    return total_loss/n,total_dice/n,total_iou/n

In [None]:
def main():
    size = 512
    batch_size = 4
    lr = 1e-3
    epochs=25
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = UNet(in_ch=3,out_ch=1,base=32).to(device)
    loss_fn = BCEDiceLoss(bce_weight=0.5)
    optimizer = torch.optim.AdamW(model.parameters(),lr=lr)
    best_val_dice = -1

    for ep in range(1,epochs+1):
        tr_loss,tr_dice,tr_iou = run_epoch(model,train_loader,optimizer,loss_fn,device,train=True)
        va_loss,va_dice,va_iou = run_epoch(model,val_loader,optimizer,loss_fn,device,train=False)
        print(f"Epoch {ep} | train: loss {tr_loss:.4f} dice {tr_dice:.4f} iou {tr_iou:.4f} val: loss {va_loss:.4f} dice {va_dice:.4f} iou {va_iou:.4f}")

        if va_dice>best_val_dice:
            best_val_dice = va_dice
            torch.save(model.state_dict(),"unet_angio_best.pth")
            print("saved model")
    print("Training done")

In [None]:
#if __name__ == "__main__":
    #main()

In [None]:
# thr sweep on val 
@torch.no_grad()
def sweep_thr(model,loader,device="cuda",thr=np.arange(0.1,0.91,0.05)):
    model.eval()
    results = []
    for t in thr:
        total_dice,total_iou,n_batches = 0.0,0.0,0
        for batch in loader:
            img,mask = batch[0].to(device),batch[1].to(device)
            logits = model(img)
            dice,iou = dice_iou_from_logits(logits,mask,thr=t)
            total_dice += dice
            total_iou += iou
            n_batches += 1
        mean_dice = total_dice/max(n_batches,1)
        mean_iou = total_iou/max(n_batches,1)
        results.append((t,mean_dice,mean_iou))
        print(f"thr = {t:.2f} val dice = {mean_dice:.4f} val_iou = {mean_iou:.4f}")
    return results

In [None]:
thresholds = np.arange(0.1,0.91,0.05)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(in_ch=3,out_ch=1,base=32).to(device)
state_dict = torch.load("/kaggle/input/xca-unet/pytorch/default/1/unet_angio_best.pth",map_location=device)
model.load_state_dict(state_dict)
val_results = sweep_thr(model,val_loader,device=device,thr=thresholds)
best_thr,best_val_dice,best_val_iou = max(val_results,key=lambda x: x[1])
print(f"thr = {best_thr:.2f} val dice = {best_val_dice:.2f} val iou = {best_val_iou:.2f}")

In [None]:
# evaluate on test set 
@torch.no_grad()
def evaluate_on_loader(model,loader,device="cuda",thr=0.5):
    model.eval()
    total_dice,total_iou,n_batches = 0.0,0.0,0
    for batch in tqdm(loader,desc="Test Inference"):
        img,mask = batch[0].to(device),batch[1].to(device)
        logits = model(img)
        dice,iou = dice_iou_from_logits(logits,mask,thr=thr)
        total_dice+=dice
        total_iou+=iou
        n_batches +=1
    return total_dice/max(n_batches,1),total_iou/max(n_batches,1)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(in_ch=3,out_ch=1,base=32).to(device)
state_dict = torch.load("/kaggle/input/xca-unet/pytorch/default/1/unet_angio_best.pth",map_location=device)
model.load_state_dict(state_dict)
test_dice,test_iou = evaluate_on_loader(model,test_loader,device=device,thr=best_thr)
print(f"TEST Perfromance - TEST Dice = {test_dice:.4f} TEST Iou = {test_iou:.4f}")

In [None]:
import random
import matplotlib.pyplot as plt

@torch.no_grad()
def predict_single(model, img_t, device="cuda", thr=0.5):
    """
    img_t: (C,H,W) tensor in [0,1]
    returns pred_mask_np (H,W) in {0,1}, prob_np (H,W) in [0,1]
    """
    model.eval()
    x = img_t.unsqueeze(0).to(device)  # (1,C,H,W)
    logits = model(x)
    probs = torch.sigmoid(logits)[0,0].detach().cpu().numpy()
    pred = (probs > thr).astype(np.uint8)
    return pred, probs

def make_overlay(img_np, gt_np, pred_np, alpha=0.5):
    """
    img_np: (H,W,3) uint8 or float in [0,255] or [0,1]
    gt_np, pred_np: (H,W) in {0,1}
    """
    if img_np.dtype != np.uint8:
        # assume [0,1] or [0,255] float
        img_disp = img_np.copy()
        if img_disp.max() <= 1.0:
            img_disp = (img_disp * 255.0).clip(0,255)
        img_disp = img_disp.astype(np.uint8)
    else:
        img_disp = img_np.copy()

    overlay = img_disp.astype(np.float32)

    # Color layers
    gt_color   = np.zeros_like(overlay); gt_color[...,1] = 255  # green
    pred_color = np.zeros_like(overlay); pred_color[...,0] = 255 # red

    # Apply masks
    gt_mask = gt_np.astype(bool)
    pr_mask = pred_np.astype(bool)

    overlay[gt_mask]   = (1-alpha)*overlay[gt_mask]   + alpha*gt_color[gt_mask]
    overlay[pr_mask]   = (1-alpha)*overlay[pr_mask]   + alpha*pred_color[pr_mask]

    return overlay.astype(np.uint8)

def plot_random_10(model, dataset, device="cuda", thr=0.5, n=10):
    idxs = random.sample(range(len(dataset)), k=min(n, len(dataset)))

    plt.figure(figsize=(16, 4*n))
    for row, idx in enumerate(idxs):
        item = dataset[idx]
        img_t, mask_t = item[0], item[1]   # ignore meta if present
        # img_t: (C,H,W) float [0,1]
        # mask_t: (1,H,W) float {0,1}

        pred_np, prob_np = predict_single(model, img_t, device=device, thr=thr)

        # Prepare display image (H,W,3)
        img_np = img_t.detach().cpu().numpy()
        img_np = np.transpose(img_np, (1,2,0))  # (H,W,C)

        # if grayscale convert to 3ch for display
        if img_np.shape[2] == 1:
            img_np = np.repeat(img_np, 3, axis=2)

        gt_np = mask_t[0].detach().cpu().numpy().astype(np.uint8)
        overlay = make_overlay(img_np, gt_np, pred_np, alpha=0.55)

        # 4 columns
        ax1 = plt.subplot(n, 4, row*4 + 1)
        ax1.imshow(img_np, cmap=None)
        ax1.set_title("XCA Image")
        ax1.axis("off")

        ax2 = plt.subplot(n, 4, row*4 + 2)
        ax2.imshow(gt_np, cmap="gray")
        ax2.set_title("GT Mask")
        ax2.axis("off")

        ax3 = plt.subplot(n, 4, row*4 + 3)
        ax3.imshow(pred_np, cmap="gray")
        ax3.set_title("Pred Mask")
        ax3.axis("off")

        ax4 = plt.subplot(n, 4, row*4 + 4)
        ax4.imshow(overlay)
        ax4.set_title("Overlay (GT=Green, Pred=Red)")
        ax4.axis("off")

    plt.tight_layout()
    plt.show()

# Run it (use test_ds that may return meta)
plot_random_10(model, test_ds, device=device, thr=0.5, n=10)

In [None]:
import torch
from tqdm import tqdm

@torch.no_grad()
def dice_per_image_from_logits(logits, targets, thr=0.5, eps=1e-7):
    probs = torch.sigmoid(logits)
    preds = (probs > thr).float()

    inter = (preds * targets).sum(dim=(2,3))
    denom = preds.sum(dim=(2,3)) + targets.sum(dim=(2,3))
    dice = (2*inter + eps) / (denom + eps)

    return dice.squeeze(1)  # (B,)

@torch.no_grad()
def collect_sorted_test_cases(model, loader, device="cuda", thr=0.5):
    """
    Returns:
      cases_sorted: list of dicts sorted by dice ascending
      each dict contains: img(C,H,W), gt(H,W), pred(H,W), dice(float), meta(optional)
    """
    model.eval()
    cases = []

    for batch in tqdm(loader, desc="Collecting per-image Dice (test)"):
        img = batch[0].to(device)
        mask = batch[1].to(device)
        meta = batch[2] if len(batch) > 2 else None

        logits = model(img)
        probs = torch.sigmoid(logits)
        pred = (probs > thr).float()

        dice_vals = dice_per_image_from_logits(logits, mask, thr=thr)

        B = img.size(0)
        for i in range(B):
            cases.append({
                "img": img[i].detach().cpu(),        # (C,H,W)
                "gt": mask[i,0].detach().cpu(),      # (H,W)
                "pred": pred[i,0].detach().cpu(),    # (H,W)
                "dice": float(dice_vals[i].item()),
                "meta": (meta[i] if meta is not None else None)
            })

    cases_sorted = sorted(cases, key=lambda x: x["dice"])
    return cases_sorted


In [None]:
cases_sorted = collect_sorted_test_cases(model, test_loader, device=device, thr=best_thr)

worst_20 = cases_sorted[:20]
best_20  = cases_sorted[-20:]

print("Worst 20 Dice:", [round(c["dice"], 3) for c in worst_20])
print("Best  20 Dice:", [round(c["dice"], 3) for c in best_20])


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def _to_img_np(img_t):
    img = img_t.numpy()
    img = np.transpose(img, (1,2,0))  # HWC
    if img.shape[2] == 1:
        img = np.repeat(img, 3, axis=2)
    return img

def _overlay(img_np, gt_np, pred_np, alpha=0.55):
    if img_np.max() <= 1.0:
        base = (img_np * 255).astype(np.uint8)
    else:
        base = img_np.astype(np.uint8)

    out = base.astype(np.float32)
    gt = gt_np.astype(bool)
    pr = pred_np.astype(bool)

    # GT green
    out[gt] = (1-alpha)*out[gt] + alpha*np.array([0,255,0])
    # Pred red
    out[pr] = (1-alpha)*out[pr] + alpha*np.array([255,0,0])

    return out.astype(np.uint8)

def plot_case_list(case_list, title):
    n = len(case_list)
    plt.figure(figsize=(16, 4*n))
    for i, c in enumerate(case_list):
        img_np = _to_img_np(c["img"])
        gt_np = c["gt"].numpy()
        pred_np = c["pred"].numpy()
        ov = _overlay(img_np, gt_np, pred_np)

        plt.subplot(n, 4, i*4+1); plt.imshow(img_np); plt.axis("off"); plt.title("Image")
        plt.subplot(n, 4, i*4+2); plt.imshow(gt_np, cmap="gray"); plt.axis("off"); plt.title("GT")
        plt.subplot(n, 4, i*4+3); plt.imshow(pred_np, cmap="gray"); plt.axis("off"); plt.title("Pred")
        plt.subplot(n, 4, i*4+4); plt.imshow(ov); plt.axis("off"); plt.title(f"Overlay | Dice={c['dice']:.3f}")
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()


In [None]:
plot_case_list(worst_20, "Worst 20 Dice cases (Test)")
plot_case_list(best_20,  "Best 20 Dice cases (Test)")