In [None]:
# ===================== Install =====================
!pip install -q segmentation-models-pytorch torchmetrics transformers timm

In [None]:
# ===================== Imports =====================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
from glob import glob
from tqdm import tqdm
import segmentation_models_pytorch as smp
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor

In [None]:
# ===================== Config =====================
NUM_CLASSES = 12
IMAGE_SIZE = (512, 512)
BATCH_SIZE = 4
EPOCHS = 20
PATIENCE = 5
LEARNING_RATE = 2e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# ===================== Class Weights =====================
class_weights = torch.tensor([
    1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 3.0, 2.0, 0.5
]).to(DEVICE)

# ===================== Loss Functions =====================
def focal_loss(outputs, targets, alpha=None, gamma=2.0):
    # Ensure target shape matches outputs for cross_entropy
    if outputs.shape[2:] != targets.shape[1:]:
        targets = F.interpolate(targets.unsqueeze(1).float(), size=outputs.shape[2:], mode='nearest').squeeze(1).long()

    ce_loss = F.cross_entropy(outputs, targets, reduction='none', weight=alpha)
    pt = torch.exp(-ce_loss)
    focal = ((1 - pt) ** gamma) * ce_loss
    return focal.mean()

def dice_loss(outputs, targets, smooth=1e-6):
    num_classes = outputs.shape[1]

    # Softmax over channels
    outputs = F.softmax(outputs, dim=1)

    # One-hot encode targets
    targets_one_hot = F.one_hot(targets, num_classes)  # [B, H, W, C]
    targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()  # [B, C, H, W]

    # Resize target if needed
    if outputs.shape != targets_one_hot.shape:
        targets_one_hot = F.interpolate(targets_one_hot, size=outputs.shape[2:], mode='nearest')

    intersection = (outputs * targets_one_hot).sum(dim=(2, 3))
    union = outputs.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))

    dice = (2 * intersection + smooth) / (union + smooth)
    return 1 - dice.mean()

def combined_loss(outputs, targets):
    fl = focal_loss(outputs, targets, alpha=class_weights)
    dl = dice_loss(outputs, targets)
    return fl + dl

**Classes:**

* 0: urbanland
* 1: agricultureland
* 2: rangeland
* 3: forestland
* 4: water
* 5: barrenland
* 6: unknown
* 7: building (Dubai)
* 8: land_unpaved (Dubai)
* 9: road (Dubai)
* 10: vegetation_dubai (Dubai)
* 11: unlabeled


In [None]:
# ===================== Dataset =====================
rgb_to_class = {
    (0, 255, 255): 0, (255, 255, 0): 1, (255, 0, 255): 2,
    (0, 255, 0): 3, (0, 0, 255): 4, (255, 255, 255): 5,
    (0, 0, 0): 6, (60, 16, 152): 7, (132, 41, 246): 8,
    (110, 193, 228): 9, (254, 221, 58): 10, (155, 155, 155): 11
}
class_to_rgb = {v: k for k, v in rgb_to_class.items()}

class SatelliteDataset(Dataset):
    def __init__(self, image_paths, mask_paths, augment=False):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.augment = augment
        self.transform = T.Compose([
            T.Resize(IMAGE_SIZE),
            T.ToTensor()
        ])
        self.aug = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomRotation(20)
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("RGB")

        if self.augment:
            seed = np.random.randint(2147483647)
            torch.manual_seed(seed)
            img = self.aug(img)
            torch.manual_seed(seed)
            mask = self.aug(mask)

        img = self.transform(img)
        mask = self.transform(mask).permute(1, 2, 0).numpy()

        class_mask = np.zeros((IMAGE_SIZE[1], IMAGE_SIZE[0]), dtype=np.int64)
        for rgb, cls in rgb_to_class.items():
            class_mask[np.all(mask == np.array(rgb)/255.0, axis=-1)] = cls

        return img, torch.tensor(class_mask, dtype=torch.long)

In [None]:
# ===================== Visualization =====================
def visualize_prediction(image, pred_mask, true_mask):
    image = image.permute(1, 2, 0).cpu().numpy()
    pred_mask_rgb = np.zeros_like(image)
    true_mask_rgb = np.zeros_like(image)

    for cls, rgb in class_to_rgb.items():
        pred_mask_rgb[pred_mask == cls] = rgb
        true_mask_rgb[true_mask == cls] = rgb

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(image)
    axs[0].set_title("Image")
    axs[1].imshow(pred_mask_rgb)
    axs[1].set_title("Prediction")
    axs[2].imshow(true_mask_rgb)
    axs[2].set_title("Ground Truth")
    plt.show()


