# COCO-style Semantic Segmentation with Mask2Former
This notebook demonstrates how to set up a semantic segmentation pipeline using a COCO-style dataset and the Mask2Former model. It covers dataset preparation, label remapping, color palette setup, and data loading for training and evaluation.

In [1]:
# Consolidated Imports
import os
import random
import json
import pickle

from PIL import Image
import numpy as np
import skimage.draw
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

from sklearn.decomposition import PCA
from scipy.special import softmax
from tqdm.auto import tqdm

from torchvision import transforms

import evaluate

from transformers import (
    Mask2FormerForUniversalSegmentation,
    Mask2FormerImageProcessor,
    pipeline
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set Random Seeds for Reproducibility
seed = 78
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [3]:
# COCO Dataset Class Definition
class COCODataset(Dataset):
    """
    A custom Dataset class for COCO-format JSON annotations and images.
    Each item returns:
      - 'image': a PIL Image (converted to RGB)
      - 'semantic_map': a 2D uint8 tensor where each pixel's value is the category ID
      - 'image_id': the original COCO image ID
      - 'width', 'height': dimensions of the image
    """
    def __init__(self, coco_file: str, root_dir: str, split: str = None, transform=None):
        with open(coco_file, 'r') as f:
            self.coco_data = json.load(f)
        self.split_image_ids = { # dictionary mapping split names to image IDs
            'train': list(range(283, 314)) + list(range(314, 345)) + list(range(408, 471)),
            'valid': list(range(345, 377)) + list(range(533, 564)),
            'test':  list(range(377, 408)) + list(range(471, 533))
        }
        all_images = self.coco_data['images']
        if split in self.split_image_ids:
            valid_ids = set(self.split_image_ids[split])
            self.images = [img for img in all_images if img['id'] in valid_ids]
        else:
            self.images = all_images
        self.annotations = self.coco_data['annotations']
        self.categories = {
            cat['id']: {
                'name': cat['name'],
                'color': cat.get('color', "#000000"),
                'supercategory': cat['supercategory']
            }
            for cat in self.coco_data['categories']
        }
        print("Category IDs and their names:")
        for cat_id, cat_info in self.categories.items():
            print(f"  ID {cat_id}: {cat_info['name']}")
        self.root_dir = root_dir
        self.transform = transform
        self.image_id_to_annotations = {}
        for anno in self.annotations:
            image_id = anno['image_id']
            if image_id not in self.image_id_to_annotations:
                self.image_id_to_annotations[image_id] = []
            self.image_id_to_annotations[image_id].append(anno)
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx: int):
        image_info = self.images[idx]
        image_id = image_info['id']
        width, height = image_info['width'], image_info['height']
        relative_path = image_info['path'].lstrip('/datasets/')
        image_path = os.path.join(self.root_dir, relative_path)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        annotations = self.image_id_to_annotations.get(image_id, [])
        segmentations = [anno.get('segmentation', []) for anno in annotations]
        category_ids = [anno['category_id'] for anno in annotations]
        semantic_map = np.zeros((height, width), dtype=np.uint8)
        for seg, cat_id in zip(segmentations, category_ids):
            for poly in seg:
                coords = np.array(poly).reshape(-1, 2)
                rr, cc = skimage.draw.polygon(coords[:, 1], coords[:, 0], semantic_map.shape)
                semantic_map[rr, cc] = cat_id
        semantic_map_tensor = torch.tensor(semantic_map, dtype=torch.uint8)
        return {
            'image': image,
            'semantic_map': semantic_map_tensor,
            'image_id': image_id,
            'width': width,
            'height': height
        }

In [4]:
# Custom Collate Function for DataLoader

def custom_collate_fn(batch):
    collated = {}
    for key in batch[0]:
        if key == 'image':
            collated['images'] = torch.stack([item['image'] for item in batch])
        elif key == 'semantic_map':
            collated['semantic_map'] = torch.stack([item['semantic_map'] for item in batch])
        else:
            collated[key] = [item[key] for item in batch]
    return collated

In [5]:
# ID-to-Label and Color Palette Setup
id2label = {
    0: "bg",
    11: "pepper_kp",
    12: "pepper_red",
    13: "pepper_yellow",
    14: "pepper_green",
    15: "pepper_mixed",
    17: "pepper_mixed_red",
    18: "pepper_mixed_yellow"
}
label2id = {old_id: new_id for new_id, old_id in enumerate(sorted(id2label.keys()))}
id2label_remapped = {new_id: id2label[old_id] for old_id, new_id in label2id.items()}
print("Remapped ID-to-label:", id2label_remapped)
id2color = {
    0: "#000000",
    1: "#0000ff",
    2: "#c7211c",
    3: "#fff700",
    4: "#00ff00",
    5: "#e100ff",
    6: "#ff6600",
    7: "#d1c415",
}
palette = []
for class_id in range(len(id2label_remapped)):
    hex_color = id2color.get(class_id, "#000000")
    rgb = tuple(int(hex_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4))
    palette.append(rgb)
