In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
import zipfile
import os

# Define paths
gta5_zip = "/content/drive/MyDrive/Semantic_Segmentation/GTA5.zip"
cityscapes_zip = "/content/drive/MyDrive/Semantic_Segmentation/Cityscapes.zip"

os.makedirs("/content/datasets/GTA5", exist_ok=True)
os.makedirs("/content/datasets/Cityscapes", exist_ok=True)

# Unzip GTA5
with zipfile.ZipFile(gta5_zip, 'r') as zip_ref:
    zip_ref.extractall("/content/datasets/GTA5")

# Unzip Cityscapes
with zipfile.ZipFile(cityscapes_zip, 'r') as zip_ref:
    zip_ref.extractall("/content/datasets/Cityscapes")

print("Both datasets extracted.")


Both datasets extracted.


In [3]:
import shutil
import os

# Fix GTA5
if os.path.exists("/content/datasets/GTA5/GTA5"):
    shutil.move("/content/datasets/GTA5/GTA5/images", "/content/datasets/GTA5/images")
    shutil.move("/content/datasets/GTA5/GTA5/labels", "/content/datasets/GTA5/labels")
    shutil.rmtree("/content/datasets/GTA5/GTA5")
    print("Fixed GTA5 folder structure")

# Fix Cityscapes
nested_city = "/content/datasets/Cityscapes/Cityscapes/Cityspaces"
if os.path.exists(nested_city):
    shutil.move(os.path.join(nested_city, "images"), "/content/datasets/Cityscapes/leftImg8bit")
    shutil.move(os.path.join(nested_city, "gtFine"), "/content/datasets/Cityscapes/gtFine")
    shutil.rmtree("/content/datasets/Cityscapes/Cityscapes")
    print("Fixed Cityscapes folder structure")


Fixed GTA5 folder structure
Fixed Cityscapes folder structure


Clone the BiSeNet Repo & Import the Model

In [4]:
# Clone the repo (your custom version)
!git clone https://github.com/Gabrysse/MLDL2024_project1.git

# Add to Python path for import
import sys
sys.path.append("/content/MLDL2024_project1")

# Import BiSeNet
from models.bisenet.build_bisenet import BiSeNet
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("✅ BiSeNet repo cloned and model class imported.")


Cloning into 'MLDL2024_project1'...
remote: Enumerating objects: 34, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 34 (delta 9), reused 3 (delta 3), pack-reused 13 (from 1)[K
Receiving objects: 100% (34/34), 11.29 KiB | 11.29 MiB/s, done.
Resolving deltas: 100% (9/9), done.
✅ BiSeNet repo cloned and model class imported.


Load FDA-Trained Weights into BiSeNet

In [5]:
from models.bisenet.build_bisenet import BiSeNet
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build model (uses ResNet-18 by default as per repo structure)
model = BiSeNet(19, 'resnet18')
model = model.to(device)

# Load FDA-trained weights
checkpoint_path = "/content/drive/MyDrive/Semantic_Segmentation/bisenet_gta5_fda_final.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)

# Set to evaluation mode
model.eval()

print("BiSeNet loaded with FDA-trained weights.")


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 225MB/s]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:01<00:00, 150MB/s]


BiSeNet loaded with FDA-trained weights.


Generate Pseudo-Labels using FDA-Trained BiSeNet

In [6]:
from PIL import Image
from torchvision import transforms
import os
from tqdm import tqdm
import numpy as np

# Define paths
image_dir = "/content/datasets/Cityscapes/leftImg8bit/train"
pseudo_label_dir = "/content/pseudo_labels_dice"
os.makedirs(pseudo_label_dir, exist_ok=True)

