In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
import zipfile
import os

# Define paths
gta5_zip = "/content/drive/MyDrive/Semantic_Segmentation/GTA5.zip"
cityscapes_zip = "/content/drive/MyDrive/Semantic_Segmentation/Cityscapes.zip"

os.makedirs("/content/datasets/GTA5", exist_ok=True)
os.makedirs("/content/datasets/Cityscapes", exist_ok=True)

# Unzip GTA5
with zipfile.ZipFile(gta5_zip, 'r') as zip_ref:
    zip_ref.extractall("/content/datasets/GTA5")

# Unzip Cityscapes
with zipfile.ZipFile(cityscapes_zip, 'r') as zip_ref:
    zip_ref.extractall("/content/datasets/Cityscapes")

print("Both datasets extracted.")


Both datasets extracted.


In [3]:
import shutil
import os

# Fix GTA5
if os.path.exists("/content/datasets/GTA5/GTA5"):
    shutil.move("/content/datasets/GTA5/GTA5/images", "/content/datasets/GTA5/images")
    shutil.move("/content/datasets/GTA5/GTA5/labels", "/content/datasets/GTA5/labels")
    shutil.rmtree("/content/datasets/GTA5/GTA5")
    print("Fixed GTA5 folder structure")

# Fix Cityscapes
nested_city = "/content/datasets/Cityscapes/Cityscapes/Cityspaces"
if os.path.exists(nested_city):
    shutil.move(os.path.join(nested_city, "images"), "/content/datasets/Cityscapes/leftImg8bit")
    shutil.move(os.path.join(nested_city, "gtFine"), "/content/datasets/Cityscapes/gtFine")
    shutil.rmtree("/content/datasets/Cityscapes/Cityscapes")
    print("Fixed Cityscapes folder structure")


Fixed GTA5 folder structure
Fixed Cityscapes folder structure


In [4]:
# Clone the repo (your custom version)
!git clone https://github.com/Gabrysse/MLDL2024_project1.git

# Add to Python path for import
import sys
sys.path.append("/content/MLDL2024_project1")

# Import BiSeNet
from models.bisenet.build_bisenet import BiSeNet
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("BiSeNet repo cloned and model class imported.")


Cloning into 'MLDL2024_project1'...
remote: Enumerating objects: 34, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 34 (delta 9), reused 3 (delta 3), pack-reused 13 (from 1)[K
Receiving objects: 100% (34/34), 11.29 KiB | 5.64 MiB/s, done.
Resolving deltas: 100% (9/9), done.
BiSeNet repo cloned and model class imported.


In [5]:
from models.bisenet.build_bisenet import BiSeNet
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build model (uses ResNet-18 by default as per repo structure)
model = BiSeNet(19, 'resnet18')
model = model.to(device)

# Load FDA-trained weights
checkpoint_path = "/content/drive/MyDrive/Semantic_Segmentation/bisenet_gta5_fda_final.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)

# Set to evaluation mode
model.eval()

print("BiSeNet loaded with FDA-trained weights.")


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 194MB/s]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:01<00:00, 170MB/s]


BiSeNet loaded with FDA-trained weights.


In [6]:
from PIL import Image
from torchvision import transforms
import os
from tqdm import tqdm
import numpy as np

# Define paths
image_dir = "/content/datasets/Cityscapes/leftImg8bit/train"
pseudo_label_dir = "/content/pseudo_labels_dice"
os.makedirs(pseudo_label_dir, exist_ok=True)