palette = np.array(palette, dtype=np.uint8)
print("Color palette (RGB):\n", palette)

Remapped ID-to-label: {0: 'bg', 1: 'pepper_kp', 2: 'pepper_red', 3: 'pepper_yellow', 4: 'pepper_green', 5: 'pepper_mixed', 6: 'pepper_mixed_red', 7: 'pepper_mixed_yellow'}
Color palette (RGB):
 [[  0   0   0]
 [  0   0 255]
 [199  33  28]
 [255 247   0]
 [  0 255   0]
 [225   0 255]
 [255 102   0]
 [209 196  21]]


In [6]:
# Utility Function to Remap Mask Labels
def remap_labels(mask: np.ndarray, label2id_map: dict) -> torch.Tensor:
    if not isinstance(mask, torch.Tensor):
        mask = torch.tensor(mask, dtype=torch.int64)
    remapped_mask = torch.zeros_like(mask)
    for old_id, new_id in label2id_map.items():
        remapped_mask[mask == old_id] = new_id
    return remapped_mask

In [7]:
# Wrapper Dataset: ImageSegmentationDataset
class ImageSegmentationDataset(Dataset):
    """
    A wrapper around a base dataset (e.g. COCODataset) to:
      - Remap original class IDs in the mask to contiguous IDs
      - Apply image transforms (normalization, augmentation) to the input image
      - Optionally apply target transforms to the segmentation mask
    Returns a tuple: (image_tensor, remapped_mask_tensor, original_image_numpy, original_mask_numpy)
    """
    def __init__(self, base_dataset: Dataset, transform=None, target_transform=None):
        self.dataset = base_dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx: int):
        sample = self.dataset[idx]
        orig_pil_image = sample['image']
        orig_mask_np = np.array(sample['semantic_map'])
        remapped_mask = remap_labels(orig_mask_np, label2id)
        if self.transform:
            image_tensor = self.transform(orig_pil_image)
        else:
            image_tensor = torch.tensor(np.array(orig_pil_image), dtype=torch.float32).permute(2, 0, 1)
        if self.target_transform:
            mask_transformed = self.target_transform(Image.fromarray(remapped_mask.numpy()))
            mask_tensor = torch.tensor(np.array(mask_transformed), dtype=torch.int64)
        else:
            mask_tensor = remapped_mask
        return image_tensor, mask_tensor, np.array(orig_pil_image), orig_mask_np


In [8]:
# Define Image and Target Transforms
ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255.0
ADE_STD  = np.array([58.395,  57.120,  57.375]) / 255.0

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

target_transform = transforms.Compose([
    # e.g. transforms.RandomHorizontalFlip(p=1.0)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=ADE_MEAN, std=ADE_STD),
])


In [9]:
# Instantiate Base and Wrapped Datasets
coco_file_path   = os.path.expanduser("~/Downloads/Thesis/CKA_sweet_pepper_2020_summer/CKA_sweet_pepper_2020_summer.json")
dataset_root_dir = os.path.expanduser("~/Downloads/Thesis")

base_train_ds = COCODataset(coco_file=coco_file_path, root_dir=dataset_root_dir, split='train', transform=None)
base_val_ds   = COCODataset(coco_file=coco_file_path, root_dir=dataset_root_dir, split='valid', transform=None)
base_test_ds  = COCODataset(coco_file=coco_file_path, root_dir=dataset_root_dir, split='test', transform=None)

train_dataset = ImageSegmentationDataset(base_train_ds, transform=train_transform, target_transform=None)
valid_dataset = ImageSegmentationDataset(base_val_ds,   transform=train_transform, target_transform=None)
test_dataset  = ImageSegmentationDataset(base_test_ds,  transform=test_transform,  target_transform=None)

# Quick sanity check: print shapes for first sample
image_tensor, mask_tensor, orig_img_np, orig_mask_np = train_dataset[0]
print("Sample shapes (train_dataset[0]):")
print("  image tensor shape =", image_tensor.shape)
print("  remapped mask shape =", mask_tensor.shape)
print("  original image shape =", orig_img_np.shape)
print("  original mask shape =", orig_mask_np.shape)


Category IDs and their names:
  ID 11: pepper_kp
  ID 12: red
  ID 13: yellow
  ID 14: green
  ID 15: mixed
  ID 17: mixed_red
  ID 18: mixed_yellow
Category IDs and their names:
  ID 11: pepper_kp
  ID 12: red
  ID 13: yellow
  ID 14: green
  ID 15: mixed
  ID 17: mixed_red
  ID 18: mixed_yellow
Category IDs and their names:
  ID 11: pepper_kp
  ID 12: red
  ID 13: yellow
  ID 14: green
  ID 15: mixed
  ID 17: mixed_red
  ID 18: mixed_yellow
