In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from collections import Counter

import albumentations as A
from albumentations.pytorch import ToTensorV2

from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor

In [None]:
# --- Config ---
NUM_CLASSES = 12
IMAGE_SIZE = (512, 512)
BATCH_SIZE = 4
EPOCHS = 30
PATIENCE = 5
LEARNING_RATE = 2e-4

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", DEVICE)

In [None]:
# RGB to Class Mapping
rgb_to_class = {
    (0, 255, 255): 0,     # urbanland
    (255, 255, 0): 1,     # agricultureland
    (255, 0, 255): 2,     # rangeland
    (0, 255, 0): 3,       # forestland
    (0, 0, 255): 4,       # water
    (255, 255, 255): 5,   # barrenland
    (0, 0, 0): 6,         # unknown
    (60, 16, 152): 7,     # building
    (132, 41, 246): 8,    # land_unpaved
    (110, 193, 228): 9,   # road
    (254, 221, 58): 10,   # vegetation_dubai
    (155, 155, 155): 11   # unlabeled
}

# Class to RGB Mapping for visualization
class_to_rgb = {v: k for k, v in rgb_to_class.items()}

In [None]:
def rgb_mask_to_class(mask):
    """Convert RGB mask to 2D class index mask."""
    mask = np.array(mask)
    class_mask = np.zeros(mask.shape[:2], dtype=np.uint8)
    for rgb, idx in rgb_to_class.items():
        class_mask[(mask == rgb).all(axis=-1)] = idx
    return class_mask

In [None]:
# DeepGlobe Paths
deepglobe_dir = "/kaggle/input/deepglobe-land-cover-classification-dataset/train"

# Dubai Paths (multi-folder format)
dubai_dir = "/kaggle/input/semantic-segmentation-of-aerial-imagery/Semantic segmentation dataset"

# Load DeepGlobe images and masks
deepglobe_images = sorted(glob(os.path.join(deepglobe_dir, "*.jpg")))
deepglobe_masks = sorted(glob(os.path.join(deepglobe_dir, "*.png")))

# Load Dubai images and masks from tile folders
dubai_images, dubai_masks = [], []
for tile in sorted(os.listdir(dubai_dir)):
    tile_path = os.path.join(dubai_dir, tile)
    if not os.path.isdir(tile_path):
        continue
    img_folder = os.path.join(tile_path, "images")
    mask_folder = os.path.join(tile_path, "masks")
    dubai_images.extend(sorted(glob(os.path.join(img_folder, '*.jpg'))))
    dubai_masks.extend(sorted(glob(os.path.join(mask_folder, '*.png'))))

print(f"Loaded {len(deepglobe_images)} DeepGlobe images")
print(f"Loaded {len(dubai_images)} Dubai images")


In [None]:
# Combine both datasets
all_images = deepglobe_images + dubai_images
all_masks = deepglobe_masks + dubai_masks
sources = ['deepglobe'] * len(deepglobe_images) + ['dubai'] * len(dubai_images)

# Train/Val Split (Stratified)
train_imgs, val_imgs, train_masks, val_masks, train_sources, val_sources = train_test_split(
    all_images, all_masks, sources, test_size=0.2, stratify=sources, random_state=42
)

# Utility to print split info
def print_split_info(name, source_list):
    counts = Counter(source_list)
    print(f"{name} Set:")
    for dataset, count in counts.items():
        print(f"  {dataset}: {count} images")
    print(f"  Total: {len(source_list)} images\n")

print_split_info("Train", train_sources)
print_split_info("Validation", val_sources)

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = np.array(Image.open(self.image_paths[idx]).convert('RGB'))
        mask = Image.open(self.mask_paths[idx]).convert('RGB')
        mask = rgb_mask_to_class(mask)

        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        else:
            img = T.ToTensor()(img)

        return img.float(), mask.long()