# Define image preprocessing
transform_img = transforms.Compose([
    transforms.Resize((720, 1280)),  # Resize to BiSeNet input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Gather all image paths
all_image_paths = []
for city in os.listdir(image_dir):
    city_path = os.path.join(image_dir, city)
    for fname in os.listdir(city_path):
        if fname.endswith("_leftImg8bit.png"):
            all_image_paths.append(os.path.join(city_path, fname))

print(f"Found {len(all_image_paths)} images in Cityscapes/train")

# Generate pseudo-labels
model.eval()
with torch.no_grad():
    for img_path in tqdm(all_image_paths, desc="Generating Pseudo-labels"):
        img = Image.open(img_path).convert("RGB")
        input_tensor = transform_img(img).unsqueeze(0).to(device)  # (1, 3, H, W)

        output = model(input_tensor)[0]  # output shape: (1, 19, H, W)
        pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy().astype(np.uint8)

        # Save predicted mask
        out_name = os.path.basename(img_path).replace("_leftImg8bit.png", "_pseudo_label.png")
        out_path = os.path.join(pseudo_label_dir, out_name)
        Image.fromarray(pred).save(out_path)

print(f"All pseudo-labels saved to: {pseudo_label_dir}")


Found 1572 images in Cityscapes/train


Generating Pseudo-labels: 100%|██████████| 1572/1572 [04:43<00:00,  5.54it/s]

All pseudo-labels saved to: /content/pseudo_labels_dice





In [7]:
# Backup pseudo labels
!cp -r  /content/pseudo_labels_dice /content/drive/MyDrive/Semantic_Segmentation/pseudo_labels_backup/

In [8]:
#!mkdir /content/pseudo_labels_dice
!cp -r /content/drive/MyDrive/Semantic_Segmentation/pseudo_labels_backup/ /content/pseudo_labels_dice

In [9]:
import os
print(os.listdir("/content/datasets/GTA5/labels")[:5])


['02492.png', '01407.png', '00345.png', '01436.png', '01773.png']


In [10]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
import random

# Transforms
transform_image = transforms.Compose([
    transforms.Resize((720, 1280)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_label = transforms.Compose([
    transforms.Resize((720, 1280), interpolation=Image.NEAREST),
    transforms.PILToTensor()
])

# GTA5 Dataset
class GTA5Dataset(Dataset):
    def __init__(self, root, transform_img, transform_lbl):
        self.img_dir = os.path.join(root, "images")
        self.lbl_dir = os.path.join(root, "labels")
        self.imgs = sorted(os.listdir(self.img_dir))
        self.transform_img = transform_img
        self.transform_lbl = transform_lbl

    def __len__(self):
        return len(self.imgs)

    #def __getitem__(self, idx):
     #   img_path = os.path.join(self.img_dir, self.imgs[idx])
      #  lbl_path = os.path.join(self.lbl_dir, self.imgs[idx].replace("_leftImg8bit.png", "_gtFine_labelIds.png"))

       # img = Image.open(img_path).convert("RGB")
        #lbl = Image.open(lbl_path)

       # if self.transform_img:
        #    img = self.transform_img(img)
        #if self.transform_lbl:
         #   lbl = self.transform_lbl(lbl).squeeze().long()

     #   return img, lbl

    def __getitem__(self, idx):
           img_name = self.imgs[idx]                          # e.g., '00878_leftImg8bit.png'
           lbl_name = img_name.replace("_leftImg8bit.png", ".png")  # → '00878.png'

           img_path = os.path.join(self.img_dir, img_name)
           lbl_path = os.path.join(self.lbl_dir, lbl_name)

           img = Image.open(img_path).convert("RGB")
           lbl = Image.open(lbl_path)

           if self.transform_img:
               img = self.transform_img(img)
           if self.transform_lbl:
               lbl = self.transform_lbl(lbl).squeeze().long()

           return img, lbl


# Cityscapes Val Dataset
class CityscapesValDataset(Dataset):
    def __init__(self, root, transform_img, transform_lbl):
        self.img_dir = os.path.join(root, "leftImg8bit", "val")
        self.lbl_dir = os.path.join(root, "gtFine", "val")
        self.imgs = []
        for city in os.listdir(self.img_dir):
            for fname in os.listdir(os.path.join(self.img_dir, city)):
                if fname.endswith("_leftImg8bit.png"):
                    self.imgs.append(os.path.join(city, fname))
        self.transform_img = transform_img
        self.transform_lbl = transform_lbl
        self.root = root

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_rel = self.imgs[idx]
        city, fname = img_rel.split('/')
        lbl_name = fname.replace("_leftImg8bit.png", "_gtFine_labelTrainIds.png")

        img = Image.open(os.path.join(self.img_dir, city, fname)).convert("RGB")
        lbl = Image.open(os.path.join(self.lbl_dir, city, lbl_name))

        if self.transform_img:
            img = self.transform_img(img)
        if self.transform_lbl:
            lbl = self.transform_lbl(lbl).squeeze().long()

        return img, lbl

# DACS Dataset
class DACSDataset(Dataset):
    def __init__(self, gta5_dataset, cityscapes_root, pseudo_root, transform_img, transform_lbl, ignore_index=255):
        self.gta5_dataset = gta5_dataset
        self.transform_img = transform_img
        self.transform_lbl = transform_lbl
        self.ignore_index = ignore_index

        self.city_imgs = []
        self.pseudo_labels = []
        for city in os.listdir(os.path.join(cityscapes_root, "leftImg8bit", "train")):
            city_path = os.path.join(cityscapes_root, "leftImg8bit", "train", city)
            for fname in os.listdir(city_path):
                if fname.endswith("_leftImg8bit.png"):
                    self.city_imgs.append(os.path.join(city_path, fname))
                    pseudo = os.path.join(pseudo_root, fname.replace("_leftImg8bit.png", "_pseudo_label.png"))
                    self.pseudo_labels.append(pseudo)

    def __len__(self):
        return min(len(self.gta5_dataset), len(self.city_imgs))

    def __getitem__(self, idx):
        src_img, src_lbl = self.gta5_dataset[idx]
        tgt_img = Image.open(self.city_imgs[idx]).convert("RGB")
        tgt_lbl = Image.open(self.pseudo_labels[idx])

        tgt_img = self.transform_img(tgt_img)
        tgt_lbl = self.transform_lbl(tgt_lbl).squeeze().long()
        tgt_lbl = torch.clamp(tgt_lbl, 0, 18)

        # Clamp src_lbl to valid range or ignore index
        src_lbl = torch.clamp(src_lbl, 0, 18) # Clamp valid classes
        src_lbl[src_lbl > 18] = self.ignore_index # Set values > 18 to ignore_index

        # ClassMix
        classes = torch.unique(src_lbl)
        classes = classes[classes != self.ignore_index]
        selected = classes[torch.randperm(len(classes))[:len(classes)//2]] if len(classes) > 0 else torch.tensor([], dtype=torch.long)
        mask = torch.zeros_like(src_lbl, dtype=torch.bool)
        for c in selected:
            mask[src_lbl == c] = True
        mask = mask.unsqueeze(0)

        # Handle ignore index during mixing
        mixed_lbl = torch.where(mask.squeeze(0), src_lbl, tgt_lbl)
        # Ensure ignore index from source is preserved where mask is true
        mixed_lbl = torch.where((mask.squeeze(0) & (src_lbl == self.ignore_index)), self.ignore_index, mixed_lbl)


        mixed_img = tgt_img * (~mask) + src_img * mask.float()


        return mixed_img, mixed_lbl

In [11]:
from torch.utils.data import DataLoader

# Paths
gta5_root = "/content/datasets/GTA5"
cityscapes_root = "/content/datasets/Cityscapes"
pseudo_label_dir = "/content/pseudo_labels_dice"

# Initialize datasets
gta5_dataset = GTA5Dataset(gta5_root, transform_image, transform_label)
val_dataset  = CityscapesValDataset(cityscapes_root, transform_image, transform_label)
dacs_dataset = DACSDataset(gta5_dataset, cityscapes_root, pseudo_label_dir, transform_image, transform_label)

# DataLoaders
gta5_loader = DataLoader(gta5_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader  = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4)
dacs_loader = DataLoader(dacs_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# Sanity check
batch_imgs, batch_lbls = next(iter(dacs_loader))
print("DACS ✅", batch_imgs.shape, batch_lbls.shape)


DACS ✅ torch.Size([2, 3, 720, 1280]) torch.Size([2, 720, 1280])


In [12]:
from torch.utils.data import DataLoader

# Paths
gta5_root = "/content/datasets/GTA5"
cityscapes_root = "/content/datasets/Cityscapes"
pseudo_label_dir = "/content/pseudo_labels_dice"

# Initialize datasets
gta5_dataset = GTA5Dataset(gta5_root, transform_image, transform_label)
val_dataset  = CityscapesValDataset(cityscapes_root, transform_image, transform_label)
dacs_dataset = DACSDataset(gta5_dataset, cityscapes_root, pseudo_label_dir, transform_image, transform_label)

# Set pin_memory=False for CPU safety
gta5_loader = DataLoader(gta5_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=False, drop_last=True)
val_loader  = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4, pin_memory=False)
dacs_loader = DataLoader(dacs_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=False, drop_last=True)

# Sanity check
batch_imgs, batch_lbls = next(iter(dacs_loader))
print("DACS ✅", batch_imgs.shape, batch_lbls.shape)


DACS ✅ torch.Size([2, 3, 720, 1280]) torch.Size([2, 720, 1280])


In [13]:
from torch.utils.data import DataLoader

# Dataset paths
gta5_root = "/content/datasets/GTA5"
cityscapes_root = "/content/datasets/Cityscapes"
pseudo_label_dir = "/content/pseudo_labels_dice"  # updated pseudo-label path

# Initialize datasets
gta5_dataset = GTA5Dataset(gta5_root, transform_image, transform_label)
val_dataset  = CityscapesValDataset(cityscapes_root, transform_image, transform_label)
dacs_dataset = DACSDataset(gta5_dataset, cityscapes_root, pseudo_label_dir, transform_image, transform_label)

# Initialize DataLoaders
#gta5_loader = DataLoader(gta5_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
#val_loader  = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4)
#dacs_loader = DataLoader(dacs_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# Adaptive pin_memory based on device
use_cuda = torch.cuda.is_available()

gta5_loader = DataLoader(gta5_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=False, drop_last=True)
val_loader  = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4, pin_memory=False)
dacs_loader = DataLoader(dacs_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=False, drop_last=True)


# Sanity check
sample_imgs, sample_lbls = next(iter(dacs_loader))
print("DACS DataLoader ✅", sample_imgs.shape, sample_lbls.shape)


DACS DataLoader ✅ torch.Size([2, 3, 720, 1280]) torch.Size([2, 720, 1280])


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HybridSegmentationLoss(nn.Module):
    def __init__(self, weight_dice=1.0, ignore_index=255):
        super(HybridSegmentationLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.weight_dice = weight_dice
        self.ignore_index = ignore_index

    def forward(self, logits, targets):
        # logits: (B, C, H, W)
        # targets: (B, H, W)

        ce = self.ce_loss(logits, targets)

        if self.weight_dice == 0:
            return ce

        # Convert targets to one-hot format for dice
        num_classes = logits.shape[1]

        # Create a mask for valid pixels (not ignore_index)
        valid_mask = (targets != self.ignore_index)

        # Apply mask to targets and logits for dice calculation
        # We only consider valid pixels for dice loss
        targets_dice = targets[valid_mask]
        logits_dice = logits.permute(0, 2, 3, 1)[valid_mask] # (N_valid, C)

        if targets_dice.numel() == 0: # Handle case where no valid pixels exist
            return ce

        # Clamp valid targets to the range [0, num_classes - 1] before one_hot
        targets_dice_clamped = torch.clamp(targets_dice, 0, num_classes - 1)

        targets_one_hot = F.one_hot(targets_dice_clamped, num_classes=num_classes).float() # (N_valid, C)

        probs = torch.softmax(logits_dice, dim=-1) # (N_valid, C)

        # Dice calculation on valid pixels
        eps = 1e-7
        intersection = (probs * targets_one_hot).sum(dim=0)
        union = probs.sum(dim=0) + targets_one_hot.sum(dim=0)
        dice = 1 - (2 * intersection + eps) / (union + eps)
        dice_loss = dice.mean()

        return ce + self.weight_dice * dice_loss

In [15]:
print(f"Using device: {device}")
print("Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


Using device: cuda
Device Name: Tesla T4


In [16]:
try:
    dummy = torch.randn(1, 3, 720, 1280).to(device)
    print("Dummy tensor successfully moved to GPU")
except RuntimeError as e:
    print("Failed to move dummy tensor to GPU:", e)


Dummy tensor successfully moved to GPU


In [17]:
#device = torch.device("cpu")
#print("Switched to:", device)



In [18]:
gta5_loader = DataLoader(gta5_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=False, drop_last=True)
val_loader  = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=False)
dacs_loader = DataLoader(dacs_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=False, drop_last=True)



In [19]:
#model = BiSeNet(num_classes=19, context_path='resnet18').to(device)



In [20]:
import os

val_city_dir = "/content/datasets/Cityscapes/gtFine/val/munster"
files = os.listdir(val_city_dir)
print("Files found:", files[:5])


Files found: ['munster_000154_000019_gtFine_color.png', 'munster_000132_000019_gtFine_color.png', 'munster_000063_000019_gtFine_labelTrainIds.png', 'munster_000026_000019_gtFine_color.png', 'munster_000145_000019_gtFine_color.png']


In [21]:
import os
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

# === CONFIG ===
num_epochs = 50
save_dir = "/content/drive/MyDrive/Semantic_Segmentation/extension_checkpoints"
os.makedirs(save_dir, exist_ok=True)
best_model_path = os.path.join(save_dir, "best_model_dacs_dice.pth")
checkpoint_path = os.path.join(save_dir, "last_checkpoint.pth")
ignore_index = 255 # Define ignore index here

# === MODEL ===
model = BiSeNet(num_classes=19, context_path='resnet18').to(device)

# Add CUDA synchronization and error checking (keeping for now, though the new error is different)
if torch.cuda.is_available():
    torch.cuda.synchronize()
    try:
        torch.cuda.current_stream().synchronize()
    except RuntimeError as e:
        print(f"CUDA error after model to device: {e}")


        pass # Use pass to continue execution if it was just a warning

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scaler = GradScaler()
criterion = HybridSegmentationLoss(weight_dice=1.0, ignore_index=ignore_index).to(device)
val_criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index).to(device) # Set ignore_index for validation loss and move to device

# === RESUME CHECKPOINT IF EXISTS ===
start_epoch = 0
best_val_loss = float("inf")
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scaler.load_state_dict(checkpoint['scaler'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['best_val_loss']
    print(f"Resumed training from epoch {start_epoch}")

# === TRAINING LOOP ===
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0.0
    loop = tqdm(dacs_loader, desc=f"Epoch {epoch+1}", leave=False)

    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        with autocast():
            logits = model(imgs)[0]
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        loop.set_postfix(train_loss=loss.item())

    avg_train_loss = total_loss / len(dacs_loader)

    # === VALIDATION ===
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for v_imgs, v_labels in val_loader:
            v_imgs, v_labels = v_imgs.to(device), v_labels.to(device)
            v_logits = model(v_imgs)[0]

            # Add a check for the tensor's dimensions
            if v_logits.dim() == 3:
              v_logits = v_logits.unsqueeze(0) # Adds the batch dimension back in

            # Reshape logits and labels for the loss function
            v_logits = v_logits.permute(0, 2, 3, 1).reshape(-1, 19)
            v_labels = v_labels.reshape(-1)

            v_loss = val_criterion(v_logits, v_labels)
            val_loss += v_loss.item()
    val_loss /= len(val_loader)

    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # === SAVE BEST MODEL ===
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Saved Best Model at Epoch {epoch+1}")

    # === SAVE CHECKPOINT ===
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scaler': scaler.state_dict(),
        'best_val_loss': best_val_loss
    }, checkpoint_path)

print("Training Complete")

Resumed training from epoch 48




Epoch 49 | Train Loss: 1.0134 | Val Loss: 4.8244




Epoch 50 | Train Loss: 1.0064 | Val Loss: 5.4947
Training Complete


In [22]:
import torch
from models.bisenet.build_bisenet import BiSeNet

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize BiSeNet (same setup as training)
model = BiSeNet(19, 'resnet18').to(device)

# Load best trained weights after DACS
checkpoint_path = "/content/drive/MyDrive/Semantic_Segmentation/extension_checkpoints/best_model_dacs_dice.pth"
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)

# Set to eval mode
model.eval()
print(" Best DACS-trained model loaded successfully.")


✅ Best DACS-trained model loaded successfully.


In [25]:
import os
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

class CityscapesValDataset(Dataset):
    def __init__(self, img_root, lbl_root, transform=None, target_transform=None):
        self.img_root = img_root
        self.lbl_root = lbl_root
        self.transform = transform
        self.target_transform = target_transform
        self.img_paths = []
        self.lbl_paths = []

        for city in os.listdir(img_root):
            img_dir = os.path.join(img_root, city)
            lbl_dir = os.path.join(lbl_root, city)
            for file in os.listdir(img_dir):
                if file.endswith("_leftImg8bit.png"):
                    self.img_paths.append(os.path.join(img_dir, file))
                    lbl_file = file.replace("_leftImg8bit.png", "_gtFine_labelTrainIds.png")
                    self.lbl_paths.append(os.path.join(lbl_dir, lbl_file))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("RGB")
        lbl = Image.open(self.lbl_paths[idx])

        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            lbl = self.target_transform(lbl)

        return img, lbl

# Define transforms
val_transform = transforms.Compose([
    transforms.Resize((720, 1280)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

lbl_transform = transforms.Compose([
    transforms.Resize((720, 1280), interpolation=Image.NEAREST),
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
])

# Create dataset and dataloader
val_dataset = CityscapesValDataset(
    img_root="/content/datasets/Cityscapes/leftImg8bit/val",
    lbl_root="/content/datasets/Cityscapes/gtFine/val",
    transform=val_transform,
    target_transform=lbl_transform
)

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
print("Validation DataLoader ready.")


Validation DataLoader ready.


In [35]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

# Initialize confusion matrix
NUM_CLASSES = 19
conf_matrix = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)

def fast_hist(pred, label, num_classes):
    # Print shapes before applying mask
    #print("Shape of pred in fast_hist:", pred.shape)
    #print("Shape of label in fast_hist:", label.shape)
    mask = (label >= 0) & (label < num_classes)
    #print("Shape of mask in fast_hist:", mask.shape)

    # Apply mask before flattening
    pred = pred[mask]
    label = label[mask]

    hist = np.bincount(
        num_classes * label.astype(int) + pred.astype(int),
        minlength=num_classes**2
    ).reshape(num_classes, num_classes)
    return hist

model.eval()
with torch.no_grad():
    for imgs, lbls in tqdm(val_loader, desc="Evaluating"):
        imgs = imgs.to(device)
        lbls = lbls.squeeze(1).cpu().numpy()  # shape: [B, H, W]

        outputs = model(imgs)[0]  # This might return [19, H, W] if batch size is 1
        # Add batch dimension if it's missing
        if outputs.dim() == 3:
             outputs = outputs.unsqueeze(0) # Adds the batch dimension back in

        #print("Shape of outputs before argmax:", outputs.shape) # Check shape here
        preds = outputs.argmax(dim=1).cpu().numpy()  # Now should be [B, H, W]

        #print("Shape of preds after argmax:", preds.shape) # Check shape here

        for i in range(preds.shape[0]): # Iterate through the batch using index
            pred = preds[i] # Get individual prediction (H, W)
            label = lbls[i] # Get individual label (H, W)

            # Print shapes before calling fast_hist
            #print("Shape of pred before fast_hist:", pred.shape)
            #print("Shape of label before fast_hist:", label.shape)
            conf_matrix += fast_hist(pred, label, NUM_CLASSES) # Pass unflattened arrays to fast_hist

# Compute mIoU
intersection = np.diag(conf_matrix)
union = conf_matrix.sum(1) + conf_matrix.sum(0) - np.diag(conf_matrix)
iou = intersection / np.maximum(union, 1)
miou = np.nanmean(iou)

# Compute pixel accuracy
pixel_acc = intersection.sum() / conf_matrix.sum()

print(f"\n Evaluation Results:")
print(f"Pixel Accuracy: {pixel_acc:.4f}")
print(f"Mean IoU: {miou:.4f}")

Evaluating: 100%|██████████| 500/500 [01:33<00:00,  5.34it/s]


 Evaluation Results:
Pixel Accuracy: 0.3242
Mean IoU: 0.0759





In [36]:
# Calculate per-class IoU
# intersection and union were already calculated in the previous cell

# Get class names (assuming Cityscapes classes) - you might need to adjust this based on your specific class mapping
class_names = [
    "road", "sidewalk", "building", "wall", "fence", "pole",
    "traffic light", "traffic sign", "vegetation", "terrain", "sky",
    "person", "rider", "car", "truck", "bus", "train", "motorcycle",
    "bicycle"
]

# Ensure the number of class names matches the number of classes
if len(class_names) != NUM_CLASSES:
    print(f"Warning: Number of class names ({len(class_names)}) does not match NUM_CLASSES ({NUM_CLASSES}).")
    # Create generic names if there's a mismatch
    class_names = [f"class_{i}" for i in range(NUM_CLASSES)]


print("\nPer-Class IoU:")
for i in range(NUM_CLASSES):
    # Avoid division by zero for classes not present in the ground truth or predictions
    if union[i] > 0:
        print(f"{class_names[i]}: {iou[i]:.4f}")
    else:
        print(f"{class_names[i]}: N/A (no instances)")


Per-Class IoU:
road: 0.2073
sidewalk: 0.1622
building: 0.5351
wall: 0.0020
fence: 0.0242
pole: 0.0000
traffic light: 0.0000
traffic sign: 0.0010
vegetation: 0.0000
terrain: 0.0006
sky: 0.0000
person: 0.0372
rider: 0.0002
car: 0.4578
truck: 0.0118
bus: 0.0000
train: 0.0000
motorcycle: 0.0000
bicycle: 0.0033
