In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import zipfile
import os

# Define paths
gta5_zip = "/content/drive/MyDrive/Semantic_Segmentation/GTA5.zip"
cityscapes_zip = "/content/drive/MyDrive/Semantic_Segmentation/Cityscapes.zip"

os.makedirs("/content/datasets/GTA5", exist_ok=True)
os.makedirs("/content/datasets/Cityscapes", exist_ok=True)

# Unzip GTA5
with zipfile.ZipFile(gta5_zip, 'r') as zip_ref:
    zip_ref.extractall("/content/datasets/GTA5")

# Unzip Cityscapes
with zipfile.ZipFile(cityscapes_zip, 'r') as zip_ref:
    zip_ref.extractall("/content/datasets/Cityscapes")

print("✅ Both datasets extracted.")


✅ Both datasets extracted.


In [None]:
import shutil
import os

# Fix GTA5
if os.path.exists("/content/datasets/GTA5/GTA5"):
    shutil.move("/content/datasets/GTA5/GTA5/images", "/content/datasets/GTA5/images")
    shutil.move("/content/datasets/GTA5/GTA5/labels", "/content/datasets/GTA5/labels")
    shutil.rmtree("/content/datasets/GTA5/GTA5")
    print("✅ Fixed GTA5 folder structure")

# Fix Cityscapes
nested_city = "/content/datasets/Cityscapes/Cityscapes/Cityspaces"
if os.path.exists(nested_city):
    shutil.move(os.path.join(nested_city, "images"), "/content/datasets/Cityscapes/leftImg8bit")
    shutil.move(os.path.join(nested_city, "gtFine"), "/content/datasets/Cityscapes/gtFine")
    shutil.rmtree("/content/datasets/Cityscapes/Cityscapes")
    print("✅ Fixed Cityscapes folder structure")


✅ Fixed GTA5 folder structure
✅ Fixed Cityscapes folder structure


Setup — Imports & Transforms

In [None]:
import os
import torch
import torchvision
import numpy as np
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt


In [None]:
# GTA5 and Cityscapes root paths
gta5_root = "/content/datasets/GTA5"
cityscapes_root = "/content/datasets/Cityscapes"

# Confirm directories exist
print("GTA5 Image Path Exists?", os.path.exists(os.path.join(gta5_root, "images")))
print("GTA5 Label Path Exists?", os.path.exists(os.path.join(gta5_root, "labels")))
print("Cityscapes leftImg8bit Exists?", os.path.exists(os.path.join(cityscapes_root, "leftImg8bit")))
print("Cityscapes gtFine Exists?", os.path.exists(os.path.join(cityscapes_root, "gtFine")))


GTA5 Image Path Exists? True
GTA5 Label Path Exists? True
Cityscapes leftImg8bit Exists? True
Cityscapes gtFine Exists? True


Define GTA5Dataset Class

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class GTA5Dataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        self.root = root
        self.image_dir = os.path.join(root, "images")
        self.label_dir = os.path.join(root, "labels")
        self.transform = transform
        self.target_transform = target_transform

        self.images = sorted(os.listdir(self.image_dir))
        self.labels = sorted(os.listdir(self.label_dir))
        assert len(self.images) == len(self.labels), "Mismatch between images and labels"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.labels[idx])

        image = Image.open(img_path).convert("RGB")
        label = Image.open(label_path)

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label


Define CityscapesDataset Class (for validation)

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class CityscapesDataset(Dataset):
    def __init__(self, root, split='val', transform=None, target_transform=None):
        self.root = root
        self.split = split
        self.transform = transform
        self.target_transform = target_transform

        self.img_dir = os.path.join(root, "leftImg8bit", split)
        self.label_dir = os.path.join(root, "gtFine", split)

        self.img_paths = []
        self.label_paths = []

        for city in os.listdir(self.img_dir):
            img_city_path = os.path.join(self.img_dir, city)
            label_city_path = os.path.join(self.label_dir, city)

            for file_name in os.listdir(img_city_path):
                if file_name.endswith("_leftImg8bit.png"):
                    img_path = os.path.join(img_city_path, file_name)
                    label_name = file_name.replace("_leftImg8bit.png", "_gtFine_labelTrainIds.png")
                    label_path = os.path.join(label_city_path, label_name)

                    self.img_paths.append(img_path)
                    self.label_paths.append(label_path)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image = Image.open(self.img_paths[idx]).convert("RGB")
        label = Image.open(self.label_paths[idx])

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label


