In [8]:
# ------------------------------
# CELL 1 - Imports and global configuration
# ------------------------------
import os, time, random
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

# Settings (change if needed)
DATA_ROOT = "camvid"           # expects camvid/images, camvid/labels_processed, camvid/splits/*.txt
IMG_H, IMG_W = 360, 480
NUM_CLASSES = 11
BATCH_SIZE = 2
EPOCHS = 30
LR = 5e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_WORKERS = 0                # set 0 in notebook/Windows
USE_PRETRAINED_BACKBONE = True

# outputs
os.makedirs("outputs/bisenet_checkpoints", exist_ok=True)
os.makedirs("outputs/bisenet_predictions", exist_ok=True)
os.makedirs("outputs/bisenet_comparisons", exist_ok=True)

# seeds
random.seed(42); np.random.seed(42); torch.manual_seed(42)

print("Device:", DEVICE, "Image size:", IMG_H, "x", IMG_W)

Device: cpu Image size: 360 x 480


In [9]:
# ------------------------------
# CELL 2 - Color <-> class utilities
# ------------------------------
color_to_id = {
    (128,128,128): 0,
    (128,0,0): 1,
    (192,192,128): 2,
    (128,64,128): 3,
    (0,0,192): 4,
    (128,128,0): 5,
    (192,128,128): 6,
    (64,64,128): 7,
    (64,0,128): 8,
    (64,64,0): 9,
    (0,128,192): 10
}
id_to_color = {v:k for k,v in color_to_id.items()}

def class_mask_to_rgb(mask):
    # mask: H x W integer class ids (0..10, 255)
    h,w = mask.shape
    rgb = np.zeros((h,w,3), dtype=np.uint8)
    for cid,color in id_to_color.items():
        rgb[mask==cid] = color
    return rgb

In [10]:
# ------------------------------
# CELL 3 - Transforms (train/val)
# ------------------------------
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def random_scale_and_crop(img, mask, th, tw, min_s=0.95, max_s=1.05):
    s = random.uniform(min_s, max_s)
    nw, nh = max(1,int(img.width*s)), max(1,int(img.height*s))
    img = img.resize((nw, nh), Image.BILINEAR)
    mask = mask.resize((nw, nh), Image.NEAREST)
    pad_h = max(0, th - nh); pad_w = max(0, tw - nw)
    if pad_h or pad_w:
        img = Image.fromarray(np.pad(np.array(img), ((0,pad_h),(0,pad_w),(0,0)), mode='reflect'))
        mask = Image.fromarray(np.pad(np.array(mask), ((0,pad_h),(0,pad_w)), mode='constant', constant_values=255))
    left = random.randint(0, max(0, nw - tw)); top = random.randint(0, max(0, nh - th))
    img = img.crop((left, top, left+tw, top+th))
    mask = mask.crop((left, top, left+tw, top+th))
    return img, mask

def train_transform(img, mask, h=IMG_H, w=IMG_W):
    # flip + small scale crop + normalize
    if random.random() < 0.5:
        img = TF.hflip(img)
        mask = Image.fromarray(np.flip(np.array(mask), axis=1))
    img, mask = random_scale_and_crop(img, mask, h, w)
    img = TF.to_tensor(img); img = TF.normalize(img, IMAGENET_MEAN, IMAGENET_STD)
    mask = torch.from_numpy(np.array(mask)).long()
    return img, mask

def val_transform(img, mask, h=IMG_H, w=IMG_W):
    img = img.resize((w,h), Image.BILINEAR); mask = mask.resize((w,h), Image.NEAREST)
    img = TF.to_tensor(img); img = TF.normalize(img, IMAGENET_MEAN, IMAGENET_STD)
    mask = torch.from_numpy(np.array(mask)).long()
    return img, mask

