In [None]:
import os
from PIL import Image, UnidentifiedImageError
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.auto import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
print("Device:", device)

data_dir = os.path.join("/kaggle", "input", "cityscapes-image-pairs", "cityscapes_data")
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
train_fns = sorted(os.listdir(train_dir))
val_fns = sorted(os.listdir(val_dir))
print("Train files:", len(train_fns), "Val files:", len(val_fns))

sample_image_fp = os.path.join(train_dir, train_fns[0])
sample_image = Image.open(sample_image_fp).convert("RGB")
plt.imshow(sample_image); plt.axis('off')
print("sample:", sample_image_fp)

def split_image(image):
    """
    image: numpy array H x W x 3
    splits into left (cityscape) and right (label) halves.
    """
    image = np.array(image)
    h,w,_ = image.shape
    assert w % 2 == 0, "expected paired image with even width"
    mid = w // 2
    cityscape = image[:, :mid, :]
    label = image[:, mid:, :]
    return cityscape, label

sample_np = np.array(sample_image)
print("sample shape:", sample_np.shape)
cityscape, label = split_image(sample_np)
print("cityscape / label shapes:", cityscape.shape, label.shape)


MAX_IMAGES_TO_SAMPLE = 200  
MAX_PIXELS = 20000           
num_classes = 10

colors = []
collected = 0
image_files = train_fns[:MAX_IMAGES_TO_SAMPLE]

for fn in tqdm(image_files, desc="collecting label colors"):
    fp = os.path.join(train_dir, fn)
    try:
        img = Image.open(fp).convert("RGB")
        _, lbl = split_image(np.array(img))
        h, w, _ = lbl.shape

        pix = lbl.reshape(-1, 3)

        remaining = MAX_PIXELS - collected
        if remaining <= 0:
            break
        take = min(len(pix), max(1, remaining // (len(image_files))))
        if take < len(pix):

            idx = np.random.choice(len(pix), size=take, replace=False)
            pix = pix[idx]
        colors.append(pix)
        collected += pix.shape[0]
    except (UnidentifiedImageError, OSError) as e:
        print(f"skipping corrupt file {fp}: {e}")
        continue

if len(colors) == 0:
    raise RuntimeError("No colors collected from any training file. Check file paths and images.")

color_array = np.vstack(colors).astype(np.float32)
print("color_array shape (for KMeans):", color_array.shape)


from sklearn.cluster import MiniBatchKMeans
label_model = MiniBatchKMeans(
    n_clusters=num_classes,
    random_state=42,
    batch_size=4096,
    n_init="auto"
)
label_model.fit(color_array.astype(np.float64))  
print("KMeans centroids shape:", label_model.cluster_centers_.shape)


cityscape, label = split_image(sample_np)
label_class = label_model.predict(label.reshape(-1, 3).astype(np.float64)).reshape(label.shape[0], label.shape[1])

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(cityscape); axes[0].set_title("city")
axes[1].imshow(label); axes[1].set_title("raw label")
axes[2].imshow(label_class); axes[2].set_title("kmeans class")
for ax in axes: ax.axis('off')
plt.show()

class CityscapeDataset(Dataset):
    def __init__(self, image_dir, label_model, transform=None):
        self.image_dir = image_dir
        self.image_fns = sorted(os.listdir(image_dir))
        self.label_model = label_model
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ])

    def __len__(self):
        return len(self.image_fns)

    def __getitem__(self, index):
        image_fn = self.image_fns[index]
        image_fp = os.path.join(self.image_dir, image_fn)
        try:
            image = Image.open(image_fp).convert('RGB')
            image = np.array(image)
        except (UnidentifiedImageError, OSError) as e:

            H = 256; W = 256
            cityscape = np.zeros((H, W, 3), dtype=np.uint8)
            label_class = np.zeros((H, W), dtype=np.int64)
            return self.transform(cityscape), torch.from_numpy(label_class).long()

        cityscape, label = split_image(image)

        pix = label.reshape(-1, 3).astype(np.float32)
        cls = self.label_model.predict(pix.astype(np.float64)).reshape(label.shape[0], label.shape[1])

        city_t = self.transform(cityscape)  
        label_t = torch.from_numpy(cls).long()  
        return city_t, label_t


