# Utils


In [1]:
import torch
import torch.distributed as dist
from typing import List, Optional, Tuple
from functools import partial


class NestedTensor:
    """
    A nested tensor is a tensor along with a mask indicating which elements are valid.
    This is useful for batching images of different sizes.
    """

    def __init__(self, tensors: torch.Tensor, mask: Optional[torch.Tensor]):
        self.tensors = tensors
        self.mask = mask

    def decompose(self) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.tensors, self.mask

    @property
    def device(self):
        return self.tensors.device

    def to(self, device):
        return NestedTensor(self.tensors.to(device), self.mask.to(device) if self.mask is not None else None)


def nested_tensor_from_tensor_list(tensor_list: List[torch.Tensor]) -> NestedTensor:
    """
    Create a nested tensor from a list of tensors.

    Pads the tensors to the same size and creates a mask indicating valid regions.

    Args:
        tensor_list: List of tensors with shape [C, H_i, W_i]

    Returns:
        NestedTensor containing:
            - tensors: Padded tensors [B, C, H_max, W_max]
            - mask: Binary mask [B, H_max, W_max] (False = valid, True = padding)
    """
    # Calculate maximum dimensions
    max_size = _max_by_axis([list(img.shape) for img in tensor_list])

    # Create output tensor and mask
    batch_shape = [len(tensor_list)] + max_size
    b, c, h, w = batch_shape

    device = tensor_list[0].device
    dtype = tensor_list[0].dtype

    tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
    mask = torch.ones((b, h, w), dtype=torch.bool, device=device)

    # Copy tensors and update mask
    for i, img in enumerate(tensor_list):
        tensor[i, :img.shape[0], :img.shape[1], :img.shape[2]].copy_(img)
        mask[i, :img.shape[1], :img.shape[2]] = False

    return NestedTensor(tensor, mask)


def _max_by_axis(the_list: List[List[int]]) -> List[int]:
    """Find maximum size along each axis."""
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes


def is_dist_avail_and_initialized() -> bool:
    """Check if distributed training is available and initialized."""
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size() -> int:
    """Get world size for distributed training."""
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank() -> int:
    """Get rank for distributed training."""
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process() -> bool:
    """Check if current process is the main process."""
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    """Save checkpoint only on the main process."""
    if is_main_process():
        torch.save(*args, **kwargs)


class ImageList:
    """
    Structure that holds a list of images (tensors) of possibly different sizes.

    Similar to detectron2's ImageList but without the detectron2 dependency.
    """

    def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]):
        """
        Args:
            tensor: Batched images tensor [B, C, H, W]
            image_sizes: List of (height, width) tuples for each image
        """
        self.tensor = tensor
        self.image_sizes = image_sizes

    @property
    def device(self) -> torch.device:
        return self.tensor.device

    def __len__(self) -> int:
        return len(self.image_sizes)

    def __getitem__(self, idx: int) -> torch.Tensor:
        """Get image at index with its original size."""
        h, w = self.image_sizes[idx]
        return self.tensor[idx, :, :h, :w]

    @staticmethod
    def from_tensors(
        tensors: List[torch.Tensor],
        size_divisibility: int = 0,
        pad_value: float = 0.0
    ) -> "ImageList":
        """
        Create ImageList from list of image tensors.

        Args:
            tensors: List of image tensors
            size_divisibility: If > 0, pad dimensions to be divisible by this
            pad_value: Value to use for padding

        Returns:
            ImageList instance
        """
        assert len(tensors) > 0
        assert isinstance(tensors[0], torch.Tensor)

        image_sizes = [(tensor.shape[-2], tensor.shape[-1]) for tensor in tensors]

        # Calculate target size
        max_size = list(max(zip(*[tensor.shape for tensor in tensors])))

        if size_divisibility > 0:
            # Pad to be divisible by size_divisibility
            stride = size_divisibility
            max_size[-2] = (max_size[-2] + stride - 1) // stride * stride
            max_size[-1] = (max_size[-1] + stride - 1) // stride * stride

        # Create batched tensor
        batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size[-2:])
        device = tensors[0].device
        batched_tensor = torch.full(batch_shape, pad_value, dtype=tensors[0].dtype, device=device)

        # Copy images
        for i, tensor in enumerate(tensors):
            batched_tensor[i, ..., :tensor.shape[-2], :tensor.shape[-1]].copy_(tensor)

        return ImageList(batched_tensor, image_sizes)


