In [1]:
############################################################
# CELL 1: SETUP & IMPORTS
############################################################
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from PIL import Image
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 50)
print(f"DEVICE = {DEVICE}")
print("=" * 50)

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True

# Parameters
IMG_H, IMG_W = 256, 512
BATCH_SIZE = 8
EPOCHS = 20
LR = 1e-3
NUM_CLASSES = 20

print(f"Image: {IMG_H}x{IMG_W} | Batch: {BATCH_SIZE} | Epochs: {EPOCHS}")
print(f"Target Classes: {NUM_CLASSES}")
print("=" * 50)

DEVICE = cuda
Image: 256x512 | Batch: 8 | Epochs: 20
Target Classes: 20


In [None]:
############################################################
# CELL 2: DATA PREPROCESSING
############################################################
print(f"Binning grayscale values (0-255) into {NUM_CLASSES} classes...")

def value_to_class(value):
    """Map grayscale value (0-255) to class (0-19)"""
    if value == 255:
        return NUM_CLASSES  # ignore
    return min(int(value * NUM_CLASSES / 256), NUM_CLASSES - 1)

print(f"âœ“ Mapping created: 0-255 â†’ 0-{NUM_CLASSES-1}")
print("=" * 50)

In [None]:
############################################################
# CELL 3: DATASET CLASS
############################################################
class CityscapesDataset(Dataset):
    def __init__(self, img_dir, mask_dir, augment=False):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.files = sorted(os.listdir(img_dir))
        self.augment = augment

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        name = self.files[idx]
        
        img = np.array(Image.open(os.path.join(self.img_dir, name)).convert("RGB"))
        mask = cv2.imread(os.path.join(self.mask_dir, name), cv2.IMREAD_GRAYSCALE)
        
        img = cv2.resize(img, (IMG_W, IMG_H))
        mask = cv2.resize(mask, (IMG_W, IMG_H), interpolation=cv2.INTER_NEAREST)
        
        if self.augment and np.random.rand() > 0.5:
            img = np.fliplr(img).copy()
            mask = np.fliplr(mask).copy()
        
        mask_mapped = np.vectorize(value_to_class)(mask).astype(np.int64)
        
        img = img.astype(np.float32) / 255.0
        img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
        
        img = torch.from_numpy(img).permute(2, 0, 1).float()
        mask_mapped = torch.from_numpy(mask_mapped).long()
        
        return img, mask_mapped

print("âœ“ Dataset class defined")

In [None]:
############################################################
# CELL 4: CREATE DATALOADERS
############################################################
print("Creating datasets...")
train_dataset = CityscapesDataset("train/img", "train/label", augment=True)
val_dataset = CityscapesDataset("val/img", "val/label", augment=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)

print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}")

# Test data loading
print("\nTesting data...")
test_img, test_mask = train_dataset[0]
print(f"Image: {test_img.shape}")
print(f"Mask: {test_mask.shape}")
unique = torch.unique(test_mask)
print(f"Unique classes: {unique.tolist()}")
valid_pixels = (test_mask < NUM_CLASSES).sum().item()
total_pixels = test_mask.numel()
print(f"Valid pixels: {valid_pixels} / {total_pixels} ({100*valid_pixels/total_pixels:.1f}%)")
print("=" * 50)

In [None]:
############################################################
# CELL 5: MODEL ARCHITECTURE WITH SELF-ATTENTION
############################################################

