# Food Vision With SegFormer and Hugging Face

# Setup

In [1]:
import os
import random
import time
from typing import Tuple, Optional, Dict, List
from enum import Enum
import json
# Essentials
import cv2
import numpy as np
import matplotlib.pyplot as plt
# Pytorch
import torch
import torchvision
# Hugging Face
from datasets import load_dataset


## Install required libraries

In [2]:
from packaging import version

# For this notebook to run with updated APIs, we need torch 1.12+ and torchvision 0.13+
required_torch = "1.12.0"
required_torchvision = "0.13.0"

if version.parse(torch.__version__) < version.parse(required_torch) or \
   version.parse(torchvision.__version__) < version.parse(required_torchvision):
    print("[INFO] torch/torchvision versions not as required, installing latest versions.")
    # You can change cu121 to cu118 or cpu as needed
    import os
    !pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
    # Reload torch and torchvision to reflect new versions
    import importlib
    importlib.reload(torch)
    importlib.reload(torchvision)

print(f"‚úÖ torch version: {torch.__version__}")
print(f"‚úÖ torchvision version: {torchvision.__version__}")

‚úÖ torch version: 2.8.0+cu126
‚úÖ torchvision version: 0.23.0+cu126


## Setup Platform

## Mount Drive

Uncomment the lines to mount google drive if using **google colab**.

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Setup Device-Agnostic Code

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
print(f"‚úÖ Using Device: {DEVICE}")

‚úÖ Using Device: cpu


## Setup Seeds

In [6]:
def set_seeds(seed: int = 42):
    """Sets seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)

In [7]:
set_seeds()

# Data Preparation

## Dataset Structure

In [8]:
import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Optional
from transformers import BaseImageProcessor
from torchvision import transforms

class FoodSeg103Dataset(Dataset):
    """
    A robust PyTorch Dataset wrapper for the Hugging Face FoodSeg103 dataset.

    Each sample includes:
        - image (PIL Image)
        - label (segmentation mask as PIL Image)
        - classes_on_image (list of class IDs present)
        - id (int)

    Supports:
        - Hugging Face AutoImageProcessor (e.g., SegFormer processor)
        - Optional torchvision transforms (fallback)
    """

    def __init__(
        self,
        hf_dataset,
        processor: Optional[BaseImageProcessor] = None,
        transform: Optional[transforms.Compose] = None,
    ):
        """
        Args:
            hf_dataset: Hugging Face dataset split (e.g., from `datasets.load_dataset`).
            processor: Optional Hugging Face processor (handles resizing, normalization, etc.).
            transform: Optional torchvision-style transform (used when no processor is provided).
        """
        self.dataset = hf_dataset
        self.processor = processor
        self.transform = transform

        # Safety check: warn if neither processor nor transform is provided
        if self.processor is None and self.transform is None:
            print("‚ö†Ô∏è Warning: No processor or transform provided. "
                  "Images will remain as PIL and may cause DataLoader errors.")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Fetch and preprocess one sample from the dataset.
        Returns:
            (pixel_values, labels): both torch.Tensors
        """
        item = self.dataset[idx]
        image = item["image"]   # PIL Image
        label = item["label"]   # PIL Image mask

        # --- Option 1: Use Hugging Face processor (preferred for SegFormer) ---
        if self.processor is not None:
            encoded = self.processor(image, segmentation_maps=label, return_tensors="pt")
            pixel_values = encoded["pixel_values"].squeeze(0)
            labels = encoded["labels"].squeeze(0).long()
            return pixel_values, labels

        # --- Option 2: Use torchvision transforms manually ---
        if self.transform:
            image = self.transform(image)
        else:
            # Fallback: convert PIL ‚Üí Tensor to avoid DataLoader crashes
            image = transforms.ToTensor()(image)

        label = np.array(label, dtype=np.int64)
        label = torch.as_tensor(label, dtype=torch.long)

        return image, label

    def __repr__(self):
        """Return a clean summary of the dataset configuration."""
        transform_str = str(self.transform) if self.transform else "None"
        processor_str = self.processor.__class__.__name__ if self.processor else "None"
        split_name = getattr(self.dataset, "split", "unknown")
        return (
            f"Dataset: FoodSeg103\n"
            f"    Number of datapoints: {len(self)}\n"
            f"    Split: {split_name}\n"
            f"    Processor: {processor_str}\n"
            f"    Transform: {transform_str}\n"
        )