def multi_apply(func, *args, **kwargs):
    """
    Apply function to multiple inputs and concatenate outputs.

    Useful for applying the same function to multiple levels of features.
    """
    pfunc = partial(func, **kwargs) if kwargs else func
    map_results = map(pfunc, *args)
    return tuple(map(list, zip(*map_results)))


def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    """
    Inverse of sigmoid function.

    Args:
        x: Input tensor (should be in range [0, 1])
        eps: Small epsilon to avoid numerical issues

    Returns:
        Inverse sigmoid of x
    """
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1 / x2)

# Losses

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from typing import Dict, List, Tuple


class HungarianMatcher(nn.Module):
    """
    This class computes an assignment between the targets and the predictions.

    For efficiency reasons, the targets don't include the no_object class. Because of this,
    in general, there are more predictions than targets. We do a 1-to-1 matching,
    and the remaining predictions are un-matched (and thus treated as background).
    """

    def __init__(
        self,
        cost_class: float = 1.0,
        cost_mask: float = 1.0,
        cost_dice: float = 1.0,
        num_points: int = 0,
    ):
        """
        Args:
            cost_class: Weight of the classification cost
            cost_mask: Weight of the sigmoid CE cost for masks
            cost_dice: Weight of the dice cost for masks
            num_points: Number of points to sample for matching
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_mask = cost_mask
        self.cost_dice = cost_dice
        self.num_points = num_points

    @torch.no_grad()
    def forward(
        self,
        outputs: Dict[str, torch.Tensor],
        targets: List[Dict[str, torch.Tensor]]
    ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """
        Performs the matching.

        Args:
            outputs: Model outputs containing:
                - "pred_logits": [batch_size, num_queries, num_classes]
                - "pred_masks": [batch_size, num_queries, H, W]

            targets: List of target dictionaries (one per batch item) containing:
                - "labels": [num_target_masks] containing class labels
                - "masks": [num_target_masks, H, W] containing target masks

        Returns:
            List of tuples (index_i, index_j) where:
                - index_i: indices of selected predictions (in order)
                - index_j: indices of corresponding selected targets (in order)
            For each batch element, it returns (index_i, index_j) such that:
                - len(index_i) = len(index_j) = min(num_queries, num_target_masks)
                - For all k: prediction index_i[k] is matched to target index_j[k]
        """
        batch_size, num_queries = outputs["pred_logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_mask = outputs["pred_masks"].flatten(0, 1)  # [batch_size * num_queries, H, W]

        # Concatenate target labels and masks
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_mask = torch.cat([v["masks"] for v in targets])

        # Compute classification cost
        # out_prob: [batch_size * num_queries, num_classes]
        # tgt_ids: [total_num_targets_in_batch]
        cost_class = -out_prob[:, tgt_ids]

        # Compute mask costs
        if self.num_points > 0:
            # Sample points for efficiency
            out_mask_sampled = self._sample_points_for_matching(out_mask, self.num_points)
            tgt_mask_sampled = self._sample_points_for_matching(tgt_mask, self.num_points)
        else:
            out_mask_sampled = out_mask.flatten(1)
            tgt_mask_sampled = tgt_mask.flatten(1)

        # Compute sigmoid cross-entropy cost
        with torch.cuda.amp.autocast(enabled=False):
            out_mask_sampled = out_mask_sampled.float()
            tgt_mask_sampled = tgt_mask_sampled.float()

            # Sigmoid CE cost
            cost_mask = self._batch_sigmoid_ce_cost(out_mask_sampled, tgt_mask_sampled)

            # Dice cost
            cost_dice = self._batch_dice_cost(out_mask_sampled, tgt_mask_sampled)

        # Final cost matrix
        C = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
        C = C.reshape(batch_size, num_queries, -1)

        # Perform Hungarian matching for each batch element
        indices = []
        sizes = [len(v["labels"]) for v in targets]

        for i, c in enumerate(C.cpu()):
            if sizes[i] == 0:
                # No targets in this batch element
                indices.append((
                    torch.empty(0, dtype=torch.long),
                    torch.empty(0, dtype=torch.long)
                ))
            else:
                # Extract cost matrix for this batch element
                c_i = c[:, sum(sizes[:i]):sum(sizes[:i+1])]

                # Solve assignment problem
                row_ind, col_ind = linear_sum_assignment(c_i.numpy())

                indices.append((
                    torch.as_tensor(row_ind, dtype=torch.long),
                    torch.as_tensor(col_ind, dtype=torch.long)
                ))

        return [(torch.as_tensor(i, dtype=torch.long), torch.as_tensor(j, dtype=torch.long))
                for i, j in indices]

    def _sample_points_for_matching(self, masks: torch.Tensor, num_points: int) -> torch.Tensor:
        """
        Uniformly sample points from masks for matching.

        Args:
            masks: [N, H, W] tensor of masks
            num_points: Number of points to sample

        Returns:
            [N, num_points] tensor of sampled mask values
        """
        N, H, W = masks.shape

        # Generate random point coordinates
        # Use same points for all masks to ensure fair comparison
        with torch.no_grad():
            points_idx = torch.randperm(H * W, device=masks.device)[:num_points]

        # Sample mask values at these points
        masks_flat = masks.flatten(1)  # [N, H*W]
        masks_sampled = masks_flat[:, points_idx]  # [N, num_points]

        return masks_sampled

    def _batch_sigmoid_ce_cost(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute sigmoid cross-entropy cost between all pairs.

        Args:
            inputs: [N, P] tensor of predicted logits
            targets: [M, P] tensor of target values (0 or 1)

        Returns:
            [N, M] cost matrix
        """
        # Compute pairwise sigmoid CE
        # Note: This is memory efficient for reasonable batch sizes
        N, P = inputs.shape
        M = targets.shape[0]

        # Expand dimensions for broadcasting
        inputs_exp = inputs.unsqueeze(1)  # [N, 1, P]
        targets_exp = targets.unsqueeze(0)  # [1, M, P]

        # Compute sigmoid CE for all pairs
        ce = F.binary_cross_entropy_with_logits(
            inputs_exp.expand(N, M, P),
            targets_exp.expand(N, M, P),
            reduction='none'
        )

        return ce.mean(dim=2)  # [N, M]

    def _batch_dice_cost(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute dice cost between all pairs.

        Args:
            inputs: [N, P] tensor of predicted logits
            targets: [M, P] tensor of target values (0 or 1)

        Returns:
            [N, M] cost matrix
        """
        N, P = inputs.shape
        M = targets.shape[0]

        # Apply sigmoid to get probabilities
        inputs = inputs.sigmoid()

        # Expand dimensions for broadcasting
        inputs_exp = inputs.unsqueeze(1)  # [N, 1, P]
        targets_exp = targets.unsqueeze(0)  # [1, M, P]

        # Compute dice coefficient for all pairs
        numerator = 2 * (inputs_exp * targets_exp).sum(dim=2)  # [N, M]
        denominator = inputs_exp.sum(dim=2) + targets_exp.sum(dim=2)  # [N, M]

        # Dice cost (1 - dice coefficient)
        dice_cost = 1 - (numerator + 1) / (denominator + 1)

        return dice_cost  # [N, M]

In [3]:
import numpy as np


def dice_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_masks: float,
) -> torch.Tensor:
    """
    Compute the DICE loss, similar to generalized IOU for masks.

    Args:
        inputs: Predicted masks [N, H, W]
        targets: Target masks [N, H, W]
        num_masks: Number of masks to normalize the loss

    Returns:
        Normalized dice loss
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    targets = targets.flatten(1)

    numerator = 2 * (inputs * targets).sum(-1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)

    return loss.sum() / num_masks


def sigmoid_ce_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_masks: float,
) -> torch.Tensor:
    """
    Sigmoid cross-entropy loss.

    Args:
        inputs: Predicted logits [N, H, W]
        targets: Target masks [N, H, W]
        num_masks: Number of masks to normalize the loss

    Returns:
        Normalized sigmoid CE loss
    """
    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    return loss.mean(1).sum() / num_masks