Category IDs and their names:
  ID 11: pepper_kp
  ID 12: red
  ID 13: yellow
  ID 14: green
  ID 15: mixed
  ID 17: mixed_red
  ID 18: mixed_yellow
Category IDs and their names:
  ID 11: pepper_kp
  ID 12: red
  ID 13: yellow
  ID 14: green
  ID 15: mixed
  ID 17: mixed_red
  ID 18: mixed_yellow
Sample shapes (train_dataset[0]):
  image tensor shape = torch.Size([3, 1280, 720])
  remapped mask shape = torch.Size([1280, 720])
  original image shape = (1280, 720, 3)
  original mask shape = (1280, 720)
Sample shapes (train_dataset[0]):
  image tenso

  orig_mask_np = np.array(sample['semantic_map'])


In [10]:
# Prepare Mask2Former Processor and DataLoaders
preprocessor = Mask2FormerImageProcessor(
    ignore_index=255,
    reduce_labels=False,
    do_resize=False,
    do_rescale=False,
    do_normalize=False,
    num_labels=len(id2label_remapped)
)

def segmentation_collate_fn(batch):
    images, masks, orig_images, orig_masks = zip(*batch)
    processed = preprocessor(
        list(images),
        segmentation_maps=list(masks),
        return_tensors="pt"
    )
    # Attach original images and masks for later use (e.g. metric computation)
    processed["original_images"] = orig_images
    processed["original_segmentation_maps"] = orig_masks
    return processed

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=segmentation_collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=2, shuffle=False, collate_fn=segmentation_collate_fn)
test_dataloader  = DataLoader(test_dataset,  batch_size=2, shuffle=False, collate_fn=segmentation_collate_fn)

print(f"Number of train batches: {len(train_dataloader)}")
print(f"Number of valid batches: {len(valid_dataloader)}")


Number of train batches: 62
Number of valid batches: 32


In [13]:
# Training Loop for Semantic Segmentation
# Use train_dataloader, valid_dataloader, test_dataloader, preprocessor, id2label_remapped from previous cells
# Make sure 'model' is defined in a previous cell or define it here if not already present

import evaluate
import torch
from tqdm.auto import tqdm
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

metric = evaluate.load("mean_iou")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the Mask2Former model for universal segmentation
model = Mask2FormerForUniversalSegmentation.from_pretrained(
    "facebook/mask2former-swin-large-ade-semantic",
    num_labels=len(id2label_remapped),
    ignore_mismatched_sizes=True
)
model.to(device)

optimizer = optim.SGD(model.parameters(), lr=2e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=5e-6)

best_val_loss = float('inf')
best_epoch = 0
running_loss = 0.0
num_samples = 0

def get_class_labels(mask_labels):
    return [torch.zeros_like(lbl, dtype=torch.int64) for lbl in mask_labels]

for epoch in range(100):
    print("Epoch:", epoch)
    model.train()
    for idx, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        class_labels = get_class_labels(batch["mask_labels"])
        outputs = model(
            pixel_values=batch["pixel_values"].to(device),
            mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
            class_labels=[labels.to(device) for labels in class_labels],
        )
        loss = outputs.loss
        loss.backward()
        batch_size = batch["pixel_values"].size(0)
        running_loss += loss.item()
        num_samples += batch_size
        if idx % 100 == 0:
            print("Loss:", running_loss/num_samples)
        optimizer.step()
        # scheduler.step(epoch)

    model.eval()
    val_loss = 0.0
    for idx, batch in enumerate(tqdm(valid_dataloader)):
        with torch.no_grad():
            class_labels = get_class_labels(batch["mask_labels"])
            outputs = model(
                pixel_values=batch["pixel_values"].to(device),
                mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
                class_labels=[labels.to(device) for labels in class_labels],
            )
            valid_loss = outputs.loss
        val_loss += valid_loss.item()
    print("Mean IoU:", metric.compute(num_labels = len(id2label_remapped ), ignore_index=0)['mean_iou'])
    avg_val_loss = val_loss / len(valid_dataloader)
    print("Validation Loss:", avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch
        model_save_path = f"~/best_model_epoch_{best_epoch}.pt"
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved at epoch {best_epoch} with validation loss: {best_val_loss}")


Some weights of Mask2FormerForUniversalSegmentation were not initialized from the model checkpoint at facebook/mask2former-swin-large-ade-semantic and are newly initialized because the shapes did not match:
- class_predictor.bias: found shape torch.Size([151]) in the checkpoint and torch.Size([9]) in the model instantiated
- class_predictor.weight: found shape torch.Size([151, 256]) in the checkpoint and torch.Size([9, 256]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([151]) in the checkpoint and torch.Size([9]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 0


  orig_mask_np = np.array(sample['semantic_map'])
  orig_mask_np = np.array(sample['semantic_map'])


: 