# Self-Attention Module 
class SelfAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.query = nn.Conv2d(channels, channels // 8, 1)
        self.key = nn.Conv2d(channels, channels // 8, 1)
        self.value = nn.Conv2d(channels, channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        B, C, H, W = x.size()
        
        # Query, Key, Value
        Q = self.query(x).view(B, -1, H * W).permute(0, 2, 1)  # (B, HW, C')
        K = self.key(x).view(B, -1, H * W)                      # (B, C', HW)
        V = self.value(x).view(B, -1, H * W)                    # (B, C, HW)
        
        # Attention scores
        attention = torch.bmm(Q, K)  # (B, HW, HW)
        attention = F.softmax(attention, dim=-1)
        
        # Apply attention to values
        out = torch.bmm(V, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)
        
        # Residual connection with learnable weight
        out = self.gamma * out + x
        
        return out

# Convolutional Block
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

# U-Net with Self-Attention
class FastUNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        # Encoder: ResNet34 backbone
        resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        
        self.enc1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
        self.enc2 = nn.Sequential(resnet.maxpool, resnet.layer1)
        self.enc3 = resnet.layer2
        self.enc4 = resnet.layer3
        self.enc5 = resnet.layer4
        
        # âœ… Self-Attention Modules 
        self.att5 = SelfAttention(512)
        self.att4 = SelfAttention(256)
        self.att3 = SelfAttention(128)
        
        # Decoder
        self.up5 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec5 = ConvBlock(512, 256)
        
        self.up4 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec4 = ConvBlock(256, 128)
        
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec3 = ConvBlock(128, 64)
        
        self.up2 = nn.ConvTranspose2d(64, 64, 2, stride=2)
        self.dec2 = ConvBlock(128, 64)
        
        self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
        self.dec1 = ConvBlock(64, 64)
        
        self.final = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        
        # âœ… Apply Self-Attention
        e5 = self.att5(e5)
        e4 = self.att4(e4)
        e3 = self.att3(e3)
        
        # Decoder with skip connections
        d5 = self.up5(e5)
        d5 = torch.cat([d5, e4], dim=1)
        d5 = self.dec5(d5)
        
        d4 = self.up4(d5)
        d4 = torch.cat([d4, e3], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e2], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = self.dec1(d1)
        
        return self.final(d1)

print("âœ“ Model architecture with Self-Attention defined")
print("  - Self-Attention modules: 3 (manually implemented)")
print("  - Encoder: ResNet34")
print("  - Decoder: U-Net style with skip connections")

In [None]:
############################################################
# CELL 6: BUILD MODEL
############################################################
print("Building model...")
model = FastUNet(num_classes=NUM_CLASSES).to(DEVICE)
print(f"âœ“ Model Ready ({NUM_CLASSES} classes)")
print("=" * 50)

In [None]:
############################################################
# CELL 7: TRAINING SETUP
############################################################
criterion = nn.CrossEntropyLoss(ignore_index=NUM_CLASSES)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

def evaluate(model, loader, max_batches=100):
    model.eval()
    correct, pixels = 0, 0
    ious = []
    
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(loader):
            if i >= max_batches:
                break
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            preds = torch.argmax(model(imgs), dim=1)
            
            valid = (masks < NUM_CLASSES)
            correct += (preds[valid] == masks[valid]).sum().item()
            pixels += valid.sum().item()
            
            for c in range(NUM_CLASSES):
                inter = ((preds == c) & (masks == c)).sum().item()
                union = ((preds == c) | (masks == c)).sum().item()
                if union > 0:
                    ious.append(inter / union)
    
    return correct/pixels if pixels > 0 else 0, np.mean(ious) if ious else 0

print("âœ“ Training setup ready")
print("=" * 50)

In [None]:
############################################################
# CELL 8: TRAINING LOOP
############################################################
print("ðŸ”¥ TRAINING STARTED ðŸ”¥")
print("=" * 50)

best_miou = 0
losses, accs, mious = [], [], []

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for imgs, masks in pbar:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        
        preds = model(imgs)
        loss = criterion(preds, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.3f}'})
    
    avg_loss = total_loss / len(train_loader)
    losses.append(avg_loss)
    scheduler.step()
    
    if (epoch + 1) % 2 == 0 or epoch == 0 or epoch >= EPOCHS - 2:
        print("\nEvaluating...")
        acc, miou = evaluate(model, val_loader)
        accs.append(acc)
        mious.append(miou)
        
        if miou > best_miou:
            best_miou = miou
            torch.save(model.state_dict(), "best_model.pth")
            print(f"ðŸ’¾ Saved! mIoU={miou*100:.1f}%")
    else:
        accs.append(accs[-1] if accs else 0)
        mious.append(mious[-1] if mious else 0)
    
    print(f"Epoch {epoch+1} | Loss={avg_loss:.3f} | Acc={accs[-1]*100:.1f}% | mIoU={mious[-1]*100:.1f}% | Best={best_miou*100:.1f}%")

print("\nðŸŽ¯ TRAINING COMPLETE!")
print("=" * 50)

In [None]:
############################################################
# CELL 9: FINAL EVALUATION
############################################################
print("ðŸ“Š FINAL EVALUATION")
model.load_state_dict(torch.load("best_model.pth"))
final_acc, final_miou = evaluate(model, val_loader, max_batches=250)

print("=" * 50)
print(f"âœ… Pixel Accuracy = {final_acc*100:.2f}%")
print(f"âœ… mIoU           = {final_miou*100:.2f}%")
print(f"âœ… Classes        = {NUM_CLASSES}")
print("=" * 50)

In [None]:
############################################################
# CELL 10: SAVE MODEL
############################################################
torch.save({
    'model_state_dict': model.state_dict(),
    'num_classes': NUM_CLASSES,
    'pixel_acc': final_acc,
    'miou': final_miou
}, "cityscapes_checkpoint.pth")

torch.save(model.state_dict(), "cityscapes_model_weights.pth")

print("âœ“ Model saved!")
print("  - best_model.pth")
print("  - cityscapes_checkpoint.pth")
print("  - cityscapes_model_weights.pth")
print("=" * 50)

In [None]:
############################################################
# CELL 11: PLOT RESULTS
############################################################
plt.figure(figsize=(12, 4))

plt.subplot(1,3,1)
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)

plt.subplot(1,3,2)
plt.plot([a*100 for a in accs])
plt.title("Pixel Accuracy (%)")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.grid(True)

plt.subplot(1,3,3)
plt.plot([m*100 for m in mious])
plt.title("mIoU (%)")
plt.xlabel("Epoch")
plt.ylabel("mIoU")
plt.grid(True)

plt.tight_layout()
plt.savefig("training_results.png", dpi=120)
plt.show()

print("âœ“ Training curves saved: training_results.png")