class SetCriterion(nn.Module):
    """
    This class computes the loss for Mask2Former.
    The process happens in two steps:
        1) We compute hungarian assignment between ground truth masks and predictions
        2) We supervise each pair of matched ground-truth / prediction (supervise class and mask)
    """

    def __init__(
        self,
        num_classes: int,
        matcher: HungarianMatcher,
        weight_dict: Dict[str, float],
        eos_coef: float,
        losses: List[str],
        num_points: int = 12544,
        oversample_ratio: float = 3.0,
        importance_sample_ratio: float = 0.75,
    ):
        """
        Args:
            num_classes: Number of object categories
            matcher: Module to compute assignment between targets and predictions
            weight_dict: Dict containing weights for different losses
            eos_coef: Relative classification weight of the no-object category
            losses: List of losses to be applied
            num_points: Number of points sampled for point losses
            oversample_ratio: Oversampling ratio for point sampling
            importance_sample_ratio: Ratio of points sampled via importance sampling
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses

        # Point sampling parameters
        self.num_points = num_points
        self.oversample_ratio = oversample_ratio
        self.importance_sample_ratio = importance_sample_ratio

        # Create no-object class weight
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)

    def loss_labels(
        self,
        outputs: Dict[str, torch.Tensor],
        targets: List[Dict[str, torch.Tensor]],
        indices: List[Tuple[torch.Tensor, torch.Tensor]],
        num_masks: int,
    ) -> Dict[str, torch.Tensor]:
        """Classification loss."""
        assert "pred_logits" in outputs
        src_logits = outputs["pred_logits"]

        # Prepare target classes
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(
            src_logits.shape[:2],
            self.num_classes,
            dtype=torch.int64,
            device=src_logits.device
        )
        target_classes[idx] = target_classes_o

        # Compute cross-entropy loss
        loss_ce = F.cross_entropy(
            src_logits.transpose(1, 2),
            target_classes,
            self.empty_weight
        )

        losses = {"loss_ce": loss_ce}
        return losses

    def loss_masks(
        self,
        outputs: Dict[str, torch.Tensor],
        targets: List[Dict[str, torch.Tensor]],
        indices: List[Tuple[torch.Tensor, torch.Tensor]],
        num_masks: int,
    ) -> Dict[str, torch.Tensor]:
        """Compute mask losses (dice + sigmoid CE)."""
        assert "pred_masks" in outputs

        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)

        src_masks = outputs["pred_masks"]
        src_masks = src_masks[src_idx]

        # Get target masks
        target_masks = [t["masks"] for t in targets]
        target_masks, valid = nested_tensor_from_tensor_list(target_masks).decompose()
        target_masks = target_masks.to(src_masks)
        target_masks = target_masks[tgt_idx]

        # Point sampling for efficient computation
        with torch.no_grad():
            point_coords = self._sample_points(
                src_masks,
                target_masks,
                self.oversample_ratio,
                self.importance_sample_ratio
            )

        # Sample points from masks
        src_masks_sampled = self._point_sample(src_masks, point_coords, align_corners=False)
        target_masks_sampled = self._point_sample(target_masks, point_coords, align_corners=False)

        src_masks_sampled = src_masks_sampled.flatten(1)
        target_masks_sampled = target_masks_sampled.flatten(1)

        losses = {
            "loss_mask": sigmoid_ce_loss(src_masks_sampled, target_masks_sampled, num_masks),
            "loss_dice": dice_loss(src_masks_sampled, target_masks_sampled, num_masks),
        }

        return losses

    def _get_src_permutation_idx(self, indices):
        """Get source permutation indices."""
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        """Get target permutation indices."""
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def _sample_points(
        self,
        src_masks: torch.Tensor,
        tgt_masks: torch.Tensor,
        oversample_ratio: float,
        importance_sample_ratio: float,
    ) -> torch.Tensor:
        """
        Sample points for mask loss computation.

        Combines uniform sampling with importance sampling based on uncertainty.
        """
        N, H, W = src_masks.shape

        num_points = int(self.num_points * oversample_ratio)
        num_uncertain_points = int(importance_sample_ratio * num_points)
        num_random_points = num_points - num_uncertain_points

        # Uniform random sampling
        point_coords = torch.rand(N, num_random_points, 2, device=src_masks.device)
        point_coords = point_coords * torch.tensor([W, H], device=src_masks.device) - 0.5

        # Importance sampling based on prediction uncertainty
        if num_uncertain_points > 0:
            point_logits = src_masks.flatten(1)
            point_uncertainties = -torch.abs(point_logits)

            # Sample points with high uncertainty
            _, uncertain_idx = torch.topk(point_uncertainties, num_uncertain_points, dim=1)

            # Convert flat indices to 2D coordinates
            uncertain_coords = torch.stack([
                uncertain_idx % W,
                uncertain_idx // W
            ], dim=-1).float()

            # Add small random noise
            uncertain_coords += torch.rand_like(uncertain_coords) - 0.5

            # Combine uniform and importance samples
            point_coords = torch.cat([point_coords, uncertain_coords], dim=1)

        return point_coords

    def _point_sample(
        self,
        input: torch.Tensor,
        point_coords: torch.Tensor,
        align_corners: bool = False,
    ) -> torch.Tensor:
        """
        Sample features at specified point coordinates.

        Args:
            input: Input tensor [N, H, W]
            point_coords: Point coordinates [N, P, 2]
            align_corners: Whether to align corners in grid_sample

        Returns:
            Sampled features [N, P, 1]
        """
        N, H, W = input.shape

        # Normalize coordinates to [-1, 1]
        point_coords = point_coords.clone()
        point_coords[..., 0] = 2.0 * point_coords[..., 0] / W - 1.0
        point_coords[..., 1] = 2.0 * point_coords[..., 1] / H - 1.0

        # Add batch dimension and sample
        output = F.grid_sample(
            input.unsqueeze(1),
            point_coords.unsqueeze(1),
            mode="bilinear",
            padding_mode="zeros",
            align_corners=align_corners,
        )

        return output.squeeze(1).transpose(1, 2)

    def get_loss(
        self,
        loss: str,
        outputs: Dict[str, torch.Tensor],
        targets: List[Dict[str, torch.Tensor]],
        indices: List[Tuple[torch.Tensor, torch.Tensor]],
        num_masks: int,
    ) -> Dict[str, torch.Tensor]:
        """Dispatch to specific loss function."""
        loss_map = {
            "labels": self.loss_labels,
            "masks": self.loss_masks,
        }

        assert loss in loss_map, f"Loss {loss} not supported"
        return loss_map[loss](outputs, targets, indices, num_masks)

    def forward(
        self,
        outputs: Dict[str, torch.Tensor],
        targets: List[Dict[str, torch.Tensor]],
    ) -> Dict[str, torch.Tensor]:
        """
        Compute all losses.

        Args:
            outputs: Model outputs containing:
                - pred_logits: [B, num_queries, num_classes]
                - pred_masks: [B, num_queries, H, W]
                - aux_outputs: List of intermediate outputs
            targets: List of target dictionaries containing:
                - labels: [num_objects] tensor of class labels
                - masks: [num_objects, H, W] tensor of target masks

        Returns:
            Dict of losses
        """
        # Retrieve outputs
        outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}

        # Match predictions with targets
        indices = self.matcher(outputs_without_aux, targets)

        # Count total number of masks across batch
        num_masks = sum(len(t["labels"]) for t in targets)
        num_masks = torch.as_tensor([num_masks], dtype=torch.float, device=outputs["pred_logits"].device)

        # Ensure at least 1 to avoid division by zero
        num_masks = torch.clamp(num_masks, min=1)

        # Compute all requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))

        # Process auxiliary outputs (intermediate layer predictions)
        if "aux_outputs" in outputs:
            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses

# Models


In [4]:
from typing import Dict, List, Tuple, Optional



class Mask2Former(nn.Module):
    """
    PyTorch-native implementation of Mask2Former.

    Main architecture for mask classification, supporting:
    - Instance segmentation
    - Semantic segmentation
    - Panoptic segmentation
    """

    def __init__(
        self,
        backbone: nn.Module,
        pixel_decoder: nn.Module,
        transformer_decoder: nn.Module,
        num_queries: int = 100,
        object_mask_threshold: float = 0.0,
        overlap_threshold: float = 0.8,
        metadata: Optional[Dict] = None,
        pixel_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
        pixel_std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
        semantic_on: bool = True,
        instance_on: bool = False,
        panoptic_on: bool = False,
        test_topk_per_image: int = 100,
        num_classes: int = 133,  # COCO panoptic classes
    ):
        """
        Args:
            backbone: Backbone network (e.g., ResNet)
            pixel_decoder: Pixel decoder module
            transformer_decoder: Transformer decoder for mask prediction
            num_queries: Number of query embeddings
            object_mask_threshold: Threshold for object masks
            overlap_threshold: Threshold for mask overlap
            metadata: Dataset metadata
            pixel_mean: Mean values for input normalization
            pixel_std: Std values for input normalization
            semantic_on: Enable semantic segmentation
            instance_on: Enable instance segmentation
            panoptic_on: Enable panoptic segmentation
            test_topk_per_image: Number of top predictions per image
            num_classes: Number of classes
        """
        super().__init__()

        self.backbone = backbone
        self.pixel_decoder = pixel_decoder
        self.predictor = transformer_decoder

        self.num_queries = num_queries
        self.object_mask_threshold = object_mask_threshold
        self.overlap_threshold = overlap_threshold
        self.metadata = metadata if metadata is not None else {}

        # Normalization parameters
        self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1))
        self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1))

        # Task flags
        self.semantic_on = semantic_on
        self.instance_on = instance_on
        self.panoptic_on = panoptic_on

        self.test_topk_per_image = test_topk_per_image
        self.num_classes = num_classes

    def forward(self, images: torch.Tensor, targets: Optional[List[Dict]] = None):
        """
        Args:
            images: Batched images tensor [B, C, H, W]
            targets: Ground truth targets (for training)

        Returns:
            dict: Model outputs containing:
                - 'pred_logits': Classification logits [B, num_queries, num_classes]
                - 'pred_masks': Mask predictions [B, num_queries, H, W]
                - 'aux_outputs': Auxiliary outputs from intermediate layers
        """
        # Normalize images
        images = (images - self.pixel_mean) / self.pixel_std

        # Extract features from backbone
        features = self.backbone(images)

        # Generate mask features and positional embeddings
        mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder(features)

        # Predict masks and classes
        predictions = self.predictor(
            multi_scale_features,
            mask_features,
            targets=targets
        )

        # Postprocess for inference
        if not self.training:
            # Process predictions for the specific task
            if self.semantic_on:
                predictions = self.semantic_inference(predictions, images)
            elif self.instance_on:
                predictions = self.instance_inference(predictions, images)
            elif self.panoptic_on:
                predictions = self.panoptic_inference(predictions, images)

        return predictions

    def semantic_inference(self, outputs: Dict, images: torch.Tensor) -> Dict:
        """Semantic segmentation inference."""
        mask_cls = outputs["pred_logits"]
        mask_pred = outputs["pred_masks"]

        # For semantic segmentation, we average over all queries
        mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]  # Remove no-object class
        mask_pred = mask_pred.sigmoid()
        semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)

        # Resize to original image size
        B, _, H_img, W_img = images.shape
        semseg = F.interpolate(
            semseg,
            size=(H_img, W_img),
            mode="bilinear",
            align_corners=False
        )

        return {"sem_seg": semseg}

    def instance_inference(self, outputs: Dict, images: torch.Tensor) -> List[Dict]:
        """Instance segmentation inference."""
        mask_cls = outputs["pred_logits"]
        mask_pred = outputs["pred_masks"]

        # Get image size
        B, _, H_img, W_img = images.shape

        results = []

        # Process each image in the batch
        for i in range(B):
            scores, labels = mask_cls[i].max(-1)
            mask_pred_i = mask_pred[i]

            # Remove background predictions
            keep = labels != self.num_classes
            scores = scores[keep]
            labels = labels[keep]
            mask_pred_i = mask_pred_i[keep]

            # Apply threshold
            keep = scores > self.object_mask_threshold
            scores = scores[keep]
            labels = labels[keep]
            mask_pred_i = mask_pred_i[keep]

            # Get top-k predictions
            if len(scores) > self.test_topk_per_image:
                indices = torch.argsort(scores, descending=True)[:self.test_topk_per_image]
                scores = scores[indices]
                labels = labels[indices]
                mask_pred_i = mask_pred_i[indices]

            # Resize masks to original size
            mask_pred_i = F.interpolate(
                mask_pred_i.unsqueeze(0),
                size=(H_img, W_img),
                mode="bilinear",
                align_corners=False
            ).squeeze(0)

            # Binarize masks
            mask_pred_i = mask_pred_i.sigmoid() > 0.5

            results.append({
                "instances": {
                    "pred_masks": mask_pred_i,
                    "scores": scores,
                    "pred_classes": labels
                }
            })

        return results

    def panoptic_inference(self, outputs: Dict, images: torch.Tensor) -> List[Dict]:
        """Panoptic segmentation inference."""
        # Panoptic segmentation combines instance and semantic results
        # This is a simplified version - full implementation would need
        # more sophisticated merging of instance and stuff predictions

        mask_cls = outputs["pred_logits"]
        mask_pred = outputs["pred_masks"]

        B, _, H_img, W_img = images.shape

        results = []

        for i in range(B):
            scores, labels = mask_cls[i].max(-1)
            mask_pred_i = mask_pred[i].sigmoid()

            # Keep high-scoring predictions
            keep = scores > self.object_mask_threshold
            scores = scores[keep]
            labels = labels[keep]
            mask_pred_i = mask_pred_i[keep]

            # Resize masks
            mask_pred_i = F.interpolate(
                mask_pred_i.unsqueeze(0),
                size=(H_img, W_img),
                mode="bilinear",
                align_corners=False
            ).squeeze(0)

            # Generate panoptic segmentation
            panoptic_seg = torch.zeros((H_img, W_img), dtype=torch.int64, device=mask_pred_i.device)
            segments_info = []

            current_segment_id = 1

            # Process each prediction
            for j in range(len(scores)):
                mask = mask_pred_i[j] > 0.5

                # Skip empty masks
                if not mask.any():
                    continue

                # Check overlap with existing segments
                overlap = panoptic_seg > 0
                if (mask & overlap).sum() / mask.sum() > self.overlap_threshold:
                    continue

                # Add to panoptic segmentation
                panoptic_seg[mask] = current_segment_id

                segments_info.append({
                    "id": current_segment_id,
                    "category_id": labels[j].item(),
                    "score": scores[j].item()
                })

                current_segment_id += 1

            results.append({
                "panoptic_seg": (panoptic_seg, segments_info)
            })

        return results

# Usage

In [5]:
def create_dummy_backbone():
    """Create a simple backbone that returns multi-scale features."""
    class DummyBackbone(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 64, 3, stride=2, padding=1)
            self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
            self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
            self.conv4 = nn.Conv2d(256, 512, 3, stride=2, padding=1)

        def forward(self, x):
            # Return dict of features at different scales
            features = {}
            x = self.conv1(x)
            features['res2'] = x  # 1/4 scale
            x = self.conv2(x)
            features['res3'] = x  # 1/8 scale
            x = self.conv3(x)
            features['res4'] = x  # 1/16 scale
            x = self.conv4(x)
            features['res5'] = x  # 1/32 scale
            return features

    return DummyBackbone()


def create_dummy_pixel_decoder():
    """Create a simple pixel decoder."""
    class DummyPixelDecoder(nn.Module):
        def __init__(self, in_channels=[64, 128, 256, 512], mask_dim=256):
            super().__init__()
            self.mask_dim = mask_dim
            # Simple FPN-like structure
            self.lateral_convs = nn.ModuleList([
                nn.Conv2d(in_c, mask_dim, 1) for in_c in in_channels
            ])
            self.output_conv = nn.Conv2d(mask_dim, mask_dim, 3, padding=1)

        def forward(self, features):
            # Simple feature pyramid
            feature_list = [features['res2'], features['res3'], features['res4'], features['res5']]

            # Upsample and combine features
            mask_features = None
            for i in range(len(feature_list)-1, -1, -1):
                lateral = self.lateral_convs[i](feature_list[i])
                if mask_features is None:
                    mask_features = lateral
                else:
                    # Upsample and add
                    mask_features = F.interpolate(
                        mask_features,
                        size=lateral.shape[-2:],
                        mode='bilinear',
                        align_corners=False
                    )
                    mask_features = mask_features + lateral

            mask_features = self.output_conv(mask_features)

            # For transformer: return mask features and multi-scale features
            multi_scale_features = [lateral_conv(feat)
                                  for lateral_conv, feat in zip(self.lateral_convs, feature_list)]

            return mask_features, None, multi_scale_features

    return DummyPixelDecoder()


def create_dummy_transformer_decoder():
    """Create a simple transformer decoder."""
    class DummyTransformerDecoder(nn.Module):
        def __init__(self, hidden_dim=256, num_queries=100, num_classes=80):
            super().__init__()
            self.num_queries = num_queries
            self.hidden_dim = hidden_dim

            # Query embeddings
            self.query_embed = nn.Embedding(num_queries, hidden_dim)

            # Simple transformer layer
            self.transformer = nn.TransformerDecoderLayer(
                d_model=hidden_dim,
                nhead=8,
                dim_feedforward=2048,
                dropout=0.1,
                batch_first=True
            )

            # Output projections
            self.class_embed = nn.Linear(hidden_dim, num_classes + 1)  # +1 for no-object
            self.mask_embed = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )

        def forward(self, multi_scale_features, mask_features, targets=None):
            B = mask_features.shape[0]

            # Get query embeddings
            query_embeds = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1)

            # Simple attention on mask features (placeholder)
            # In real implementation, this would involve proper cross-attention
            mask_h, mask_w = mask_features.shape[-2:]
            mask_features_flat = mask_features.flatten(2).transpose(1, 2)  # [B, H*W, C]

            # Transformer decoder
            outputs = self.transformer(query_embeds, mask_features_flat)

            # Predict classes
            outputs_class = self.class_embed(outputs)

            # Predict masks
            mask_embeds = self.mask_embed(outputs)
            # Simple dot product with mask features
            outputs_mask = torch.einsum("bqc,bhwc->bqhw",
                                      mask_embeds,
                                      mask_features.permute(0, 2, 3, 1))

            return {
                "pred_logits": outputs_class,
                "pred_masks": outputs_mask,
                "aux_outputs": []  # No auxiliary outputs in this simple version
            }

    return DummyTransformerDecoder()


def main():
    """Example of how to use the PyTorch-native Mask2Former."""

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create model components
    backbone = create_dummy_backbone()
    pixel_decoder = create_dummy_pixel_decoder()
    transformer_decoder = create_dummy_transformer_decoder()

    # Create Mask2Former model
    model = Mask2Former(
        backbone=backbone,
        pixel_decoder=pixel_decoder,
        transformer_decoder=transformer_decoder,
        num_queries=100,
        num_classes=80,  # COCO classes
        instance_on=True,  # Enable instance segmentation
        semantic_on=False,
        panoptic_on=False,
    ).to(device)

    # Create criterion
    matcher = HungarianMatcher(
        cost_class=2.0,
        cost_mask=5.0,
        cost_dice=5.0,
        num_points=12544,  # Sample points for efficiency
    )

    weight_dict = {
        "loss_ce": 2.0,
        "loss_mask": 5.0,
        "loss_dice": 5.0,
    }

    criterion = SetCriterion(
        num_classes=80,
        matcher=matcher,
        weight_dict=weight_dict,
        eos_coef=0.1,  # Coefficient for no-object class
        losses=["labels", "masks"],
    ).to(device)

    # Create dummy input
    batch_size = 2
    images = torch.randn(batch_size, 3, 640, 640).to(device)

    # Create dummy targets for training
    targets = []
    for i in range(batch_size):
        num_objects = 3  # 3 objects per image
        target = {
            "labels": torch.randint(0, 80, (num_objects,)).to(device),
            "masks": torch.rand(num_objects, 640, 640).to(device) > 0.5,  # Binary masks
        }
        targets.append(target)

    # Training mode
    model.train()

    # Forward pass
    outputs = model(images, targets)

    # Calculate losses
    losses = criterion(outputs, targets)

    # Total loss
    total_loss = sum(losses[k] * weight_dict.get(k, 1.0)
                    for k in losses if k in weight_dict)

    print("Training losses:")
    for k, v in losses.items():
        print(f"  {k}: {v.item():.4f}")
    print(f"Total loss: {total_loss.item():.4f}")

    # Inference mode
    model.eval()
    with torch.no_grad():
        predictions = model(images)

    print(f"\nInference output:")
    print(f"  Number of predictions: {len(predictions)}")
    if isinstance(predictions, list) and len(predictions) > 0:
        for i, pred in enumerate(predictions):
            if "instances" in pred:
                inst = pred["instances"]
                print(f"  Image {i}: {len(inst['scores'])} instances detected")
                print(f"    - Classes: {inst['pred_classes'].tolist()}")
                print(f"    - Scores: {[f'{s:.3f}' for s in inst['scores'].tolist()]}")



In [6]:
main()

  with torch.cuda.amp.autocast(enabled=False):


Training losses:
  loss_ce: 4.3159
  loss_mask: 1.0485
  loss_dice: 0.4889
Total loss: 16.3188

Inference output:
  Number of predictions: 2
  Image 0: 99 instances detected
    - Classes: [18, 36, 78, 77, 5, 36, 55, 6, 3, 51, 65, 51, 70, 50, 7, 17, 6, 3, 59, 28, 65, 40, 7, 66, 35, 37, 68, 37, 5, 12, 33, 46, 7, 59, 6, 21, 50, 23, 3, 33, 25, 49, 56, 27, 33, 65, 44, 12, 52, 64, 40, 50, 9, 50, 21, 19, 33, 6, 75, 14, 7, 64, 44, 24, 59, 19, 66, 55, 19, 35, 33, 67, 40, 70, 59, 7, 70, 2, 50, 30, 37, 65, 42, 1, 17, 27, 41, 63, 65, 5, 6, 66, 35, 61, 37, 43, 75, 0, 2]
    - Scores: ['1.644', '1.259', '1.809', '1.978', '1.580', '1.480', '1.162', '1.043', '1.728', '1.400', '1.262', '1.053', '2.016', '1.346', '1.654', '1.376', '1.625', '1.000', '1.317', '1.035', '1.166', '1.434', '1.493', '1.471', '1.370', '1.518', '2.222', '1.972', '1.489', '1.460', '1.152', '1.635', '1.340', '1.400', '1.839', '1.634', '1.895', '1.183', '1.484', '1.508', '1.790', '1.260', '1.944', '0.965', '1.171', '1.652', '1.468