In [None]:
"""
Underwater Semantic Segmentation using SegFormer
Dataset: AI Summit Track B - Underwater Imagery
Resolution: 320x256 for consistency
Split: 80% train, 20% validation
"""

import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import SegformerConfig, SegformerForSemanticSegmentation
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score


In [None]:
# Dataset paths
DATASET_ROOT = "./rclone-v1.73.0-linux-amd64/TrackB/dataset"
IMG_DIR = os.path.join(DATASET_ROOT, "images")
MASK_DIR = os.path.join(DATASET_ROOT, "masks/combined")

# Model & training config
IMG_HEIGHT = 256
IMG_WIDTH = 320
NUM_CLASSES = 8
BATCH_SIZE = 4
EPOCHS = 30
LEARNING_RATE = 6e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Best model save path
BEST_MODEL_PATH = "./best_segformer_underwater.pth"

print(f"Device: {DEVICE}")
print(f"Images: {IMG_DIR}")
print(f"Masks: {MASK_DIR}")


In [None]:
COLOR_MAP = {
    (0, 0, 0): 0,           # Background/Sea-floor
    (255, 0, 0): 1,         # Fish
    (0, 255, 0): 2,         # Reefs
    (0, 0, 255): 3,         # Plants
    (255, 255, 0): 4,       # Wrecks
    (255, 0, 255): 5,       # Divers
    (0, 255, 255): 6,       # Robots
    (128, 128, 128): 7      # Others
}

IDX_TO_COLOR = {v: k for k, v in COLOR_MAP.items()}


In [None]:
def rgb_to_label(mask):
    """Convert RGB mask to integer label mask"""
    mask = np.array(mask)
    label = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)

    for rgb, idx in COLOR_MAP.items():
        matches = np.all(mask == np.array(rgb), axis=-1)
        label[matches] = idx

    return label


In [None]:
class UnderwaterSegDataset(Dataset):
    def __init__(self, img_dir, mask_dir, pairs):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_name, mask_name = self.pairs[idx]

        img = Image.open(os.path.join(self.img_dir, img_name)).convert("RGB")
        mask = Image.open(os.path.join(self.mask_dir, mask_name)).convert("RGB")

        img = img.resize((IMG_WIDTH, IMG_HEIGHT), Image.BILINEAR)
        mask = mask.resize((IMG_WIDTH, IMG_HEIGHT), Image.NEAREST)

        img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
        mask = torch.from_numpy(rgb_to_label(mask)).long()

        return img, mask


In [None]:
def build_pairs(img_dir, mask_dir):
    imgs = sorted([f for f in os.listdir(img_dir)
                   if f.lower().endswith(('.jpg', '.png', '.jpeg'))])
    masks = set(os.listdir(mask_dir))

    pairs = []
    for img in imgs:
        base = os.path.splitext(img)[0]
        for ext in [".bmp", ".png", ".jpg"]:
            mask_name = base + ext
            if mask_name in masks:
                pairs.append((img, mask_name))
                break

    print(f"Found {len(pairs)} image-mask pairs")
    return pairs


In [None]:
all_pairs = build_pairs(IMG_DIR, MASK_DIR)

train_pairs, val_pairs = train_test_split(
    all_pairs,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

print(f"Train samples: {len(train_pairs)}")
print(f"Val samples: {len(val_pairs)}")

train_ds = UnderwaterSegDataset(IMG_DIR, MASK_DIR, train_pairs)
val_ds = UnderwaterSegDataset(IMG_DIR, MASK_DIR, val_pairs)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=2,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


In [None]:
config = SegformerConfig.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512",
    num_labels=NUM_CLASSES
)

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512",
    config=config,
    ignore_mismatched_sizes=True,
    use_safetensors=True
).to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

print("Model loaded: SegFormer-B2")


In [None]:
def mean_iou(pred, target, num_classes=NUM_CLASSES):
    ious = []
    for c in range(num_classes):
        p = pred == c
        t = target == c
        inter = (p & t).sum().item()
        union = (p | t).sum().item()
        ious.append(1.0 if union == 0 else inter / union)
    return np.mean(ious)


def mean_f1(pred, target):
    return f1_score(
        target.cpu().numpy().flatten(),
        pred.cpu().numpy().flatten(),
        average="macro",
        zero_division=1
    )


In [None]:
def train_one_epoch(loader):
    model.train()
    total_loss = 0

    pbar = tqdm(loader, desc="Training")
    for imgs, masks in pbar:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(pixel_values=imgs, labels=masks)
        loss = outputs.loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    return total_loss / len(loader)


In [None]:
def evaluate(loader):
    model.eval()
    miou_list, f1_list = [], []

    with torch.no_grad():
        for imgs, masks in tqdm(loader, desc="Evaluating"):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

            outputs = model(pixel_values=imgs)
            logits = outputs.logits

            logits = F.interpolate(
                logits,
                size=(IMG_HEIGHT, IMG_WIDTH),
                mode="bilinear",
                align_corners=False
            )

            preds = torch.argmax(logits, dim=1)

            miou_list.append(mean_iou(preds, masks))
            f1_list.append(mean_f1(preds, masks))

    return np.mean(miou_list), np.mean(f1_list)


In [None]:
best_miou = 0.0

print("\n" + "="*60)
print("TRAINING START")
print("="*60 + "\n")

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 40)

    train_loss = train_one_epoch(train_loader)
    val_miou, val_f1 = evaluate(val_loader)

    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val mIoU:   {val_miou:.4f}")
    print(f"Val F1:     {val_f1:.4f}")

    if val_miou > best_miou:
        best_miou = val_miou
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'miou': val_miou,
            'f1': val_f1,
        }, BEST_MODEL_PATH)
        print(f"âœ“ Best model saved! (mIoU: {val_miou:.4f})")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print(f"Best Validation mIoU: {best_miou:.4f}")
print(f"Best model saved at: {BEST_MODEL_PATH}")
print("="*60)
