[![Labellerr](https://storage.googleapis.com/labellerr-cdn/%200%20Labellerr%20template/notebook.webp)](https://www.labellerr.com)

# Mask2Former

---

[![labellerr](https://img.shields.io/badge/Labellerr-BLOG-black.svg)](https://www.labellerr.com/blog/<BLOG_NAME>)
[![Youtube](https://img.shields.io/badge/Labellerr-YouTube-b31b1b.svg)](https://www.youtube.com/@Labellerr)
[![Github](https://img.shields.io/badge/Labellerr-GitHub-green.svg)](https://github.com/Labellerr/Hands-On-Learning-in-Computer-Vision)
[![Scientific Paper](https://img.shields.io/badge/Official-Paper-blue.svg)](<PAPER LINK>)

# About Mask2Former

Mask2Former is a universal image segmentation architecture developed by Meta AI Research in 2022. It is designed to handle all major segmentation tasks-semantic, instance, and panoptic segmentation-with a single, unified framework. This model builds on the MaskFormer architecture, introducing key innovations to improve both performance and efficiency.

![Labellerr](https://storage.googleapis.com/labellerr-cdn/%200%20Mask2Former/main.webp)

In [None]:
! uv pip install transformers torch scipy matplotlib opencv-python pillow scikit-image

In [None]:
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
import torch
from PIL import Image
import requests
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from skimage.measure import regionprops


# Helper Function

In [None]:
def visualize_segmentation(results, image, model, alpha=0.4):
    """
    Visualizes segmentation results with object names and color lines
    on the right edge of the image. The alpha parameter controls mask opacity.
    """
    segmentation = results['segmentation'].numpy()
    segments_info = results['segments_info']
    height, width = segmentation.shape
    color_mask = np.zeros((height, width, 3), dtype=np.uint8)

    # Assign random color to each instance
    instance_colors = {
        segment['id']: np.random.randint(0, 255, size=3)
        for segment in segments_info
    }

    for segment in segments_info:
        mask = segmentation == segment['id']
        color_mask[mask] = instance_colors[segment['id']]

    # Overlay mask on image with alpha blending
    image_np = np.array(image).astype(np.uint8)
    overlay = (1 - alpha) * image_np + alpha * color_mask
    overlay = overlay.astype(np.uint8)

    plt.figure(figsize=(12, 8))
    plt.imshow(overlay)
    ax = plt.gca()
    plt.axis('off')

    # Draw color lines and names on the right edge
    line_height = 25
    spacing = 10
    y0 = spacing
    for segment in segments_info:
        label_name = model.config.id2label[segment['label_id']]
        color = instance_colors[segment['id']] / 255

        # Draw color line
        rect = mpatches.Rectangle(
            (width - 30, y0), 20, line_height,
            linewidth=0, edgecolor=None, facecolor=color, alpha=1.0, transform=ax.transData, clip_on=False
        )
        ax.add_patch(rect)

        # Draw label text
        ax.text(
            width - 35, y0 + line_height / 2,
            f"{label_name} ({segment['score']:.2f})",
            va='center', ha='right', fontsize=11,
            color='white' if np.mean(color) < 0.5 else 'black',
            bbox=dict(facecolor=(0, 0, 0, 0.2), edgecolor='none', boxstyle='round,pad=0.2')
        )
        y0 += line_height + spacing

    plt.show()
    
def draw_binary_mask(results, model):
    segmentation = results['segmentation']
    segments_info = results['segments_info']
    seg_np = segmentation.numpy() if hasattr(segmentation, 'numpy') else np.array(segmentation)
    
    plt.figure(figsize=(15, 10))
    plt.imshow(seg_np)
    ax = plt.gca()

    # Map segment id to label id
    segment_to_label = {segment['id']: segment['label_id'] for segment in segments_info}

    # For each segment, find centroid and plot label
    for segment in segments_info:
        segment_id = segment['id']
        label_id = segment['label_id']
        label_name = model.config.id2label[label_id]
        mask = (seg_np == segment_id)
        props = regionprops(mask.astype(np.uint8))
        if props:
            y, x = props[0].centroid
            ax.text(
                x, y, label_name,
                color='white', fontsize=8, weight='bold',
                ha='center', va='center',
                bbox=dict(facecolor='black', alpha=0.5, boxstyle='round,pad=0.2')
            )
    plt.axis('off')
    plt.show()

def visualize_semantic_map(predicted_map, original_image, model, alpha=0.5):
    """
    Visualizes the semantic segmentation map over the original image.
    The alpha parameter controls the transparency of the mask (0=transparent, 1=opaque).
    """
    import torch
    import numpy as np
    import matplotlib.pyplot as plt

    # Generate a random color palette
    color_palette = np.random.randint(0, 255, size=(len(model.config.id2label), 3))
    color_seg = np.zeros((predicted_map.shape[0], predicted_map.shape[1], 3), dtype=np.uint8)
    
    for label in torch.unique(predicted_map):
        color_seg[predicted_map == label] = color_palette[label]
    
    # Blend the original image and the color mask using the alpha parameter
    img = np.array(original_image) * (1 - alpha) + color_seg * alpha
    img = img.astype(np.uint8)
    
    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

In [None]:
# url = "https://i.pinimg.com/736x/1e/de/d8/1eded82fbcfc288327d9f12795dbab7b.jpg"
# url = "https://i.pinimg.com/736x/da/c6/12/dac612292f2cb04c82d6b41fdb6a1c6a.jpg"
# url = "https://i.pinimg.com/736x/4e/9f/7b/4e9f7ba5e373204ce3a0737494ab76c3.jpg"
# url = "https://i.pinimg.com/736x/9a/67/16/9a6716e63ff0afb1cee5931ade66e388.jpg"
# url = "https://i.pinimg.com/736x/59/e6/ae/59e6ae53b2bcee83dd57e4e3b1ce7d7c.jpg"
# url = "https://i.pinimg.com/736x/07/18/53/07185391756d6eef0ef73ea9d0b56ef9.jpg"
url = "https://i.pinimg.com/736x/16/4f/21/164f213bb141fa9957926531d4056ceb.jpg"

# SEMANTIC SEGMENTATION

In [None]:
def run_semantic_segmentation(image_path):
    """Performs semantic segmentation using Cityscapes-trained model"""
    # Load model and processor
    checkpoint= "facebook/mask2former-swin-large-ade-semantic"
    processor = AutoImageProcessor.from_pretrained(checkpoint)
    model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint)
    
    # Load image
    if image_path.startswith('http'):
        image = Image.open(requests.get(image_path, stream=True).raw)
    else:
        image = Image.open(image_path)
    
    # Process and predict
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Post-process
    predicted_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
    
    return predicted_map, image, model




In [None]:
# Semantic segmentation
# url = "https://cdn-media.huggingface.co/Inference-API/Sample-results-on-the-Cityscapes-dataset-The-above-images-show-how-our-method-can-handle.png"
# url = "https://i.pinimg.com/736x/cb/b1/3f/cbb13f0a3ef98d8e180a19ff8bd7f43a.jpg"
predicted_map, image, model = run_semantic_segmentation(url)

visualize_semantic_map(predicted_map, image, model, alpha=0.6)
visualize_semantic_map(predicted_map, image, model, alpha=1)

# INSTANCE SEGMENTATION

In [None]:

def run_instance_segmentation(image_path):
    """
    Runs instance segmentation on an image and visualizes the results.
    """
    # Load model and processor
    checkpoint = "facebook/mask2former-swin-large-coco-instance"
    processor = AutoImageProcessor.from_pretrained(checkpoint)
    model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint)

    # Load image
    if image_path.startswith('http'):
        image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
    else:
        image = Image.open(image_path).convert("RGB")

    # Preprocess and inference
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    # Post-process
    results = processor.post_process_instance_segmentation(
        outputs, target_sizes=[image.size[::-1]]
    )[0]

    # Visualize
    return results, image, model

In [None]:
# Instance segmentation
# url = "https://i.pinimg.com/736x/cb/b1/3f/cbb13f0a3ef98d8e180a19ff8bd7f43a.jpg"
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
instance_results, image, model = run_instance_segmentation(url)
draw_binary_mask(instance_results, model)
visualize_segmentation(instance_results, image, model, alpha=0.5)

# PANOPTIC SEGMENTATION

In [None]:
def run_panoptic_segmentation(image_path):
    """Performs panoptic segmentation using COCO-trained model"""
    # Load model and processor
    checkpoint = "facebook/mask2former-swin-base-coco-panoptic"
    processor = AutoImageProcessor.from_pretrained(checkpoint)
    model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint)
    
    # Load image
    if image_path.startswith('http'):
        image = Image.open(requests.get(image_path, stream=True).raw)
    else:
        image = Image.open(image_path)
    
    # Process and predict
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Post-process
    results = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
    
    return results, image, model

In [None]:
# Panoptic segmentation
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# url = "https://i.pinimg.com/736x/cb/b1/3f/cbb13f0a3ef98d8e180a19ff8bd7f43a.jpg"
panoptic_results, image, model = run_panoptic_segmentation(url)
# Visualize
draw_binary_mask(panoptic_results, model)
visualize_segmentation(panoptic_results, image, model, alpha=0.6)