In [None]:
# Albumentations transforms
train_transform = A.Compose([
    A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [None]:
# Create datasets
train_dataset = SegmentationDataset(train_imgs, train_masks, transform=train_transform)
val_dataset = SegmentationDataset(val_imgs, val_masks, transform=val_transform)

# Weighted sampling to balance smaller Dubai dataset
source_counts = Counter(train_sources)
weights = [1.0 / source_counts[src] for src in train_sources]
train_sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

In [None]:
# Load Hugging Face Segformer-B3 model
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b3-finetuned-ade-512-512",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
).to(DEVICE)

# Feature extractor (optional – you’re not using it directly now, but keep it for completeness)
feature_extractor = SegformerFeatureExtractor(do_reduce_labels=False, size=IMAGE_SIZE)

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.softmax(preds, dim=1)
        targets_one_hot = torch.nn.functional.one_hot(targets, NUM_CLASSES).permute(0, 3, 1, 2).float()

        intersection = (preds * targets_one_hot).sum(dim=(2, 3))
        union = preds.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()
        
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, preds, targets):
        logpt = -torch.nn.functional.cross_entropy(preds, targets, reduction='none')
        pt = torch.exp(logpt)
        loss = -((1 - pt) ** self.gamma) * logpt
        return loss.mean()

class CombinedLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.dice = DiceLoss()
        self.focal = FocalLoss(alpha, gamma)

    def forward(self, preds, targets):
        return self.dice(preds, targets) + self.focal(preds, targets)


In [None]:
def compute_metrics(preds, targets, num_classes=NUM_CLASSES):
    preds = torch.argmax(preds, dim=1).detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()

    ious, dices, f1s = [], [], []
    correct, total = 0, 0

    for cls in range(num_classes):
        pred_cls = (preds == cls)
        true_cls = (targets == cls)

        intersection = (pred_cls & true_cls).sum()
        union = (pred_cls | true_cls).sum()
        if union > 0:
            iou = intersection / union
            ious.append(iou)
        else:
            ious.append(np.nan)

        dice = (2 * intersection) / (pred_cls.sum() + true_cls.sum() + 1e-6)
        f1 = dice  # F1 = Dice in segmentation

        dices.append(dice)
        f1s.append(f1)

        correct += (pred_cls == true_cls).sum()
        total += true_cls.size

    return {
        "Pixel Accuracy": correct / total,
        "Mean IoU": np.nanmean(ious),
        "Mean Dice": np.nanmean(dices),
        "Mean F1": np.nanmean(f1s)
    }

In [None]:
def train_one_epoch(model, loader, optimizer, loss_fn):
    model.train()
    epoch_loss = 0
    for imgs, masks in tqdm(loader, desc="Training", leave=False):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(pixel_values=imgs).logits
        outputs = torch.nn.functional.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)

def validate_one_epoch(model, loader, loss_fn):
    model.eval()
    val_loss = 0
    metrics_total = {"Pixel Accuracy": [], "Mean IoU": [], "Mean Dice": [], "Mean F1": []}

    with torch.no_grad():
        for imgs, masks in tqdm(loader, desc="Validation", leave=False):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            outputs = model(pixel_values=imgs).logits
            outputs = torch.nn.functional.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)
            loss = loss_fn(outputs, masks)
            val_loss += loss.item()

            metrics = compute_metrics(outputs, masks)
            for k in metrics:
                metrics_total[k].append(metrics[k])

    avg_metrics = {k: np.mean(v) for k, v in metrics_total.items()}
    return val_loss / len(loader), avg_metrics


In [None]:
loss_fn = CombinedLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5, verbose=True)

best_miou = 0
patience_counter = 0

history = {
    "train_loss": [],
    "val_loss": [],
    "val_pixel_acc": [],
    "val_miou": [],
    "val_dice": [],
    "val_f1": []
}