## Class Mappings

In [9]:
def load_class_mappings(json_path: str) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
    """
    Load id2label and label2id mappings from a JSON file.

    Args:
        json_path (str): Path to the JSON file like:
            {
              "0": "background",
              "1": "candy",
              "2": "egg tart",
              ...
            }

    Returns:
        tuple: (id2label: dict[int, str], label2id: dict[str, int])
    """
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Convert string keys to integers for id2label
    id2label = {int(k): v for k, v in data.items()}
    label2id = {v: int(k) for k, v in id2label.items()}

    return id2label, label2id

In [10]:
id2label, label2id = load_class_mappings("/content/drive/MyDrive/datasets/foodSeg103/classnames.json")

print("First 5 id2label:")
print(dict(list(id2label.items())[:5]))

print("\nFirst 5 label2id:")
print(dict(list(label2id.items())[:5]))

First 5 id2label:
{0: 'background', 1: 'candy', 2: 'egg tart', 3: 'french fries', 4: 'chocolate'}

First 5 label2id:
{'background': 0, 'candy': 1, 'egg tart': 2, 'french fries': 3, 'chocolate': 4}


## Create DataLoaders

In [11]:
class Split(Enum):
    TRAIN = "train"
    VALIDATION = "validation"
    TEST = "test"
    ALL = "all"

def create_dataloaders_from_hf(
    dataset_name: str,
    split: Split = Split.ALL,
    processor: Optional[BaseImageProcessor] = None,
    batch_size: int = 4,
    num_workers: int = 2,
    device: str = None
) -> Tuple[Optional[torch.utils.data.DataLoader], Optional[torch.utils.data.DataLoader], Optional[torch.utils.data.DataLoader]]:
    """
    Create PyTorch DataLoaders from a Hugging Face dataset like FoodSeg103.

    Args:
        dataset_name (str): Hugging Face dataset name.
        transform: torchvision transforms to perform on data.
        split (Split): Which split(s) to load (TRAIN, VALIDATION, TEST, or ALL).
        batch_size (int): Batch size for DataLoaders.
        num_workers (int): Number of DataLoader workers.
    Returns:
        (train_loader, val_loader, test_loader): Tuple of DataLoaders.
        Unused splits will be returned as None.
    """

    def build_loader(split_name: str, is_train: bool = False, device:str = None) -> torch.utils.data.DataLoader:
        ds = load_dataset(dataset_name, split=split_name)
        dataset = FoodSeg103Dataset(ds, processor=processor)
        device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        pin_mem = True if device == "cuda" else False
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=is_train,
            drop_last=is_train,
            num_workers=num_workers,
            pin_memory=pin_mem,
        )

    # --- Load splits ---
    train_loader = val_loader = test_loader = None
    if split in (Split.TRAIN, Split.ALL):
        train_loader = build_loader("train", is_train=True)

    if split in (Split.VALIDATION, Split.ALL):
        val_loader = build_loader("validation")

    if split in (Split.TEST, Split.ALL):
        try:
            test_loader = build_loader("test")
        except ValueError:
            print(f"‚ö†Ô∏è No 'test' split found in dataset: {dataset_name}")

    return train_loader, val_loader, test_loader

# Model Pipeline

## SegFormer Model

In [12]:
from transformers import SegformerForSemanticSegmentation, SegformerConfig, SegformerImageProcessor, BaseImageProcessor