In [None]:
# -------------------------------
# Generator
# -------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights

# -------------------------------
# Feature Pyramid Network (FPN)
# -------------------------------
class FPN(nn.Module):
    def __init__(self, in_channels_list, out_channels):
        super().__init__()

        self.lateral_convs = nn.ModuleList([
            nn.Conv2d(in_ch, out_channels, 1) for in_ch in in_channels_list
        ])

        self.smooth_convs = nn.ModuleList([
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
            for _ in in_channels_list
        ])

    def forward(self, features):

        last_inner = self.lateral_convs[-1](features[-1]) 
        results = [last_inner]

        for i in range(len(features) - 2, -1, -1):
            lateral = self.lateral_convs[i](features[i])
            inner_up = F.interpolate(last_inner, size=lateral.shape[-2:], mode="nearest")
            last_inner = lateral + inner_up
            results.insert(0, last_inner)

        results = [smooth(x) for smooth, x in zip(self.smooth_convs, results)]
        return results  

# -------------------------------
# ASPP 
# -------------------------------
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates=(6, 12, 18)):
        super().__init__()
        blocks = []
        blocks.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        for r in rates:
            blocks.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=r, dilation=r, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
        self.atrous_blocks = nn.ModuleList(blocks)
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.project = nn.Sequential(
            nn.Conv2d(out_channels * (len(rates) + 2), out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        size = x.shape[2:]
        res = [m(x) for m in self.atrous_blocks]
        gp = self.global_pool(x)
        gp = F.interpolate(gp, size=size, mode='bilinear', align_corners=False)
        x = torch.cat(res + [gp], dim=1)
        return self.project(x)

# -------------------------------
# Decoder 
# -------------------------------
class Decoder(nn.Module):
    def __init__(self, in_channels, mid_channels, num_classes):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Conv2d(mid_channels, num_classes, 1)

    def forward(self, x, target_size=None):
        x = self.conv1(x)
        x = self.classifier(x)
        if target_size is not None:
            x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
        return x

# -------------------------------
# Generator: ResNet50 + FPN + ASPP + Decoder
# -------------------------------
class FPN_ASPP_Generator(nn.Module):
    def __init__(self, num_classes=10, use_imagenet_weights=True, fpn_out=256, aspp_out=256, decoder_mid=128):
        super().__init__()
        self.backbone = create_model(
            'efficientnet_b0',          
            pretrained=use_imagenet_weights,
            features_only=True,        
            out_indices=(1, 2, 3, 4)    
        )
        in_channels_list = self.backbone.feature_info.channels()
        self.fpn = FPN(in_channels_list, fpn_out)
        self.aspp = ASPP(fpn_out, aspp_out)
        self.decoder = Decoder(aspp_out, decoder_mid, num_classes)

    def forward(self, x):
        size = x.shape[2:]
        features = self.backbone(x)
        c2, c3, c4, c5 = features
        fpn_feats = self.fpn([c2, c3, c4, c5])
        x = fpn_feats[0]
        x = self.aspp(x)
        x = self.decoder(x, target_size=size)
        return x


# -------------------------------
# PatchGAN Discriminator 
# -------------------------------
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=4, ndf=64, n_layers=3):
        super().__init__()
        layers = []
        nf = ndf
        layers += [nn.Conv2d(in_channels, nf, 4, 2, 1), nn.LeakyReLU(0.2, True)]
        for _ in range(1, n_layers):
            prev = nf
            nf = min(nf * 2, 512)
            layers += [
                nn.Conv2d(prev, nf, 4, 2, 1, bias=False),
                nn.BatchNorm2d(nf),
                nn.LeakyReLU(0.2, True)
            ]
        nf2 = min(nf * 2, 512)
        layers += [
            nn.Conv2d(nf, nf2, 4, 1, 1, bias=False),
            nn.BatchNorm2d(nf2),
            nn.LeakyReLU(0.2, True)
        ]
        layers += [nn.Conv2d(nf2, 1, 4, 1, 1)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# -------------------------------
# Hinge GAN losses
# -------------------------------
def hinge_d_loss(real_logits, fake_logits):
    loss_real = torch.mean(F.relu(1.0 - real_logits))
    loss_fake = torch.mean(F.relu(1.0 + fake_logits))
    return loss_real + loss_fake

def hinge_g_loss(fake_logits):
    return -torch.mean(fake_logits)

# -------------------------------
# SegmentationGAN wrapper
# -------------------------------
class SegmentationGAN(nn.Module):
    def __init__(self, num_classes=10, use_imagenet_weights=True):
        super().__init__()
        self.num_classes = num_classes
        self.generator = FPN_ASPP_Generator(num_classes=num_classes, use_imagenet_weights=use_imagenet_weights)
        self.discriminator = PatchDiscriminator(in_channels=3 + 1)

    @staticmethod
    def build_aux_from_prob(prob):

        aux = prob.max(dim=1, keepdim=True).values
        return aux

    @staticmethod
    def build_aux_from_mask(masks, num_classes):
        
        b, h, w = masks.shape[0], masks.shape[1], masks.shape[2]
        return torch.ones((b, 1, h, w), device=masks.device, dtype=torch.float32)

    def forward(self, imgs):
        seg_logits = self.generator(imgs)
        prob = F.softmax(seg_logits, dim=1)
        aux_fake = self.build_aux_from_prob(prob)
        disc_in = torch.cat([imgs, aux_fake], dim=1)
        disc_logits = self.discriminator(disc_in)
        return seg_logits, disc_logits


In [None]:
# ================================================================
# 5-Fold Cross Validation 
# ================================================================

import os
import copy
import numpy as np
from sklearn.model_selection import KFold
from torch.utils.data import Subset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

# ---------- Hyperparameters ----------
num_classes = 10              
batch_size = 16
num_epochs = 10               
lr = 1e-3
weight_decay = 0
num_folds = 5
seed = 42

torch.manual_seed(seed)
np.random.seed(seed)

# ---------- Metrics ----------
def compute_confusion_matrix(preds, labels, num_classes):
    mask = (labels >= 0) & (labels < num_classes)
    hist = np.bincount(num_classes * labels[mask].astype(int) + preds[mask].astype(int),
                       minlength=num_classes**2).reshape(num_classes, num_classes)
    return hist

def compute_metrics_from_confusion(conf):
    epsilon = 1e-10
    true_positive = np.diag(conf).astype(np.float64)
    gt_total = conf.sum(axis=1).astype(np.float64)
    pred_total = conf.sum(axis=0).astype(np.float64)
    union = gt_total + pred_total - true_positive
    iou = true_positive / (union + epsilon)
    per_class_acc = true_positive / (gt_total + epsilon)
    valid = gt_total > 0
    mIoU = np.mean(iou[valid]) if valid.any() else 0.0
    mPA = np.mean(per_class_acc[valid]) if valid.any() else 0.0
    return {'iou_per_class': iou, 'acc_per_class': per_class_acc, 'mIoU': float(mIoU), 'mPA': float(mPA)}

# ---------- Dataset ----------
train_dataset_full = CityscapeDataset(train_dir, label_model)
test_dataset = CityscapeDataset(val_dir, label_model)
print(f"Train dataset size: {len(train_dataset_full)}, Test (val_dir) size: {len(test_dataset)}")
# ---------- GAN Training Utils ----------
def d_step(discriminator, real_imgs, fake_aux, real_aux, opt_d):
    """
    Train discriminator with hinge loss on real (img + real_aux) and fake (img + fake_aux).
    """
    discriminator.train()
    opt_d.zero_grad(set_to_none=True)

    # Real logits
    real_in = torch.cat([real_imgs, real_aux], dim=1)  
    real_logits = discriminator(real_in)

    # Fake logits
    fake_in = torch.cat([real_imgs, fake_aux], dim=1)
    fake_logits = discriminator(fake_in.detach())

    # Hinge D loss
    d_loss = hinge_d_loss(real_logits, fake_logits)
    d_loss.backward()
    opt_d.step()
    return d_loss.item()

def g_step(generator, discriminator, imgs, masks, opt_g, ce_loss, num_classes):
    """
    Train generator with CE segmentation loss + adversarial hinge loss.
    """
    generator.train()
    opt_g.zero_grad(set_to_none=True)

    # Forward generator
    seg_logits = generator(imgs)                       
    ce = ce_loss(seg_logits, masks)                   

    # Build aux from probs for adversarial path
    with torch.no_grad():
        prob = F.softmax(seg_logits, dim=1)
        aux_fake = SegmentationGAN.build_aux_from_prob(prob)  

    disc_in = torch.cat([imgs, aux_fake], dim=1)
    fake_logits = discriminator(disc_in)

    adv_g = hinge_g_loss(fake_logits)

    # Total G loss: CE + lambda_adv * adversarial
    lambda_adv = 0.2
    g_loss = ce + lambda_adv * adv_g
    g_loss.backward()
    opt_g.step()

    return g_loss.item(), ce.item(), adv_g.item()

# ---------- One epoch over a loader ----------
def train_one_epoch(model_gan, loader, opt_g, opt_d, ce_loss, device, num_classes, fold, epoch, total_epochs):
    model_gan.train()
    gen = model_gan.generator
    disc = model_gan.discriminator

    pbar = tqdm(loader, desc=f"Fold {fold} | Epoch {epoch}/{total_epochs} [GAN Train]", leave=False)
    running = {"D": 0.0, "G": 0.0, "CE": 0.0, "G_adv": 0.0}
    n = 0

    for imgs, masks in pbar:
        imgs = imgs.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        seg_logits = gen(imgs)
        with torch.no_grad():
            prob = F.softmax(seg_logits, dim=1)
            aux_fake = SegmentationGAN.build_aux_from_prob(prob)           
            aux_real = SegmentationGAN.build_aux_from_mask(masks, num_classes)  

        # 1) D step (hinge)
        d_loss = d_step(disc, imgs, aux_fake, aux_real, opt_d)

        # 2) G step (CE + adversarial hinge)
        g_loss, ce, g_adv = g_step(gen, disc, imgs, masks, opt_g, ce_loss, num_classes)
        bsz = imgs.size(0)
        running["D"] += d_loss * bsz
        running["G"] += g_loss * bsz
        running["CE"] += ce * bsz
        running["G_adv"] += g_adv * bsz
        n += bsz

        pbar.set_postfix({
            "D": f"{running['D']/n:.3f}",
            "G": f"{running['G']/n:.3f}",
            "CE": f"{running['CE']/n:.3f}",
            "G_adv": f"{running['G_adv']/n:.3f}",
        })
    for k in running:
        running[k] = running[k] / max(1, n)
    return running

@torch.no_grad()
def evaluate_model_on_loader(model_gan, loader, device, num_classes):
    model_gan.eval()
    def forward_fn(x):
        return model_gan.generator(x)
    conf = np.zeros((num_classes, num_classes), dtype=np.int64)
    for imgs, masks in loader:
        imgs = imgs.to(device)
        masks = masks.to(device)
        logits = forward_fn(imgs)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        gts = masks.cpu().numpy()
        for p, g in zip(preds, gts):
            conf += compute_confusion_matrix(p.flatten(), g.flatten(), num_classes=num_classes)
    metrics = compute_metrics_from_confusion(conf)
    return metrics, conf


# ---------- K-Fold Cross Validation ----------
kf = KFold(n_splits=num_folds, shuffle=True, random_state=seed)
indices = np.arange(len(train_dataset_full))
fold_metrics = []

for fold, (train_idx, val_idx) in enumerate(kf.split(indices), 1):
    print(f"\n========== Fold {fold}/{num_folds} ==========")

    train_subset = Subset(train_dataset_full, train_idx)
    val_subset = Subset(train_dataset_full, val_idx)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    # Build GAN (generator + patch discriminator)
    gan = SegmentationGAN(num_classes=num_classes, use_imagenet_weights=True).to(device)

    # Two optimizers: one for G, one for D
    opt_g = optim.Adam(gan.generator.parameters(), lr=lr, weight_decay=weight_decay)
    opt_d = optim.Adam(gan.discriminator.parameters(), lr=1e-5, weight_decay=weight_decay)

    # CE loss for segmentation 
    ce_loss = nn.CrossEntropyLoss()  

    best_val_mIoU = -1.0
    best_state = None

    epoch_iter = tqdm(range(1, num_epochs + 1), desc=f"Fold {fold} GAN Training", leave=False)
    for epoch in epoch_iter:
        train_stats = train_one_epoch(
            gan, train_loader, opt_g, opt_d, ce_loss, device, num_classes, fold, epoch, num_epochs
        )

        # Evaluate generator segmentation performance on fold-val
        val_metrics, _ =evaluate_model_on_loader(gan, val_loader, device, num_classes)
        val_mIoU = val_metrics['mIoU']
        val_mPA = val_metrics['mPA']

        epoch_iter.set_postfix({
            "D": f"{train_stats['D']:.3f}",
            "G": f"{train_stats['G']:.3f}",
            "CE": f"{train_stats['CE']:.3f}",
            "G_adv": f"{train_stats['G_adv']:.3f}",
            "Val_mIoU": f"{val_mIoU:.4f}",
            "Val_mPA": f"{val_mPA:.4f}"
        })

        print(f"Fold {fold} | Epoch {epoch:02d}/{num_epochs} -> "
              f"D: {train_stats['D']:.3f} | G: {train_stats['G']:.3f} | "
              f"CE: {train_stats['CE']:.3f} | G_adv: {train_stats['G_adv']:.3f} | "
              f"Val mIoU: {val_mIoU:.4f} | Val mPA: {val_mPA:.4f}")

        if val_mIoU > best_val_mIoU:
            best_val_mIoU = val_mIoU
            best_state = {
                "gen": copy.deepcopy(gan.generator.state_dict()),
                "disc": copy.deepcopy(gan.discriminator.state_dict())
            }

    if best_state is not None:
        gan.generator.load_state_dict(best_state["gen"])
        gan.discriminator.load_state_dict(best_state["disc"])

    test_metrics, _ = evaluate_model_on_loader(gan, test_loader, device, num_classes)
    print(f"✅ Fold {fold} TEST results -> mIoU: {test_metrics['mIoU']:.4f}, mPA: {test_metrics['mPA']:.4f}")
    fold_metrics.append(test_metrics)


# ---------- Results Summary ----------
num_params_gen = sum(p.numel() for p in gan.generator.parameters() if p.requires_grad)
num_params_disc = sum(p.numel() for p in gan.discriminator.parameters() if p.requires_grad)
total_params = num_params_gen + num_params_disc

print(f"Generator parameters: {num_params_gen:,}")
print(f"Discriminator parameters: {num_params_disc:,}")
print(f"Total trainable parameters: {total_params:,}")
all_mIoU = [m['mIoU'] for m in fold_metrics]
all_mPA = [m['mPA'] for m in fold_metrics]
print("\n===== Cross-validation summary =====")
for i, (miou, mpa) in enumerate(zip(all_mIoU, all_mPA), 1):
    print(f"Fold {i} -> mIoU: {miou:.4f}  mPA: {mpa:.4f}")
print(f"\nAverage mIoU over {num_folds} folds: {np.mean(all_mIoU):.4f} ± {np.std(all_mIoU):.4f}")
print(f"Average mPA  over {num_folds} folds: {np.mean(all_mPA):.4f} ± {np.std(all_mPA):.4f}")