In [11]:
# ------------------------------
# CELL 4 - Dataset and DataLoaders
# ------------------------------
class CamVidDataset(Dataset):
    def __init__(self, root, split_file, mode="val"):
        self.root = Path(root)
        self.img_dir = self.root / "images"
        self.lbl_dir = self.root / "labels_processed"
        with open(split_file, "r") as f:
            self.items = [x.strip() for x in f.readlines()]
        self.mode = mode

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        name = self.items[idx]
        img = Image.open(self.img_dir/(name+".png")).convert("RGB")
        mask = Image.open(self.lbl_dir/(name+".png"))
        return (train_transform(img, mask) if self.mode=="train" else val_transform(img, mask))

def make_loaders(root=DATA_ROOT, batch_size=BATCH_SIZE):
    train_set = CamVidDataset(root, f"{root}/splits/train.txt", mode="train")
    val_set   = CamVidDataset(root, f"{root}/splits/val.txt", mode="val")
    test_set  = CamVidDataset(root, f"{root}/splits/test.txt", mode="val")
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=False)
    val_loader   = DataLoader(val_set,   batch_size=1, shuffle=False, num_workers=NUM_WORKERS)
    test_loader  = DataLoader(test_set,  batch_size=1, shuffle=False, num_workers=NUM_WORKERS)
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = make_loaders()
print("Train:", len(train_loader.dataset), "Val:", len(val_loader.dataset), "Test:", len(test_loader.dataset))

Train: 490 Val: 105 Test: 106


In [12]:
# ------------------------------
# CELL 5 - BiSeNetV2-Lite model (FINAL FIXED VERSION)
# ------------------------------
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

class DetailBranch(nn.Module):
    def __init__(self):
        super().__init__()
        self.s1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
        )
        self.s2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.s1(x)
        x = self.s2(x)
        return x


class SemanticBranch(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        
        weights = MobileNet_V2_Weights.DEFAULT if pretrained else None
        mb = mobilenet_v2(weights=weights)
        features = mb.features

        # stages
        self.stage1 = nn.Sequential(*features[:4])   # output channels: 24
        self.stage2 = nn.Sequential(*features[4:7])  # output channels: 32
        self.stage3 = nn.Sequential(*features[7:])   # output channels: 1280 (IMPORTANT)

        # FIXED REDUCTION LAYERS
        self.reduce_low  = nn.Conv2d(24,   64, 1, bias=False)
        self.reduce_mid  = nn.Conv2d(32,  128, 1, bias=False)
        self.reduce_high = nn.Conv2d(1280, 128, 1, bias=False)   

        self.bn   = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        s1 = self.stage1(x)
        s2 = self.stage2(s1)
        s3 = self.stage3(s2)

        r1 = self.reduce_low(s1)
        r2 = self.reduce_mid(s2)
        r3 = self.reduce_high(s3)

        # upsample and fuse
        r3u = nn.functional.interpolate(r3, size=r2.shape[2:], mode="bilinear", align_corners=False)
        fuse = self.bn(self.relu(r2 + r3u))

        return r1, fuse


class BiSeNetV2Lite(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, pretrained_backbone=True):
        super().__init__()
        self.detail = DetailBranch()
        self.semantic = SemanticBranch(pretrained=pretrained_backbone)

        # detail produces 128 channels
        # semantic produces 128 channels
        # so combined = 256
        self.fuse_conv = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256), nn.ReLU(inplace=True),
        )

        self.classifier = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(256, num_classes, 1)
        )

    def forward(self, x):
        d = self.detail(x)
        low, sem = self.semantic(x)

        sem_up = nn.functional.interpolate(sem, size=d.shape[2:], mode="bilinear", align_corners=False)

        cat = torch.cat([d, sem_up], dim=1)
        fused = self.fuse_conv(cat)
        out = self.classifier(fused)

        # final upsample to input size
        return nn.functional.interpolate(out, size=x.shape[2:], mode="bilinear", align_corners=False)