# Define image preprocessing
transform_img = transforms.Compose([
    transforms.Resize((720, 1280)),  # Resize to BiSeNet input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Gather all image paths
all_image_paths = []
for city in os.listdir(image_dir):
    city_path = os.path.join(image_dir, city)
    for fname in os.listdir(city_path):
        if fname.endswith("_leftImg8bit.png"):
            all_image_paths.append(os.path.join(city_path, fname))

print(f"🔍 Found {len(all_image_paths)} images in Cityscapes/train")

# Generate pseudo-labels
model.eval()
with torch.no_grad():
    for img_path in tqdm(all_image_paths, desc="⚙️ Generating Pseudo-labels"):
        img = Image.open(img_path).convert("RGB")
        input_tensor = transform_img(img).unsqueeze(0).to(device)  # (1, 3, H, W)

        output = model(input_tensor)[0]  # output shape: (1, 19, H, W)
        pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy().astype(np.uint8)

        # Save predicted mask
        out_name = os.path.basename(img_path).replace("_leftImg8bit.png", "_pseudo_label.png")
        out_path = os.path.join(pseudo_label_dir, out_name)
        Image.fromarray(pred).save(out_path)

print(f"✅ All pseudo-labels saved to: {pseudo_label_dir}")


🔍 Found 1572 images in Cityscapes/train


⚙️ Generating Pseudo-labels: 100%|██████████| 1572/1572 [04:28<00:00,  5.85it/s]

✅ All pseudo-labels saved to: /content/pseudo_labels_dice





Rebuild Dataset Classes for Hybrid Training

In [19]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
import random

# Transforms
transform_image = transforms.Compose([
    transforms.Resize((720, 1280)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_label = transforms.Compose([
    transforms.Resize((720, 1280), interpolation=Image.NEAREST),
    transforms.PILToTensor()
])

# GTA5 Dataset
class GTA5Dataset(Dataset):
    def __init__(self, root, transform_img, transform_lbl):
        self.img_dir = os.path.join(root, "images")
        self.lbl_dir = os.path.join(root, "labels")
        self.imgs = sorted(os.listdir(self.img_dir))
        self.transform_img = transform_img
        self.transform_lbl = transform_lbl

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.imgs[idx])
        lbl_path = os.path.join(self.lbl_dir, self.imgs[idx].replace("_leftImg8bit.png", "_gtFine_labelIds.png"))

        img = Image.open(img_path).convert("RGB")
        lbl = Image.open(lbl_path)

        if self.transform_img:
            img = self.transform_img(img)
        if self.transform_lbl:
            lbl = self.transform_lbl(lbl).squeeze().long()

        return img, lbl

# Cityscapes Val Dataset
class CityscapesValDataset(Dataset):
    def __init__(self, root, transform_img, transform_lbl):
        self.img_dir = os.path.join(root, "leftImg8bit", "val")
        self.lbl_dir = os.path.join(root, "gtFine", "val")
        self.imgs = []
        for city in os.listdir(self.img_dir):
            for fname in os.listdir(os.path.join(self.img_dir, city)):
                if fname.endswith("_leftImg8bit.png"):
                    self.imgs.append(os.path.join(city, fname))
        self.transform_img = transform_img
        self.transform_lbl = transform_lbl
        self.root = root

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_rel = self.imgs[idx]
        city, fname = img_rel.split('/')
        lbl_name = fname.replace("_leftImg8bit.png", "_gtFine_labelIds.png")

        img = Image.open(os.path.join(self.img_dir, city, fname)).convert("RGB")
        lbl = Image.open(os.path.join(self.lbl_dir, city, lbl_name))

        if self.transform_img:
            img = self.transform_img(img)
        if self.transform_lbl:
            lbl = self.transform_lbl(lbl).squeeze().long()

        return img, lbl

# DACS Dataset
class DACSDataset(Dataset):
    def __init__(self, gta5_dataset, cityscapes_root, pseudo_root, transform_img, transform_lbl, ignore_index=255):
        self.gta5_dataset = gta5_dataset
        self.transform_img = transform_img
        self.transform_lbl = transform_lbl
        self.ignore_index = ignore_index

        self.city_imgs = []
        self.pseudo_labels = []
        for city in os.listdir(os.path.join(cityscapes_root, "leftImg8bit", "train")):
            city_path = os.path.join(cityscapes_root, "leftImg8bit", "train", city)
            for fname in os.listdir(city_path):
                if fname.endswith("_leftImg8bit.png"):
                    self.city_imgs.append(os.path.join(city_path, fname))
                    pseudo = os.path.join(pseudo_root, fname.replace("_leftImg8bit.png", "_pseudo_label.png"))
                    self.pseudo_labels.append(pseudo)

    def __len__(self):
        return min(len(self.gta5_dataset), len(self.city_imgs))

    def __getitem__(self, idx):
        src_img, src_lbl = self.gta5_dataset[idx]
        tgt_img = Image.open(self.city_imgs[idx]).convert("RGB")
        tgt_lbl = Image.open(self.pseudo_labels[idx])

        tgt_img = self.transform_img(tgt_img)
        tgt_lbl = self.transform_lbl(tgt_lbl).squeeze().long()
        tgt_lbl = torch.clamp(tgt_lbl, 0, 18)

        # ClassMix
        classes = torch.unique(src_lbl)
        classes = classes[classes != self.ignore_index]
        selected = classes[torch.randperm(len(classes))[:len(classes)//2]] if len(classes) > 0 else torch.tensor([], dtype=torch.long)
        mask = torch.zeros_like(src_lbl, dtype=torch.bool)
        for c in selected:
            mask[src_lbl == c] = True
        mask = mask.unsqueeze(0)

        # Handle ignore index during mixing
        mixed_lbl = torch.where(mask.squeeze(0), src_lbl, tgt_lbl)
        # Ensure ignore index from source is preserved where mask is true
        mixed_lbl = torch.where((mask.squeeze(0) & (src_lbl == self.ignore_index)), self.ignore_index, mixed_lbl)


        mixed_img = tgt_img * (~mask) + src_img * mask.float()


        return mixed_img, mixed_lbl

 Define Transforms & Initialize DataLoaders

In [12]:
from torch.utils.data import DataLoader

# Paths
gta5_root = "/content/datasets/GTA5"
cityscapes_root = "/content/datasets/Cityscapes"
pseudo_label_dir = "/content/pseudo_labels_dice"

# Initialize datasets
gta5_dataset = GTA5Dataset(gta5_root, transform_image, transform_label)
val_dataset  = CityscapesValDataset(cityscapes_root, transform_image, transform_label)
dacs_dataset = DACSDataset(gta5_dataset, cityscapes_root, pseudo_label_dir, transform_image, transform_label)

# DataLoaders
gta5_loader = DataLoader(gta5_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader  = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4)
dacs_loader = DataLoader(dacs_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# ✅ Sanity check
batch_imgs, batch_lbls = next(iter(dacs_loader))
print("DACS ✅", batch_imgs.shape, batch_lbls.shape)


DACS ✅ torch.Size([2, 3, 720, 1280]) torch.Size([2, 720, 1280])


DataLoader Setup

In [13]:
from torch.utils.data import DataLoader

# ✅ Dataset paths
gta5_root = "/content/datasets/GTA5"
cityscapes_root = "/content/datasets/Cityscapes"
pseudo_label_dir = "/content/pseudo_labels_dice"  # updated pseudo-label path

# ✅ Initialize datasets
gta5_dataset = GTA5Dataset(gta5_root, transform_image, transform_label)
val_dataset  = CityscapesValDataset(cityscapes_root, transform_image, transform_label)
dacs_dataset = DACSDataset(gta5_dataset, cityscapes_root, pseudo_label_dir, transform_image, transform_label)

# ✅ Initialize DataLoaders
gta5_loader = DataLoader(gta5_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader  = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4)
dacs_loader = DataLoader(dacs_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# ✅ Sanity check
sample_imgs, sample_lbls = next(iter(dacs_loader))
print("DACS DataLoader ✅", sample_imgs.shape, sample_lbls.shape)


DACS DataLoader ✅ torch.Size([2, 3, 720, 1280]) torch.Size([2, 720, 1280])


Define Hybrid Loss (CrossEntropy + Dice)

HybridSegmentationLoss

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HybridSegmentationLoss(nn.Module):
    def __init__(self, weight_dice=1.0, ignore_index=255):
        super(HybridSegmentationLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.weight_dice = weight_dice
        self.ignore_index = ignore_index

    def forward(self, logits, targets):
        # logits: (B, C, H, W)
        # targets: (B, H, W)

        ce = self.ce_loss(logits, targets)

        if self.weight_dice == 0:
            return ce

        # Convert targets to one-hot format for dice
        num_classes = logits.shape[1]
        # Mask out invalid pixels
        valid_mask = (targets != self.ignore_index)

        # Clamp invalid class labels
        targets_clamped = targets.clone()
        targets_clamped[~valid_mask] = 0  # temp set to 0 just for one_hot
        targets_one_hot = F.one_hot(targets.clamp(0, num_classes - 1), num_classes=num_classes)  # (B, H, W, C)
        targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()  # (B, C, H, W)

        probs = torch.softmax(logits, dim=1)  # (B, C, H, W)

        # Mask ignore pixels
        valid_mask = valid_mask.unsqueeze(1)  # (B, 1, H, W)
        probs = probs * valid_mask
        targets_one_hot = targets_one_hot * valid_mask

        # Dice calculation
        eps = 1e-7
        intersection = (probs * targets_one_hot).sum(dim=(2, 3))
        union = probs.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))
        dice = 1 - (2 * intersection + eps) / (union + eps)
        dice_loss = dice.mean()

        return ce + self.weight_dice * dice_loss


In [None]:
import os
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

# === CONFIG ===
num_epochs = 10
save_dir = "/content/drive/MyDrive/Semantic_Segmentation/extension_checkpoints"
os.makedirs(save_dir, exist_ok=True)
best_model_path = os.path.join(save_dir, "best_model_dacs_dice.pth")
checkpoint_path = os.path.join(save_dir, "last_checkpoint.pth")
ignore_index = 255 # Define ignore index here

# === MODEL ===
model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scaler = GradScaler()
criterion = HybridSegmentationLoss(weight_dice=1.0, ignore_index=ignore_index).to(device)
val_criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) # Set ignore_index for validation loss

# === RESUME CHECKPOINT IF EXISTS ===
start_epoch = 0
best_val_loss = float("inf")
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scaler.load_state_dict(checkpoint['scaler'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['best_val_loss']
    print(f"Resumed training from epoch {start_epoch}")

# === TRAINING LOOP ===
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0.0
    loop = tqdm(dacs_loader, desc=f"Epoch {epoch+1}", leave=False)

    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        with autocast():
            logits = model(imgs)[0]
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        loop.set_postfix(train_loss=loss.item())

    avg_train_loss = total_loss / len(dacs_loader)

    # === VALIDATION ===
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for v_imgs, v_labels in val_loader:
            v_imgs, v_labels = v_imgs.to(device), v_labels.to(device)
            v_logits = model(v_imgs)[0]
            v_loss = val_criterion(v_logits, v_labels)
            val_loss += v_loss.item()
    val_loss /= len(val_loader)

    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # === SAVE BEST MODEL ===
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Saved Best Model at Epoch {epoch+1}")

    # === SAVE CHECKPOINT ===
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scaler': scaler.state_dict(),
        'best_val_loss': best_val_loss
    }, checkpoint_path)

print("Training Complete")