# Fine-tune SegFormer (B2/B5) on RELLIS-3D or RUGD

This notebook trains a larger SegFormer model on Google Colab Pro (A100 GPU).

## Setup
1. Upload your **preprocessed dataset** to Google Drive root (My Drive):
   - `rellis3d_processed.zip` or `rugd_processed.zip`
   - These should contain: `train/images/`, `train/labels/`, `val/images/`, `val/labels/`, `id2label.json`, `label2id.json`
2. Set the config in Cell 2 below
3. **Runtime → Change runtime type → A100 GPU**
4. Run all cells

## Estimated Training Time (B2 @ 512×512, 50 epochs)
- **A100 (Colab Pro):** ~1.5–2 hours
- **T4 (Colab Free):** ~3–4 hours

With Colab Pro you can close the tab — training continues in the background.

In [None]:
# Cell 1: Install dependencies
!pip install -q transformers datasets torch torchvision Pillow

In [None]:
# Cell 2: Configuration — EDIT THESE

DATASET = "rellis3d"          # "rellis3d" or "rugd"
MODEL_SIZE = "b2"             # "b0", "b1", "b2", "b3", "b4", "b5"
BASE_MODEL = "ade"            # "ade" (recommended for off-road) or "cityscapes"
IMAGE_SIZE = 768              # 768 recommended (near RUGD native res, good for RELLIS-3D too)
EPOCHS = 50
BATCH_SIZE = 4                # 4 to avoid system RAM spikes during checkpoint saves
LEARNING_RATE = 6e-5
FP16 = True                   # Mixed precision (faster on T4/A100)
GRADIENT_CHECKPOINTING = False # Set True if you get OOM errors

# Paths (relative to Google Drive)
DRIVE_ZIP = f"/content/drive/MyDrive/{DATASET}_processed.zip"  # uploaded zip
DATASET_DIR = f"/content/{DATASET}_processed"
OUTPUT_DIR = f"/content/drive/MyDrive/segformer_outputs/{DATASET}_segformer_{MODEL_SIZE}"

# Model ID
if BASE_MODEL == "cityscapes":
    MODEL_NAME = f"nvidia/segformer-{MODEL_SIZE}-finetuned-cityscapes-1024-1024"
else:
    MODEL_NAME = f"nvidia/segformer-{MODEL_SIZE}-finetuned-ade-512-512"