# create model
model = BiSeNetV2Lite()
print("Model OK. Params:", sum(p.numel() for p in model.parameters())//1_000_000, "M")

Model OK. Params: 4 M


In [13]:
# ------------------------------
# CELL 6 - Loss, optimizer, LR schedule, mIoU
# ------------------------------
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=1e-4)

def poly_lr(optimizer, init_lr, cur_iter, max_iter, power=0.9):
    lr = init_lr * (1 - float(cur_iter)/float(max_iter))**power
    for g in optimizer.param_groups: g['lr'] = lr

def compute_miou_batch(preds, masks, num_classes=NUM_CLASSES, ignore_index=255):
    preds = preds.cpu().numpy(); masks = masks.cpu().numpy()
    total_iou = 0.0; cnt = 0
    for p,m in zip(preds,masks):
        valid = (m != ignore_index); ious=[]
        for cls in range(num_classes):
            pred_c = (p==cls); mask_c = (m==cls)
            inter = (pred_c & mask_c & valid).sum()
            union = ((pred_c | mask_c) & valid).sum()
            if union>0: ious.append(inter/union)
        if ious: total_iou += sum(ious)/len(ious); cnt+=1
    return total_iou / max(1,cnt)

In [14]:
# ------------------------------
# CELL 7 - Training loop (poly LR + early stopping)
# ------------------------------
model = model.to(DEVICE)

def train_bisenet(epochs=EPOCHS, init_lr=LR, patience=8):
    best_miou = -1.0; no_improve = 0
    max_iter = epochs * len(train_loader); iter_count = 0

    for epoch in range(1, epochs+1):
        model.train(); running_loss=0.0; t0=time.time()
        for imgs, masks in train_loader:
            imgs = imgs.to(DEVICE); masks = masks.to(DEVICE)
            iter_count += 1; poly_lr(optimizer, init_lr, iter_count, max_iter)
            optimizer.zero_grad()
            out = model(imgs)
            loss = criterion(out, masks)
            loss.backward(); optimizer.step()
            running_loss += float(loss.item())

        # validation
        model.eval(); val_iou_sum=0.0; nval=0
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs = imgs.to(DEVICE); masks = masks.to(DEVICE)
                out = model(imgs); preds = torch.argmax(out, dim=1)
                val_iou_sum += compute_miou_batch(preds, masks); nval += 1

        val_miou = val_iou_sum / max(1, nval)
        print(f"Epoch {epoch}/{epochs} loss={running_loss/len(train_loader):.4f} val_mIoU={val_miou:.4f} time={time.time()-t0:.1f}s lr={optimizer.param_groups[0]['lr']:.2e}")

        # checkpoint
        if val_miou > best_miou + 1e-4:
            best_miou = val_miou
            torch.save(model.state_dict(), "outputs/bisenet_checkpoints/best_bisenet.pth")
            print("  Saved best model.")
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print("Early stopping at epoch", epoch)
                break

    print("Training finished. Best mIoU:", best_miou)

# run training
train_bisenet()

Epoch 1/30 loss=0.7829 val_mIoU=0.4616 time=143.8s lr=4.85e-04
  Saved best model.
Epoch 2/30 loss=0.4780 val_mIoU=0.5083 time=145.9s lr=4.70e-04
  Saved best model.
Epoch 3/30 loss=0.3969 val_mIoU=0.5018 time=136.6s lr=4.55e-04
Epoch 4/30 loss=0.3346 val_mIoU=0.5275 time=134.4s lr=4.40e-04
  Saved best model.
Epoch 5/30 loss=0.3032 val_mIoU=0.5169 time=135.7s lr=4.24e-04
Epoch 6/30 loss=0.2879 val_mIoU=0.4832 time=136.1s lr=4.09e-04
Epoch 7/30 loss=0.2696 val_mIoU=0.5343 time=132.2s lr=3.94e-04
  Saved best model.
Epoch 8/30 loss=0.2553 val_mIoU=0.5440 time=134.1s lr=3.78e-04
  Saved best model.