class Segformer:
    """
    A flexible builder for creating and configuring SegFormer models
    from the Hugging Face Transformers library.
    """

    def __init__(
        self,
        model_name: str = "nvidia/segformer-b0-finetuned-ade-512-512",
        num_classes: int = 104,
        ignore_mismatched_sizes: bool = True,
        dropout: float = 0.1,
        device: str = "cuda",
    ):
        """
        Initialize the SegformerBuilder.

        Args:
            model_name (str): Pretrained SegFormer model name or path.
            num_classes (int): Number of segmentation classes.
            ignore_mismatched_sizes (bool): Allow different output head sizes.
            use_pretrained (bool): Whether to load pretrained weights.
            device (str): Device to load model on ("cuda" or "cpu").
        """
        self.model_name = model_name
        self.num_classes = num_classes
        self.ignore_mismatched_sizes = ignore_mismatched_sizes
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")

        # Load model config Parameters
        id2label, label2id = load_class_mappings("/content/drive/MyDrive/datasets/foodSeg103/classnames.json")
        self.config = SegformerConfig.from_pretrained(
            model_name,
            num_labels=num_classes,
            hidden_dropout_prob=dropout,
            id2label=id2label,
            label2id=label2id
        )

        # Load model weights
        self.model = SegformerForSemanticSegmentation.from_pretrained(
            model_name,
            config=self.config,
            ignore_mismatched_sizes=ignore_mismatched_sizes
        )

        # Load image processor (for transforms)
        self.processor = SegformerImageProcessor.from_pretrained(model_name)

        # Set model to device
        self.model.to(self.device)

    def get_model(self) -> torch.nn.Module:
        """Return the PyTorch model."""
        return self.model

    def get_processor(self):
        """Return the Hugging Face image processor for transforms."""
        return self.processor

    def freeze_encoder(self, freeze=True):
        """Freeze or unfreeze the encoder layers."""
        for param in self.model.segformer.encoder.parameters():
            param.requires_grad = not freeze
        print(f"Encoder frozen: {freeze}")

    def freeze_decoder(self, freeze=True):
        """Freeze or unfreeze the decoder layers."""
        for param in self.model.decode_head.parameters():
            param.requires_grad = not freeze
        print(f"Decoder frozen: {freeze}")

    def unfreeze_all(self):
        """Unfreeze all parameters."""
        for param in self.model.parameters():
            param.requires_grad = True
        print("All model parameters are trainable.")

    def summary(self):
        """Print basic info about the model."""
        print("\nüß© SegFormer Model Summary\n")
        print(f"Model: {self.model_name}")
        print(f"Classes: {self.num_classes}")
        print(f"Device: {self.device}")
        print(f"Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        print(f"Trainable: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}")
        # Add shape info
        size = self.processor.size.get("height", 512)
        dummy = torch.randn(1, 3, size, size).to(self.device)
        with torch.no_grad():
            out = self.model(pixel_values=dummy)
        print(f"Output keys: {list(out.keys())}")
        print(f"Logits shape: {out.logits.shape}")

## Training Engine

### Visualization Helpers