DACSDataset Class Implementation

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os
import torch

class DACSDataset(Dataset):
    def __init__(self, gta5_dataset, cityscapes_root, transform_img, transform_label):
        self.gta5_dataset   = gta5_dataset
        self.cityscapes_root = cityscapes_root
        self.transform_img  = transform_img
        self.transform_label = transform_label
        self.pseudo_dir     = "/content/pseudo_labels"

        # Paths to Cityscapes train images
        self.city_imgs = sorted([
            os.path.join(cityscapes_root, "leftImg8bit", "train", city, f)
            for city in os.listdir(os.path.join(cityscapes_root, "leftImg8bit", "train"))
            for f in os.listdir(os.path.join(cityscapes_root, "leftImg8bit", "train", city))
            if f.endswith("_leftImg8bit.png")
        ])

        # Matching pseudo-label paths
        self.pseudo_labels = sorted([
            os.path.join(self.pseudo_dir, os.path.basename(p).replace("_leftImg8bit.png",
                                                                      "_pseudo_label.png"))
            for p in self.city_imgs
        ])

        assert len(self.city_imgs) == len(self.pseudo_labels), "Image / pseudo-label count mismatch"

    def __len__(self):
        return min(len(self.gta5_dataset), len(self.city_imgs))

    def __getitem__(self, idx):
        # --- Source (already transformed in GTA5Dataset) ---
        src_img, src_label = self.gta5_dataset[idx]          # tensors already

        # --- Target (needs transforms) ---
        tgt_img   = Image.open(self.city_imgs[idx]).convert("RGB")
        tgt_label = Image.open(self.pseudo_labels[idx])      # PIL Image

        if self.transform_img:
            tgt_img = self.transform_img(tgt_img)

        if isinstance(tgt_label, Image.Image) and self.transform_label:
            tgt_label = self.transform_label(tgt_label).squeeze().long()

        # --- Clamp labels to valid range [0, 18] ---
        src_label = torch.clamp(src_label.squeeze().long(), 0, 18)
        tgt_label = torch.clamp(tgt_label, 0, 18)

        # --- ClassMix mask (random half of source classes) ---
        classes  = torch.unique(src_label[src_label != 255])
        if len(classes):
            chosen = classes[torch.randint(0, len(classes), (len(classes)//2 + 1,))]
        else:
            chosen = torch.tensor([], dtype=torch.long)
        mask = torch.zeros_like(src_label, dtype=torch.bool)
        for c in chosen:
            mask |= (src_label == c)
        mask = mask.unsqueeze(0)          # shape (1,H,W) for broadcasting

        # --- Mix images and labels ---
        mixed_img   = torch.where(mask, src_img, tgt_img)
        mixed_label = torch.where(mask.squeeze(0), src_label, tgt_label)

        return mixed_img, mixed_label


Image and Label Transformations

In [None]:
from torchvision import transforms

# ✅ Mean and std for normalization (ImageNet stats)
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

# ✅ Image Transform
transform_img = transforms.Compose([
    transforms.Resize((720, 1280)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# ✅ Label Transform (Just resize + convert to tensor, no normalization)
transform_label = transforms.Compose([
    transforms.Resize((720, 1280), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.PILToTensor(),  # Keeps integer label values intact
])


Sanity Check

In [None]:
import os

frankfurt_dir = "/content/datasets/Cityscapes/gtFine/val/frankfurt"
files = os.listdir(frankfurt_dir)

print(f"📁 Total files in frankfurt/: {len(files)}")
print("📝 Sample files:")
print(files[:10])  # Print first 10 file names


📁 Total files in frankfurt/: 534
📝 Sample files:
['frankfurt_000000_012868_gtFine_labelTrainIds.png', 'frankfurt_000001_020693_gtFine_labelTrainIds.png', 'frankfurt_000000_012009_gtFine_labelTrainIds.png', 'frankfurt_000001_032018_gtFine_labelTrainIds.png', 'frankfurt_000001_051737_gtFine_labelTrainIds.png', 'frankfurt_000001_042384_gtFine_labelTrainIds.png', 'frankfurt_000001_029236_gtFine_color.png', 'frankfurt_000001_033655_gtFine_labelTrainIds.png', 'frankfurt_000001_038645_gtFine_labelTrainIds.png', 'frankfurt_000001_075984_gtFine_labelTrainIds.png']


In [None]:
# Instantiate datasets (use a small subset for quick testing)
gta5_dataset = GTA5Dataset("/content/datasets/GTA5", transform=transform_img, target_transform=transform_label)
cityscapes_dataset = CityscapesDataset("/content/datasets/Cityscapes", split="val", transform=transform_img, target_transform=transform_label)

# Fetch one sample from each
gta5_img, gta5_label = gta5_dataset[0]
city_img, city_label = cityscapes_dataset[0]

# Print stats
print("GTA5 Sample:")
print("  Image shape:", gta5_img.shape)
print("  Label shape:", gta5_label.shape)
print("  Label min/max:", gta5_label.min().item(), gta5_label.max().item())
print("  Unique classes:", torch.unique(gta5_label))

print("\n Cityscapes Val Sample:")
print("  Image shape:", city_img.shape)
print("  Label shape:", city_label.shape)
print("  Label min/max:", city_label.min().item(), city_label.max().item())
print("  Unique classes:", torch.unique(city_label))


GTA5 Sample:
  Image shape: torch.Size([3, 720, 1280])
  Label shape: torch.Size([1, 720, 1280])
  Label min/max: 0 27
  Unique classes: tensor([ 0,  5,  6,  7, 11, 12, 13, 15, 17, 21, 22, 23, 24, 27],
       dtype=torch.uint8)

 Cityscapes Val Sample:
  Image shape: torch.Size([3, 720, 1280])
  Label shape: torch.Size([1, 720, 1280])
  Label min/max: 0 255
  Unique classes: tensor([  0,   1,   2,   5,   6,   7,   8,   9,  10,  11,  13, 255],
       dtype=torch.uint8)


Initialize GTA5 Dataloader


In [None]:
from torch.utils.data import DataLoader

# Step 3.1: Initialize GTA5 Dataloader
batch_size = 2
num_workers = 2
pin_memory = True
drop_last = True
shuffle = True

gta5_loader = DataLoader(
    gta5_dataset,
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers,
    pin_memory=pin_memory,
    drop_last=drop_last
)

# Sanity check
gta5_imgs, gta5_labels = next(iter(gta5_loader))
print("GTA5 Dataloader Sanity Check:")
print("  Image batch shape:", gta5_imgs.shape)
print("  Label batch shape:", gta5_labels.shape)
print("  Unique labels in batch:", torch.unique(gta5_labels))


GTA5 Dataloader Sanity Check:
  Image batch shape: torch.Size([2, 3, 720, 1280])
  Label batch shape: torch.Size([2, 1, 720, 1280])
  Unique labels in batch: tensor([ 0,  1,  4,  5,  6,  7,  8, 11, 12, 13, 15, 17, 19, 20, 21, 22, 23, 24,
        26, 27], dtype=torch.uint8)


Initialize the Cityscapes Validation Dataloader (val_loader)

In [None]:
from torchvision import transforms
import torchvision.transforms.functional as TF
import torch

# For Cityscapes validation images
image_transform = transforms.Compose([
    transforms.Resize((720, 1280)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# For validation masks
target_transform = transforms.Compose([
    transforms.Resize((720, 1280), interpolation=TF.InterpolationMode.NEAREST),
    transforms.PILToTensor()
])


In [None]:
from torch.utils.data import DataLoader

# Validation dataset (Cityscapes val)
cityscapes_val_dataset = CityscapesDataset(
    root="/content/datasets/Cityscapes",
    split='val',
    transform=image_transform,
    target_transform=target_transform
)

val_loader = DataLoader(
    cityscapes_val_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# 🔍 Sanity check
val_imgs, val_masks = next(iter(val_loader))

print("🔎 Cityscapes Val Sample:")
print(f"Image batch shape: {val_imgs.shape}")
print(f"Label batch shape: {val_masks.shape}")
print(f"Unique labels: {torch.unique(val_masks[0])}")


🔎 Cityscapes Val Sample:
Image batch shape: torch.Size([2, 3, 720, 1280])
Label batch shape: torch.Size([2, 1, 720, 1280])
Unique labels: tensor([  0,   1,   2,   5,   6,   7,   8,   9,  10,  11,  13, 255],
       dtype=torch.uint8)


DACS DataLoader

In [None]:
from torch.utils.data import DataLoader

# Initialize the DACS dataset
dacs_dataset = DACSDataset(
    gta5_dataset=gta5_dataset,
    cityscapes_root="/content/datasets/Cityscapes",
    transform_img=image_transform,
    transform_label=target_transform
)

# DACS loader (source + pseudo target mixed)
dacs_loader = DataLoader(
    dacs_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)




Clone and Set Up BiSeNet Repository

In [None]:
# Clone the BiSeNet repo (custom or official)
!git clone https://github.com/Gabrysse/MLDL2024_project1.git



import sys
sys.path.append("/content/MLDL2024_project1")

# ✅ Test import
from models.bisenet.build_bisenet import BiSeNet
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("BiSeNet repo cloned and model imported successfully!")

fatal: destination path 'MLDL2024_project1' already exists and is not an empty directory.
BiSeNet repo cloned and model imported successfully!


In [None]:
from models.bisenet.build_bisenet import BiSeNet
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build model (uses ResNet-18 by default as per repo structure)
model = BiSeNet(19, 'resnet18')
model = model.to(device)

# Load FDA-trained weights
checkpoint_path = "/content/drive/MyDrive/Semantic_Segmentation/bisenet_gta5_fda_final.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)

# Set to evaluation mode
model.eval()

print("✅ BiSeNet loaded with FDA-trained weights.")


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 134MB/s]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:01<00:00, 106MB/s]


✅ BiSeNet loaded with FDA-trained weights.


Apply FDA-trained BiSeNet model to Generate Pseudo-Labels

In [None]:
import os
from PIL import Image
import torch
from torchvision import transforms
from tqdm import tqdm
import numpy as np

# ✅ Path to Cityscapes train images
cityscapes_root = "/content/datasets/Cityscapes"
image_dir = os.path.join(cityscapes_root, "leftImg8bit", "train")

# ✅ Output directory for pseudo-labels
pseudo_label_dir = "/content/pseudo_labels"
os.makedirs(pseudo_label_dir, exist_ok=True)

# ✅ Reuse image transform
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]
transform_img = transforms.Compose([
    transforms.Resize((720, 1280)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# ✅ Iterate through all train images
cities = os.listdir(image_dir)
all_image_paths = []
for city in cities:
    city_path = os.path.join(image_dir, city)
    for fname in os.listdir(city_path):
        if fname.endswith("_leftImg8bit.png"):
            all_image_paths.append(os.path.join(city_path, fname))

print(f"🔍 Found {len(all_image_paths)} images in Cityscapes/train")

# ✅ Generate pseudo-labels
model.eval()
with torch.no_grad():
    for img_path in tqdm(all_image_paths, desc="Generating Pseudo-labels"):
        img = Image.open(img_path).convert("RGB")
        input_tensor = transform_img(img).unsqueeze(0).to(device)  # (1, 3, H, W)

        # Forward pass
        output = model(input_tensor)[0]  # shape (1, 19, H, W)
        pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy().astype(np.uint8)

        # Save pseudo-label mask
        out_name = os.path.basename(img_path).replace("_leftImg8bit.png", "_pseudo_label.png")
        out_path = os.path.join(pseudo_label_dir, out_name)
        Image.fromarray(pred).save(out_path)

print("✅ All pseudo-labels generated and saved to:", pseudo_label_dir)


🔍 Found 1572 images in Cityscapes/train


Generating Pseudo-labels: 100%|██████████| 1572/1572 [1:27:45<00:00,  3.35s/it]

✅ All pseudo-labels generated and saved to: /content/pseudo_labels





Re-initialize the DACS dataset

In [None]:
# Re-initialize DACS Dataset with existing pseudo-labels
dacs_dataset = DACSDataset(
    gta5_dataset=gta5_dataset,
    cityscapes_root="/content/datasets/Cityscapes",
    transform_img=transform_img,
    transform_label=transform_label
)


Sanity Check (small batch to avoid OOM)

In [None]:
from torch.utils.data import DataLoader, Subset

# Use a subset for quick test
small_dacs = Subset(dacs_dataset, range(10))
dacs_loader = DataLoader(small_dacs, batch_size=2, shuffle=False)

# Grab 1 mini-batch
mixed_imgs, mixed_labels = next(iter(dacs_loader))
print("Mixed image shape:", mixed_imgs.shape)
print("Mixed label shape:", mixed_labels.shape)
print("Unique labels in sample:", torch.unique(mixed_labels))


✅ Mixed image shape: torch.Size([2, 3, 720, 1280])
✅ Mixed label shape: torch.Size([2, 720, 1280])
✅ Unique labels in sample: tensor([ 0,  1,  2,  3,  4,  6,  9, 10, 11, 12, 13, 14, 15, 17, 18])


In [None]:
val_dataset = CityscapesDataset(
    root="/content/datasets/Cityscapes",
    split='val',
    transform=image_transform,
    target_transform=target_transform
)


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.bisenet.build_bisenet import BiSeNet

# === Load datasets and dataloaders (assume already defined) ===
gta5_loader = DataLoader(gta5_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4)
dacs_loader = DataLoader(dacs_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# === Model, optimizer, scaler, criterion ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiSeNet(19, 'resnet18').to(device)

optimizer = optim.SGD(model.parameters(), lr=2.5e-4, momentum=0.9, weight_decay=1e-4)
scaler = GradScaler()
criterion = nn.CrossEntropyLoss(ignore_index=255)

# === Checkpoint logic ===
checkpoint_path = "/content/checkpoints/dacs_bisenet.pth"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
start_epoch = 0
best_val_loss = float("inf")

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scaler.load_state_dict(checkpoint["scaler"])
    start_epoch = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val"]
    print(f"Resumed training from epoch {start_epoch}")

# === Training loop ===
num_epochs = 30
for epoch in range(start_epoch, num_epochs):
    model.train()
    train_loop = tqdm(dacs_loader, desc=f"Train E{epoch+1}")
    for imgs, labels in train_loop:
        imgs = imgs.to(device)
        labels = labels.to(device).long()
        labels = torch.where(labels > 18, torch.full_like(labels, 255), labels)

        optimizer.zero_grad()
        with autocast():
            logits = model(imgs)[0]  # Output: [B, C, H, W]
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loop.set_postfix(loss=loss.item())

    # === Validation ===
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for v_imgs, v_lbls in tqdm(val_loader, desc=f"🔍 Val E{epoch+1}"):
            v_imgs = v_imgs.to(device)
            v_lbls = v_lbls.to(device).long()
            v_lbls = v_lbls.squeeze(1) if v_lbls.dim() == 4 else v_lbls
            v_lbls = torch.where(v_lbls > 18, torch.full_like(v_lbls, 255), v_lbls)

            v_out = model(v_imgs)[0]
            if v_out.shape[0] != v_lbls.shape[0]:
                min_b = min(v_out.shape[0], v_lbls.shape[0])
                v_out = v_out[:min_b]
                v_lbls = v_lbls[:min_b]

            vloss_b = criterion(v_out, v_lbls)
            val_loss += vloss_b.item()
    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f}")

    # === Save best model ===
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "/content/checkpoints/dacs_bisenet_best.pth")
        print("New best model saved.")

    # === Save checkpoint ===
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),
        "best_val": best_val_loss,
    }, checkpoint_path)
    print(f"Checkpoint saved for epoch {epoch+1}")