Epoch 9/30 loss=0.2484 val_mIoU=0.5525 time=136.1s lr=3.63e-04
  Saved best model.
Epoch 10/30 loss=0.2212 val_mIoU=0.5671 time=134.9s lr=3.47e-04
  Saved best model.
Epoch 11/30 loss=0.2323 val_mIoU=0.5592 time=134.6s lr=3.31e-04
Epoch 12/30 loss=0.2075 val_mIoU=0.5451 time=131.6s lr=3.16e-04
Epoch 13/30 loss=0.1923 val_mIoU=0.5825 time=131.9s lr=3.00e-04
  Saved best model.
Epoch 14/30 loss=

In [18]:
def evaluate_test_bisenet(
    weights="outputs/bisenet_checkpoints/best_bisenet.pth",
    root=DATA_ROOT,
    h=IMG_H,
    w=IMG_W
):
    device = DEVICE

    # Test dataset & loader
    test_set = CamVidDataset(
        root,
        f"{root}/splits/test.txt",
        mode="val"   # val_transform used internally
    )
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0)

    # Load model
    model = BiSeNetV2Lite(num_classes=NUM_CLASSES).to(device)
    state = torch.load(weights, map_location=device)
    model.load_state_dict(state)
    model.eval()

    # Compute mIoU
    miou = 0.0
    count = 0
    with torch.no_grad():
        for imgs, masks in test_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            out = model(imgs)
            preds = torch.argmax(out, dim=1)
            miou += compute_miou_batch(preds, masks)
            count += 1

    miou = miou / max(1, count)
    print("BiSeNetV2 Test mIoU:", miou)
    return miou


# Run evaluation
test_miou_bisenet = evaluate_test_bisenet()

BiSeNetV2 Test mIoU: 0.6197696506028855


In [15]:
# ------------------------------
# CELL 8 - Load best checkpoint and prepare for inference
# ------------------------------
best_model = BiSeNetV2Lite()
best_model.load_state_dict(torch.load("outputs/bisenet_checkpoints/best_bisenet.pth", map_location=DEVICE))
best_model = best_model.to(DEVICE); best_model.eval()
print("Loaded best BiSeNet model.")

Loaded best BiSeNet model.


In [16]:
# ------------------------------
# CELL 9 - Save predictions and side-by-side comparisons for all test images
# ------------------------------
def save_all_predictions_and_comparisons(out_pred_dir="outputs/bisenet_predictions", out_comp_dir="outputs/bisenet_comparisons"):
    os.makedirs(out_pred_dir, exist_ok=True); os.makedirs(out_comp_dir, exist_ok=True)
    with torch.no_grad():
        for idx, (imgs, masks) in enumerate(test_loader):
            imgs = imgs.to(DEVICE)
            masks_np = masks.numpy()[0]
            out = best_model(imgs)                      # (B,C,H,W)
            pred = torch.argmax(out, dim=1).cpu().numpy()[0]
            img_np = imgs.cpu().numpy()[0].transpose(1,2,0)
            img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
            gt_rgb = class_mask_to_rgb(masks_np); pred_rgb = class_mask_to_rgb(pred)
            Image.fromarray(pred_rgb).save(f"{out_pred_dir}/pred_{idx}.png")
            fig, ax = plt.subplots(1,3,figsize=(14,4))
            ax[0].imshow(img_np); ax[0].set_title("Image"); ax[0].axis("off")
            ax[1].imshow(gt_rgb);  ax[1].set_title("GT Mask"); ax[1].axis("off")
            ax[2].imshow(pred_rgb);ax[2].set_title("Pred Mask"); ax[2].axis("off")
            fig.tight_layout(); fig.savefig(f"{out_comp_dir}/comparison_{idx}.png", dpi=120); plt.close(fig)
    print("Saved predictions and comparisons.")

save_all_predictions_and_comparisons()

Saved predictions and comparisons.