In [13]:
def visualize_image_with_mask(image, mask=None, alpha=0.5):
    """
    Visualize an image, its segmentation mask, and overlay.

    Args:
        image (np.ndarray or PIL.Image.Image): RGB image.
        mask (np.ndarray or PIL.Image.Image): Segmentation mask (H x W).
        alpha (float): Transparency for overlay blending.
    """
    # Convert PIL ‚Üí NumPy if needed
    img = np.array(image)
    if img.dtype != np.uint8:
        img = (img * 255).astype(np.uint8)
    h, w = img.shape[:2]

    mask_np = None
    if mask is not None:
        mask_np = np.array(mask, dtype=np.int64)

        # --- Resize mask to match image if needed ---
        if mask_np.shape[:2] != (h, w):
            mask_np = cv2.resize(
                mask_np, (w, h), interpolation=cv2.INTER_NEAREST
            )

        # --- Random color map for classes ---
        unique_classes = np.unique(mask_np)
        color_map = {cls: np.random.randint(0, 255, size=3) for cls in unique_classes if cls != 0}

        color_mask = np.zeros_like(img)
        for cls_id, color in color_map.items():
            color_mask[mask_np == cls_id] = color

        # Blend image + color mask
        blended = cv2.addWeighted(img, 1 - alpha, color_mask, alpha, 0)
    else:
        blended = img

    # --- Plot all ---
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    if mask_np is not None:
        plt.imshow(mask_np, cmap="tab20")
        plt.title("Segmentation Mask (Class IDs)")
    else:
        plt.text(0.5, 0.5, "No mask", ha="center", va="center")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(blended)
    plt.title("Overlay (Image + Mask)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

    # --- Metadata ---
    if mask_np is not None:
        coverage = np.count_nonzero(mask_np) / mask_np.size * 100
        print(f"‚úÖ Mask Coverage: {coverage:.2f}% of image")
        print(f"üü¢ Unique classes: {np.unique(mask_np).tolist()}")
        print(f"üìè Image size: {img.shape}")
        print(f"üé≠ Mask size: {mask_np.shape}")

In [14]:
def unnormalize_image(tensor: torch.Tensor, mean: Tuple[float, float, float]=(0.485, 0.456, 0.406), std: Tuple[float, float, float]=(0.229, 0.224, 0.225)):
    """
    Reverse ImageNet normalization to make images display correctly.

    Args:
        tensor (torch.Tensor): Normalized image tensor of shape (C, H, W)
        mean (tuple): Mean used in normalization
        std (tuple): Std used in normalization

    Returns:
        torch.Tensor: Unnormalized image tensor, values clamped to [0, 1]
    """
    tensor = tensor.clone().cpu()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return torch.clamp(tensor, 0, 1)


def plot_segmentation(images: torch.Tensor, masks: torch.Tensor, preds: torch.Tensor, n:int=4, save_path: str=None, unnormalize:bool=True):
    """
    Visualize input images, ground truth masks, and model predictions.

    Args:
        images (torch.Tensor): Batch of input images (B, C, H, W)
        masks (torch.Tensor): Ground truth segmentation masks (B, H, W)
        preds (torch.Tensor): Predicted segmentation masks (B, H, W)
        n (int): Number of samples to visualize
        save_path (str): Optional file path to save visualization
        unnormalize (bool): Whether to unnormalize images before plotting
    """
    # Ensure tensors are on CPU and detached
    images, masks, preds = images.cpu(), masks.cpu(), preds.cpu()

    plt.figure(figsize=(12, n * 3))

    for i in range(min(n, images.size(0))):
        # --- Reverse normalization for accurate visualization ---
        if unnormalize:
            img = unnormalize_image(images[i])
        else:
            img = torch.clamp(images[i], 0, 1)

        img = img.permute(1, 2, 0).numpy()  # Convert from CHW ‚Üí HWC

        # --- Original Image ---
        plt.subplot(n, 3, i * 3 + 1)
        plt.imshow(img)
        plt.title("Original Image")
        plt.axis("off")

        # --- Ground Truth Mask ---
        plt.subplot(n, 3, i * 3 + 2)
        plt.imshow(masks[i], cmap="tab20")
        plt.title("Ground Truth")
        plt.axis("off")

        # --- Prediction Mask ---
        plt.subplot(n, 3, i * 3 + 3)
        plt.imshow(preds[i], cmap="tab20")
        plt.title("Prediction")
        plt.axis("off")

    plt.tight_layout()

    # --- Optionally save figure ---
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
        print(f"üñºÔ∏è Saved segmentation visualization to: {save_path}")

    plt.show()


def plot_training_curves(results, save_path=None):
    """Visualize training metrics (loss, acc, mIoU) over epochs."""
    train_loss, test_loss = results["train_loss"], results["test_loss"]
    train_acc, test_acc = results["train_acc"], results["test_acc"]
    train_iou, test_iou = results["train_iou"], results["test_iou"]
    epochs = range(1, len(train_loss) + 1)

    plt.figure(figsize=(16, 5))

    # ---- Loss Plot ----
    plt.subplot(1, 3, 1)
    plt.plot(epochs, train_loss, "b-", label="Train Loss")
    plt.plot(epochs, test_loss, "r-", label="Test Loss")
    plt.title("Loss per Epoch")
    plt.xlabel("Epoch")
    plt.legend()

    # ---- Accuracy Plot ----
    plt.subplot(1, 3, 2)
    plt.plot(epochs, train_acc, "b-", label="Train Acc")
    plt.plot(epochs, test_acc, "r-", label="Test Acc")
    plt.title("Accuracy per Epoch")
    plt.xlabel("Epoch")
    plt.legend()

    # ---- mIoU Plot ----
    plt.subplot(1, 3, 3)
    plt.plot(epochs, train_iou, "b-", label="Train mIoU")
    plt.plot(epochs, test_iou, "r-", label="Test mIoU")
    plt.title("Mean IoU per Epoch")
    plt.xlabel("Epoch")
    plt.legend()

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
        print(f"‚úÖ Saved training curves to {save_path}")
    plt.show()

In [15]:
def display_random_sample(dataset: torch.utils.data.Dataset, seed: int = 42):
    """
    Displays a random image and its mask from a PyTorch dataset.

    Args:
        dataset (Dataset): A PyTorch dataset that returns (image, mask) pairs.
    """
    # Pick random index
    idx = random.randint(0, len(dataset) - 1)
    image, mask = dataset[idx]

    # If tensors ‚Üí convert to NumPy (for plotting)
    if isinstance(image, torch.Tensor):
        image_np = image.permute(1, 2, 0).numpy()  # CHW ‚Üí HWC
        image_np = np.clip(image_np, 0, 1)
    else:
        image_np = np.array(image)

    if isinstance(mask, torch.Tensor):
        mask_np = mask.numpy()
    else:
        mask_np = np.array(mask)

    # --- Print metadata ---
    print(f"üñºÔ∏è Sample Index: {idx}")
    # --- Visualize using helper ---
    visualize_image_with_mask(image_np, mask_np)

### Metric Helpers

In [16]:
def pixel_accuracy(preds, labels):
    """Compute per-pixel accuracy."""
    correct = (preds == labels).sum().item()
    total = labels.numel()
    return correct / total


def intersection_over_union(preds, labels, num_classes: int, ignore_index:int=255):
    """Compute mean Intersection-over-Union (mIoU)."""
    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    ious = []

    for cls in range(num_classes):
        if cls == ignore_index:
            continue
        pred_inds = preds == cls
        label_inds = labels == cls
        intersection = np.logical_and(pred_inds, label_inds).sum()
        union = np.logical_or(pred_inds, label_inds).sum()
        if union == 0:
            continue
        ious.append(intersection / union)

    return np.mean(ious) if ious else 0.0

### Checkpoint Utilities

In [17]:
def save_checkpoint(
    model: torch.nn.Module,
    optimizer,
    scheduler,
    epoch: int,
    best_miou,
    path: Optional[str] = "ckpts",
    filename: Optional[str] = None,
):
    """
    Save model, optimizer, scheduler, and mIoU to a checkpoint file.
    The file will be named as segformer_finetuned_{epoch}.ckpt under
    ckpt directory by default
    """
    # Create Directory Path
    os.makedirs(path, exist_ok=True)
    # Filename
    filename = f"segformer_finetuned_{epoch}.ckpt"
    checkpoint_path = os.path.join(path, filename)
    # Save Model
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict() if optimizer else None,
        "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
        "best_miou": best_miou,
    }, checkpoint_path)
    # Log
    print(f"‚úÖ Checkpoint saved at: {checkpoint_path} with mIoU: {best_miou}")