print(f"Dataset: {DATASET}")
print(f"Base model: {MODEL_NAME}")
print(f"Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"Epochs: {EPOCHS}, Batch: {BATCH_SIZE}, LR: {LEARNING_RATE}")
print(f"FP16: {FP16}, Gradient checkpointing: {GRADIENT_CHECKPOINTING}")

In [None]:
# Cell 3: Mount Google Drive and copy dataset to local disk
from google.colab import drive
drive.mount("/content/drive")

import os, shutil

# Copy from unzipped folder on Drive to local SSD (faster training)
DRIVE_DATASET = f"/content/drive/MyDrive/{DATASET}_processed"

if not os.path.isdir(DATASET_DIR):
    if os.path.isdir(DRIVE_DATASET):
        print(f"Copying {DRIVE_DATASET} to local disk...")
        shutil.copytree(DRIVE_DATASET, DATASET_DIR)
        print("Done.")
    else:
        print(f"ERROR: {DRIVE_DATASET} not found on Drive!")
        print("Available folders:")
        for item in sorted(os.listdir("/content/drive/MyDrive/")):
            if "rellis" in item.lower() or "rugd" in item.lower():
                print(f"  {item}")
else:
    print(f"Dataset already at {DATASET_DIR}")

# Verify structure
for sub in ["train/images", "train/labels", "val/images", "val/labels", "id2label.json"]:
    path = os.path.join(DATASET_DIR, sub)
    exists = os.path.exists(path)
    count = ""
    if exists and os.path.isdir(path):
        count = f" ({len(os.listdir(path))} files)"
    print(f"  {'OK' if exists else 'MISSING'}: {sub}{count}")

In [None]:
# Cell 4: Check GPU
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    vram = getattr(props, "total_memory", None) or getattr(props, "total_mem", 0)
    print(f"VRAM: {vram / 1e9:.1f} GB")
else:
    print("WARNING: No GPU! Go to Runtime → Change runtime type → A100 GPU")

In [None]:
# Cell 5: Load dataset and model
import json
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from transformers import (
    SegformerForSemanticSegmentation,
    SegformerImageProcessor,
    TrainingArguments,
    Trainer,
)

dataset_dir = Path(DATASET_DIR)

with open(dataset_dir / "id2label.json") as f:
    id2label = json.load(f)
with open(dataset_dir / "label2id.json") as f:
    label2id = json.load(f)

num_classes = len(id2label)
print(f"Classes ({num_classes}): {list(id2label.values())}")


class SegmentationDataset(Dataset):
    def __init__(self, image_dir, label_dir, processor, max_size=512):
        self.image_dir = Path(image_dir)
        self.label_dir = Path(label_dir)
        self.processor = processor
        self.max_size = max_size
        image_files = {f.stem: f for f in self.image_dir.iterdir()
                       if f.suffix.lower() in (".jpg", ".jpeg", ".png")}
        label_files = {f.stem: f for f in self.label_dir.iterdir()
                       if f.suffix.lower() == ".png"}
        common = sorted(set(image_files) & set(label_files))
        self.pairs = [(image_files[s], label_files[s]) for s in common]
        if not self.pairs:
            raise ValueError(f"No matching pairs in {image_dir} / {label_dir}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, lbl_path = self.pairs[idx]
        image = Image.open(img_path).convert("RGB")
        label = Image.open(lbl_path)
        image = image.resize((self.max_size, self.max_size), Image.BILINEAR)
        label = label.resize((self.max_size, self.max_size), Image.NEAREST)
        encoded = self.processor(images=image, return_tensors="pt")
        pixel_values = encoded["pixel_values"].squeeze(0)
        labels = torch.from_numpy(np.array(label, dtype=np.int64))
        return {"pixel_values": pixel_values, "labels": labels}


# Load processor and model
print(f"Loading {MODEL_NAME} ...")
processor = SegformerImageProcessor.from_pretrained(MODEL_NAME)
processor.do_reduce_labels = False

model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_NAME,
    num_labels=num_classes,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

if GRADIENT_CHECKPOINTING:
    model.gradient_checkpointing_enable()
    print("Gradient checkpointing enabled")

# Load datasets
train_dataset = SegmentationDataset(
    dataset_dir / "train" / "images",
    dataset_dir / "train" / "labels",
    processor, max_size=IMAGE_SIZE,
)
val_dataset = SegmentationDataset(
    dataset_dir / "val" / "images",
    dataset_dir / "val" / "labels",
    processor, max_size=IMAGE_SIZE,
)
print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")

In [None]:
# Cell 6: Train
import gc
import traceback
from datetime import datetime

LOG_FILE = "/content/drive/MyDrive/segformer_outputs/training_log.txt"

def log_to_drive(msg):
    """Append a message to a log file on Drive so we can see what happened."""
    with open(LOG_FILE, "a") as f:
        f.write(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {msg}\n")
    print(msg)

def compute_metrics_factory(num_classes, ignore_index=255):
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        h, w = labels.shape[1], labels.shape[2]
        logits_up = torch.nn.functional.interpolate(
            logits_tensor, size=(h, w), mode="bilinear", align_corners=False
        )
        preds = logits_up.argmax(dim=1).numpy()
        del logits_tensor, logits_up, logits

        mask = labels != ignore_index
        preds_valid = preds[mask]
        labels_valid = labels[mask]
        del preds, labels, mask

        per_class_iou = []
        per_class_acc = []
        for cls in range(num_classes):
            pred_cls = preds_valid == cls
            label_cls = labels_valid == cls
            intersection = np.logical_and(pred_cls, label_cls).sum()
            union = np.logical_or(pred_cls, label_cls).sum()
            if union > 0:
                per_class_iou.append(intersection / union)
            if label_cls.sum() > 0:
                per_class_acc.append(intersection / label_cls.sum())

        mean_iou = float(np.mean(per_class_iou)) if per_class_iou else 0.0
        mean_acc = float(np.mean(per_class_acc)) if per_class_acc else 0.0
        overall_acc = float((preds_valid == labels_valid).sum() / max(len(labels_valid), 1))

        del preds_valid, labels_valid
        gc.collect()

        return {
            "mean_iou": mean_iou,
            "mean_accuracy": mean_acc,
            "overall_accuracy": overall_acc,
        }
    return compute_metrics


os.makedirs(OUTPUT_DIR, exist_ok=True)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="polynomial",
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_dir=os.path.join(OUTPUT_DIR, "logs"),
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="mean_iou",
    greater_is_better=True,
    fp16=FP16 and torch.cuda.is_available(),
    gradient_accumulation_steps=1,
    eval_accumulation_steps=2,
    dataloader_num_workers=0,
    remove_unused_columns=False,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics_factory(num_classes),
)

# Auto-resume from latest checkpoint if one exists on Drive
import glob
checkpoints = sorted(glob.glob(os.path.join(OUTPUT_DIR, "checkpoint-*")))
if checkpoints:
    log_to_drive(f"Resuming from {checkpoints[-1]}")
    resume_from = checkpoints[-1]
else:
    log_to_drive("No checkpoints found, starting fresh")
    resume_from = None

log_to_drive(f"=== Starting training: {DATASET} {MODEL_SIZE} {BASE_MODEL} {IMAGE_SIZE}px batch={BATCH_SIZE} ===")

try:
    trainer.train(resume_from_checkpoint=resume_from)
    log_to_drive("Training completed successfully!")
except Exception as e:
    log_to_drive(f"TRAINING CRASHED: {type(e).__name__}: {e}")
    log_to_drive(traceback.format_exc())
    raise  # still show the error in the notebook

In [None]:
# Cell 7: Save model, evaluate, and log results
import shutil
import json
from datetime import datetime

print(f"Saving model to {OUTPUT_DIR} ...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

# Copy class mappings alongside model
shutil.copy2(dataset_dir / "id2label.json", Path(OUTPUT_DIR) / "id2label.json")
shutil.copy2(dataset_dir / "label2id.json", Path(OUTPUT_DIR) / "label2id.json")

print("\n=== Final evaluation ===")
metrics = trainer.evaluate()
print(metrics)

# Save results to a JSON file alongside the model
results = {
    "dataset": DATASET,
    "model_size": MODEL_SIZE,
    "base_model": BASE_MODEL,
    "model_name": MODEL_NAME,
    "image_size": IMAGE_SIZE,
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "learning_rate": LEARNING_RATE,
    "fp16": FP16,
    "gradient_checkpointing": GRADIENT_CHECKPOINTING,
    "num_classes": num_classes,
    "train_samples": len(train_dataset),
    "val_samples": len(val_dataset),
    "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
    "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
    "final_metrics": {
        "eval_loss": metrics["eval_loss"],
        "mean_iou": metrics["eval_mean_iou"],
        "mean_accuracy": metrics["eval_mean_accuracy"],
        "overall_accuracy": metrics["eval_overall_accuracy"],
    },
}

results_path = os.path.join(OUTPUT_DIR, "training_results.json")
with open(results_path, "w") as f:
    json.dump(results, f, indent=2)
print(f"\nResults saved to: {results_path}")

# Also append to a central log file on Drive
log_path = "/content/drive/MyDrive/segformer_outputs/all_results.json"
all_results = []
if os.path.exists(log_path):
    with open(log_path) as f:
        all_results = json.load(f)
all_results.append(results)
with open(log_path, "w") as f:
    json.dump(all_results, f, indent=2)
print(f"Appended to central log: {log_path}")

print(f"\nDone! Model saved to Google Drive: {OUTPUT_DIR}")
print("Download the folder and place it in training/models/<name> to use in CARLA.")

In [None]:
# Cell 8 (Optional): Zip the model for easy download
zip_name = f"{DATASET}_segformer_{MODEL_SIZE}"
zip_path = f"/content/drive/MyDrive/{zip_name}.zip"

print(f"Zipping model to {zip_path} ...")
!cd /content/drive/MyDrive/segformer_outputs && zip -r "/content/drive/MyDrive/{zip_name}.zip" "{zip_name}/"
print(f"Done! Download {zip_path} from Google Drive.")