In [None]:
# ===================== Train Function =====================
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, name):
    best_loss = float('inf')
    patience_counter = 0

    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        train_acc = MulticlassAccuracy(num_classes=NUM_CLASSES).to("cpu")
        train_iou = MulticlassJaccardIndex(num_classes=NUM_CLASSES).to("cpu")

        for img, mask in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
            img, mask = img.to(DEVICE), mask.to(DEVICE)
            optimizer.zero_grad()
            output = model(img)
            if isinstance(output, dict):
                output = F.interpolate(output["logits"], size=IMAGE_SIZE, mode="bilinear", align_corners=False)
            loss = criterion(output, mask)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # Resize predictions to match ground truth
            preds = torch.argmax(output, dim=1)
            if preds.shape[1:] != mask.shape[1:]:
                preds = F.interpolate(preds.unsqueeze(1).float(), size=mask.shape[1:], mode='nearest').squeeze(1).long()
            preds = preds.detach().cpu()

            train_acc.update(preds, mask.cpu())
            train_iou.update(preds, mask.cpu())


        model.eval()
        val_loss = 0
        val_acc = MulticlassAccuracy(num_classes=NUM_CLASSES).to("cpu")
        val_iou = MulticlassJaccardIndex(num_classes=NUM_CLASSES).to("cpu")

        with torch.no_grad():
            for img, mask in val_loader:
                img, mask = img.to(DEVICE), mask.to(DEVICE)
                output = model(img)
                if isinstance(output, dict):
                    output = output["logits"]
                loss = criterion(output, mask)
                val_loss += loss.item()

                # Resize predictions to match ground truth
                preds = torch.argmax(output, dim=1)
                if preds.shape[1:] != mask.shape[1:]:
                    preds = F.interpolate(preds.unsqueeze(1).float(), size=mask.shape[1:], mode='nearest').squeeze(1).long()
                preds = preds.detach().cpu()

                val_acc.update(preds, mask.cpu())
                val_iou.update(preds, mask.cpu())

        scheduler.step(val_loss)

        print(f"\nEpoch {epoch+1}:")
        print(f"  Train Loss: {train_loss/len(train_loader):.4f}, Acc: {train_acc.compute():.4f}, mIoU: {train_iou.compute():.4f}")
        print(f"  Val   Loss: {val_loss/len(val_loader):.4f}, Acc: {val_acc.compute():.4f}, mIoU: {val_iou.compute():.4f}")

        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f"{name}_best.pth")
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print("Early stopping")
                break

In [None]:
# ===================== Load Dataset =====================
deepglobe_dir = "/kaggle/input/deepglobe-land-cover-classification-dataset/train"
deepglobe_images = sorted(glob(os.path.join(deepglobe_dir, '*_sat.jpg')))
deepglobe_masks = sorted(glob(os.path.join(deepglobe_dir, '*_mask.png')))

dubai_dir = "/kaggle/input/semantic-segmentation-of-aerial-imagery/Semantic segmentation dataset"
dubai_images, dubai_masks = [], []
for tile in sorted(os.listdir(dubai_dir)):
    tile_path = os.path.join(dubai_dir, tile)
    if not os.path.isdir(tile_path): continue
    img_folder = os.path.join(tile_path, "images")
    mask_folder = os.path.join(tile_path, "masks")
    dubai_images.extend(sorted(glob(os.path.join(img_folder, '*.jpg'))))
    dubai_masks.extend(sorted(glob(os.path.join(mask_folder, '*.png'))))

all_images = deepglobe_images + dubai_images
all_masks = deepglobe_masks + dubai_masks
train_imgs, val_imgs, train_masks, val_masks = train_test_split(all_images, all_masks, test_size=0.2, random_state=42)

train_dataset = SatelliteDataset(train_imgs, train_masks, augment=True)
val_dataset = SatelliteDataset(val_imgs, val_masks)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [None]:
# ===================== Train SegFormer-B3 =====================
segformer = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b3-finetuned-ade-512-512",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
).to(DEVICE)

segformer_optimizer = optim.AdamW(segformer.parameters(), lr=LEARNING_RATE)
segformer_scheduler = optim.lr_scheduler.ReduceLROnPlateau(segformer_optimizer, mode='min', factor=0.5, patience=2, verbose=True)
train_model(segformer, train_loader, val_loader, combined_loss, segformer_optimizer, segformer_scheduler, name="segformer")

In [None]:
import torch.nn.functional as F

# ===================== Evaluation =====================
segformer.load_state_dict(torch.load("segformer_best.pth"))
segformer.eval()

# Show 15 predictions
num_samples = 15

for i in range(num_samples):
    sample_img, sample_mask = val_dataset[i]

    with torch.no_grad():
        pred_segformer = segformer(sample_img.unsqueeze(0).to(DEVICE))
        if isinstance(pred_segformer, dict):
            pred_segformer = pred_segformer["logits"]

        # Upsample to original size (512x512)
        pred_segformer = F.interpolate(pred_segformer, size=(512, 512), mode='bilinear', align_corners=False)
        pred_segformer = pred_segformer.argmax(1).squeeze().cpu()

    visualize_prediction(sample_img, pred_segformer, sample_mask)