for epoch in range(1, EPOCHS + 1):
    print(f"\n Epoch {epoch}/{EPOCHS}")

    train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn)
    val_loss, val_metrics = validate_one_epoch(model, val_loader, loss_fn)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    for k, v in val_metrics.items():
        print(f"{k}: {v:.4f}")

    scheduler.step(val_loss)

    # Save history
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_pixel_acc"].append(val_metrics["Pixel Accuracy"])
    history["val_miou"].append(val_metrics["Mean IoU"])
    history["val_dice"].append(val_metrics["Mean Dice"])
    history["val_f1"].append(val_metrics["Mean F1"])

    # Save best model
    if val_metrics["Mean IoU"] > best_miou:
        best_miou = val_metrics["Mean IoU"]
        torch.save(model.state_dict(), "best_model.pth")
        print("Best model saved!")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= PATIENCE:
        print("Early stopping triggered.")
        break


In [None]:
def plot_training_history(history):
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(18, 12))

    # Train vs Val Loss
    plt.subplot(2, 3, 1)
    plt.plot(epochs, history["train_loss"], label='Train Loss')
    plt.plot(epochs, history["val_loss"], label='Val Loss')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Pixel Accuracy
    plt.subplot(2, 3, 2)
    plt.plot(epochs, history["val_pixel_acc"], label='Pixel Accuracy')
    plt.title('Pixel Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    # Mean IoU
    plt.subplot(2, 3, 3)
    plt.plot(epochs, history["val_miou"], label='Mean IoU')
    plt.title('Mean IoU')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.legend()

    # Dice Score
    plt.subplot(2, 3, 4)
    plt.plot(epochs, history["val_dice"], label='Mean Dice')
    plt.title('Mean Dice')
    plt.xlabel('Epoch')
    plt.ylabel('Dice')
    plt.legend()

    # F1 Score
    plt.subplot(2, 3, 5)
    plt.plot(epochs, history["val_f1"], label='F1 Score')
    plt.title('F1 Score')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
def class_mask_to_rgb(mask):
    """Convert a class-indexed mask to RGB."""
    h, w = mask.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for class_idx, color in class_to_rgb.items():
        rgb[mask == class_idx] = color
    return rgb


In [None]:
# Plot training performance
plot_training_history(history)

In [None]:
def calculate_iou_and_dice(pred, target, num_classes=NUM_CLASSES):
    ious = []
    dices = []
    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)
        intersection = np.logical_and(pred_cls, target_cls).sum()
        union = np.logical_or(pred_cls, target_cls).sum()
        if union == 0:
            iou = float('nan')
        else:
            iou = intersection / union
        dice = (2 * intersection) / (pred_cls.sum() + target_cls.sum() + 1e-6)
        ious.append(iou)
        dices.append(dice)
    return np.nanmean(ious), np.nanmean(dices)


def visualize_predictions(model, dataset, num_samples):
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)

    plt.figure(figsize=(15, num_samples * 3))

    for i, idx in enumerate(indices):
        img, gt_mask = dataset[idx]
        img_input = img.unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            output = model(pixel_values=img_input).logits
            output = torch.nn.functional.interpolate(output, size=gt_mask.shape, mode='bilinear', align_corners=False)
            pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

        # Compute metrics
        mean_iou, mean_dice = calculate_iou_and_dice(pred_mask, gt_mask.numpy())

        # Convert to RGB
        rgb_gt = class_mask_to_rgb(gt_mask.numpy())
        rgb_pred = class_mask_to_rgb(pred_mask)
        img_np = img.permute(1, 2, 0).cpu().numpy()
        img_np = (img_np * [0.229, 0.224, 0.225]) + [0.485, 0.456, 0.406]
        img_np = np.clip(img_np * 255, 0, 255).astype(np.uint8)

        # Show row
        plt.subplot(num_samples, 3, i * 3 + 1)
        plt.imshow(img_np)
        plt.title("Input Image")
        plt.axis("off")

        plt.subplot(num_samples, 3, i * 3 + 2)
        plt.imshow(rgb_gt)
        plt.title("Ground Truth")
        plt.axis("off")

        plt.subplot(num_samples, 3, i * 3 + 3)
        plt.imshow(rgb_pred)
        plt.title(f"Prediction\nDice: {mean_dice:.2f} | IoU: {mean_iou:.2f}")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# Show predictions
visualize_predictions(model, val_dataset, num_samples=15)