def load_checkpoint(
    model: torch.nn.Module,
    path: str,
    optimizer=None,
    scheduler=None,
    device: Optional[str] = "cuda"):
    """Load model and optimizer states from a checkpoint."""
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    if scheduler and checkpoint["scheduler_state_dict"]:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    # Log
    print(f"‚úÖ Loaded checkpoint from epoch {checkpoint['epoch'] + 1}, "
          f"best mIoU={checkpoint['results']:.4f}")
    return checkpoint

### Trainning Loops

In [18]:
import time
from tqdm import tqdm
from torch.amp import autocast, GradScaler

# ============================================================
# üîπ Training / Evaluation Steps (with AMP)
# ============================================================

def train_step(model, dataloader: torch.utils.data.DataLoader, scaler, optimizer, loss_fn, num_classes:int, device:str):
    """Run one training epoch."""
    model.train()
    train_loss, train_acc, train_iou = 0, 0, 0

    # Automatically choose AMP device (cuda/cpu)
    autocast_device = "cuda" if torch.cuda.is_available() else "cpu"

    for images, labels in tqdm(dataloader, leave=False, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        # ---- Forward + Backward with AMP ----
        with autocast(device_type=autocast_device, enabled=torch.cuda.is_available()):
            outputs = model(pixel_values=images)
            logits = torch.nn.functional.interpolate(
                outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
            loss = loss_fn(logits, labels)

        # ---- Gradient Scaling ----
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        # ---- Metrics ----
        train_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        train_acc += pixel_accuracy(preds, labels)
        train_iou += intersection_over_union(preds, labels, num_classes)

    n = len(dataloader)
    return train_loss / n, train_acc / n, train_iou / n


def eval_step(model, dataloader: torch.utils.data.DataLoader, loss_fn, num_classes:int, device:str):
    """Evaluate model for one epoch."""
    model.eval()
    test_loss, test_acc, test_iou = 0, 0, 0

    with torch.inference_mode():
        for images, labels in tqdm(dataloader, leave=False, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(pixel_values=images)
            logits = torch.nn.functional.interpolate(
                outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
            loss = loss_fn(logits, labels)
            preds = torch.argmax(logits, dim=1)

            # ---- Metrics ----
            test_loss += loss.item()
            test_acc += pixel_accuracy(preds, labels)
            test_iou += intersection_over_union(preds, labels, num_classes)

    n = len(dataloader)
    return test_loss / n, test_acc / n, test_iou / n

# ============================================================
# üîπ Full Training Loop (with Auto-Resume + Visualization)
# ============================================================

def train(model, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader, optimizer, loss_fn, epochs: int, device: str,
          num_classes:int, scheduler=None, save_dir:str="checkpoints", min_best_miou: float =0.0, vis_every:int=5, auto_save: bool = True):
    """
    Full training loop for semantic segmentation with:
    - AMP training
    - Auto resume
    - Periodic visualization and checkpoints
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scaler = GradScaler() if torch.cuda.is_available() else None
    model.to(device)
    best_miou = 0.0

    # ---- Auto Resume ----
    best_ckpt_path = os.path.join(save_dir, "best_segformer.ckpt")
    if os.path.exists(best_ckpt_path):
        ckpt = load_checkpoint(model, optimizer, scheduler, path=best_ckpt_path)
        best_miou = ckpt["best_miou"]

    # ---- Training Logs ----
    results = {k: [] for k in ["train_loss", "train_acc", "train_iou",
                               "test_loss", "test_acc", "test_iou", "epoch_time"]}

    # ---- Training Loop ----
    for epoch in range(epochs):
        start_time = time.perf_counter()
        print(f"\nüå± Epoch [{epoch+1}/{epochs}]")

        # ---- Train and Evaluate ----
        train_loss, train_acc, train_iou = train_step(model, train_dataloader,scaler, optimizer, loss_fn, num_classes, device)
        test_loss, test_acc, test_iou = eval_step(model, test_dataloader, loss_fn, num_classes, device)

        if scheduler:
            scheduler.step()

        # ---- Logging ----
        epoch_time = time.perf_counter() - start_time
        results["epoch_time"].append(epoch_time)
        tqdm.write(
            f"üïí {epoch_time:.2f}s\n"
            f"üìà Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | mIoU: {train_iou:.4f}\n"
            f"üìà Validation Loss: {test_loss:.4f} | Acc: {test_acc:.4f} | mIoU: {test_iou:.4f}\n"
        )

        # ---- Save Best Model ----
        if test_iou > best_miou and test_iou >= min_best_miou:
            best_miou = test_iou
            save_checkpoint(model, optimizer, scheduler, epoch, best_miou, path=save_dir)

        # ---- Periodic Tasks ----
        if (epoch + 1) % vis_every == 0:
            # Visualize Predictions after vis_every interval
            model.eval()
            images, labels = next(iter(test_dataloader))
            images, labels = images.to(device), labels.to(device)

            with torch.inference_mode():
                outputs = model(pixel_values=images)
                logits = torch.nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
                preds = torch.argmax(logits, dim=1)

            plot_segmentation(images, labels, preds, n=min(4, images.size(0)),
                              save_path=os.path.join(save_dir, f"vis_epoch_{epoch+1}.png"))
            plot_training_curves(results, save_path=os.path.join(save_dir, "training_curves.png"))
            # Save Checkpoint after every vis_every
            save_checkpoint(model, optimizer, scheduler, epoch, best_miou, path=save_dir, filename=f"segformer_epoch_{epoch+1}.ckpt")

        # ---- Record Metrics ----
        for k, v in zip(["train_loss", "train_acc", "train_iou", "test_loss", "test_acc", "test_iou"],
                        [train_loss, train_acc, train_iou, test_loss, test_acc, test_iou]):
            results[k].append(v)

    print("\nüéØ Training complete!")
    print(f"üèÜ Best Validation mIoU: {best_miou:.4f}\n")
    avg_time = np.mean(results["epoch_time"])
    print(f"‚è±Ô∏è Average Epoch Time: {avg_time:.2f} sec")

    return results


## Fine Tunning

In [None]:
EPOCHS = 30
LR = 1e-4
WEIGHT_DECAY = 1e-4
T_MAX = 20
IGNORE_INDEX = 255
CHECKPOINT_DIR = "drive/MyDrive/checkpoints"
NUM_WORKERS = os.cpu_count() or 2 # Use 2 Workers as default
BATCH_SIZE = 8
NUM_CLASSES = 104
DROP_RATE = 0.1

# Initialize model and builder
builder = Segformer(num_classes=NUM_CLASSES, device=DEVICE, dropout=DROP_RATE)
model = builder.get_model()
# Get Image Processor from builder
processor = builder.get_processor()

# Initialize Dataloaders
train_loader, validation_loader, _ = create_dataloaders_from_hf(dataset_name="EduardoPacheco/FoodSeg103", batch_size=BATCH_SIZE, processor=processor)

# Loss (ignore_index=255 if present in dataset)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)

# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_MAX)

# Model Summary
print(builder.summary())

# Fine-tune
if __name__ == "__main__":
    results = train(
    model=model,
    train_dataloader=train_loader,
    test_dataloader=validation_loader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=EPOCHS,
    device=DEVICE,
    num_classes=NUM_CLASSES,
    scheduler=scheduler,
    save_dir=CHECKPOINT_DIR,
    vis_every=8,  # visualize predictions every 8 epochs
)

# Later (for inference or resuming)
# load_checkpoint(model, path="./ckspts/segformer_best.ckpt", device=builder.device)
# model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([104]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([104, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference

‚ö†Ô∏è No 'test' split found in dataset: EduardoPacheco/FoodSeg103

üß© SegFormer Model Summary

Model: nvidia/segformer-b0-finetuned-ade-512-512
Classes: 104
Device: cpu
Parameters: 3,740,872
Trainable: 3,740,872
Output keys: ['logits']
Logits shape: torch.Size([1, 104, 128, 128])
None

üå± Epoch [1/30]


Training:   1%|          | 6/622 [03:32<5:42:27, 33.36s/it]

## Save Model Checkpoint Manually

The model can be saved via `model.save_pretrained("./segformer_finetuned")` or `builder.get_processor().save_pretrained("./segformer_finetuned")`. any of the method works.

You can use this code for loading the model in inference.
```python
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
# Load model wither via `SegformerForSemanticSegmentation`
model = SegformerForSemanticSegmentation.from_pretrained("./ckpts/segformer_finetuned")
# the mode can also be loaded via AutoImageProcessor
processor = SegformerImageProcessor.from_pretrained("./ckpts/segformer_finetuned")
```

In [None]:
model_path="./segformer_finetuned"
# model.save_pretrained("./segformer_finetuned")
builder.get_processor().save_pretrained(model_path)
print(f"üé´Saved Pretained Processor to: {model_path}")

# Model Analytics



## Plot Loss Curves

In [None]:
plot_training_curves(results, save_path="models/checkpoints/training_plots.png")

## Plot Training Time

In [None]:
def plot_training_time(results, save_path=None):
    """Plot time taken per epoch and display total + average."""
    epoch_times = results.get("epoch_time", [])
    if not epoch_times:
        print("[WARN] No epoch_time data found in results.")
        return

    epochs = range(1, len(epoch_times) + 1)
    avg_time = np.mean(epoch_times)
    total_time = np.sum(epoch_times)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, epoch_times, marker='o', color='mediumseagreen', linewidth=2)
    plt.title("‚è±Ô∏è Training Time per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Time (seconds)")
    plt.grid(alpha=0.3)
    plt.xticks(epochs)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
        print(f"‚úÖ Saved training time plot to {save_path}")

    plt.show()
    print(f"üìä Total Training Time: {total_time:.2f} seconds")
    print(f"‚öôÔ∏è  Average Epoch Time: {avg_time:.2f} seconds")

In [None]:
plot_training_time(results, save_path="models/checkpoints/training_time.png")

## Visualize Predictions

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def visualize_predictions(model, dataloader, device, id2label, num_samples=4):
    """Visualize a few segmentation predictions with class colors."""
    model.eval()
    model.to(device)

    images, labels = next(iter(dataloader))
    images, labels = images.to(device), labels.to(device)

    with torch.no_grad():
        outputs = model(pixel_values=images)
        logits = torch.nn.functional.interpolate(
            outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
        )
        preds = torch.argmax(logits, dim=1).cpu()

    # Display a few samples
    for i in range(min(num_samples, len(images))):
        img = images[i].permute(1, 2, 0).cpu().numpy()
        pred_mask = preds[i].numpy()
        true_mask = labels[i].cpu().numpy()

        fig, ax = plt.subplots(1, 3, figsize=(12, 4))
        ax[0].imshow(img)
        ax[0].set_title("Original Image")

        ax[1].imshow(true_mask, cmap="tab20")
        ax[1].set_title("Ground Truth")

        ax[2].imshow(pred_mask, cmap="tab20")
        ax[2].set_title("Predicted Mask")

        for a in ax: a.axis("off")
        plt.tight_layout()
        plt.show()


In [None]:
visualize_predictions(model, validation_loader, builder.device, builder.id2label, num_samples=4)