In [1]:
import math
import copy
import numpy as np
from scipy.optimize import linear_sum_assignment
from typing import List, Optional
from collections import defaultdict
from PIL import Image
from PIL import Image, ImageDraw #sample data generation
import io

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_, xavier_normal_
import torchvision
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models._utils import IntermediateLayerGetter

  from .autonotebook import tqdm as notebook_tqdm


## Panoptic API Utils

In [2]:
def rgb2id(color):
    if isinstance(color, np.ndarray) and len(color.shape) == 3:
        if color.dtype == np.uint8:
            color = color.astype(np.int32)
        return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
    return int(color[0] + 256 * color[1] + 256 * 256 * color[2])


def id2rgb(id_map):
    if isinstance(id_map, np.ndarray):
        id_map_copy = id_map.copy()
        rgb_shape = tuple(list(id_map.shape) + [3])
        rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
        for i in range(3):
            rgb_map[..., i] = id_map_copy % 256
            id_map_copy //= 256
        return rgb_map
    color = []
    for _ in range(3):
        color.append(id_map % 256)
        id_map //= 256
    return color

## Positional Encoding

### Changes Made:

- **PositionEmbeddingSine**: It now calculates the sinusoidal embeddings based on the positions of elements in the image grid, suitable for transformers.
- **PositionEmbeddingLearned**: Adjusted to initialize weights more robustly and provided parameters for max dimensions to accommodate different image sizes.
- **build_position_encoding**: This function now supports creating either 'sine' or 'learned' embeddings based on the input arguments, making it flexible for different model configurations.

In [3]:
class PositionEmbeddingSine(nn.Module):
    """Sinusoidal Position Embedding for image-based inputs."""
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        self.scale = scale if scale is not None else 2 * math.pi

    def forward(self, x):
        assert x.mask is not None, "Mask cannot be None"
        mask = x.mask
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)

        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
    
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.tensors.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        
        return pos

In [4]:
class PositionEmbeddingLearned(nn.Module):
    """Learned Position Embedding for image-based inputs."""
    def __init__(self, num_pos_feats=256, max_height=50, max_width=50):
        super().__init__()
        self.row_embed = nn.Embedding(max_height, num_pos_feats)
        self.col_embed = nn.Embedding(max_width, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight, -0.05, 0.05)
        nn.init.uniform_(self.col_embed.weight, -0.05, 0.05)

    def forward(self, x):
        assert x.mask is not None, "Mask cannot be None"
        h, w = x.tensors.shape[-2:]
        i = torch.arange(w, device=x.tensors.device)
        j = torch.arange(h, device=x.tensors.device)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1)], dim=-1)
        pos = pos.permute(2, 0, 1).unsqueeze(0).repeat(x.tensors.size(0), 1, 1, 1)
        
        return pos

In [5]:
def build_position_encoding(hidden_dim, position_embedding='sine', max_height=50, max_width=50):
    """Factory function to build position encoding."""
    num_pos_feats = hidden_dim // 2
    if position_embedding == 'sine':
        return PositionEmbeddingSine(num_pos_feats, normalize=True)
    elif position_embedding == 'learned':
        return PositionEmbeddingLearned(num_pos_feats, max_height, max_width)
    else:
        raise ValueError(f"Not supported position embedding type {position_embedding}")

## Backbone

### Changes Made:

This refactored code maintains the same overall structure but simplifies the definition and initialization of the backbone and its components. The `FrozenBatchNorm2d` function is streamlined by using a regular `BatchNorm2d` and setting the gradient flags to `False`. The `BackboneBase` is merged and simplified into `SimplifiedBackbone`, and the functionality of handling different returning layers and feature channels is preserved.

In [6]:
def frozen_batch_norm2d(num_features, eps=1e-5):
    """Creates a frozen batch normalization layer."""
    layer = nn.BatchNorm2d(num_features, affine=True)
    layer.weight.requires_grad = False
    layer.bias.requires_grad = False
    layer.running_mean.requires_grad = False
    layer.running_var.requires_grad = False
    layer.eps = eps
    
    return layer

In [7]:
class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)

In [8]:
class SimplifiedBackbone(nn.Module):
    """ Simplified backbone model for feature extraction. """
    def __init__(self, model_name, train_layers=False, num_channels=2048, return_interm_layers=False):
        super().__init__()
        backbone = resnet50(weights=ResNet50_Weights.DEFAULT if not train_layers else None, norm_layer=lambda num_features: frozen_batch_norm2d(num_features))
        for name, parameter in backbone.named_parameters():
            if not train_layers or all(not name.startswith(f'layer{i}') for i in [2, 3, 4]):
                parameter.requires_grad = False
        
        # Setup return_layers to map the desired layer to '0'
        return_layers = {'layer4': '0'} if not return_interm_layers else {f'layer{i}': str(i-1) for i in range(1, 5)}
        
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, nested_tensor):
        tensors = nested_tensor.tensors
        mask = nested_tensor.mask
        xs = self.body(tensors)
        
        return {name: NestedTensor(x, mask) for name, x in xs.items()}

In [9]:
class BackboneWithPositionEmbedding(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)
        self.num_channels = backbone.num_channels

    def forward(self, nested_tensor):
        xs = self[0](nested_tensor)  # Calls forward on backbone
        pos = {name: self[1](x) for name, x in xs.items()}  # Generate position embeddings
        return xs, pos

In [10]:
def build_refactored_backbone(model_name, train_layers, num_channels, return_interm_layers, hidden_dim, position_embedding_type):
    """Builds a backbone with integrated position encoding."""
    position_embedding = build_position_encoding(hidden_dim, position_embedding_type)
    backbone = SimplifiedBackbone(model_name, train_layers, num_channels, return_interm_layers)
    return BackboneWithPositionEmbedding(backbone, position_embedding)

## Matcher

### Changes Made:

This version of the Hungarian matcher will compute matching costs using classification scores, point distances, and optionally GIoU scores (if applicable). However, since the original paper discusses a KMO-based matcher, and your model doesn't include explicit GIoU computation for bounding boxes, we'll focus only on classification and point distances for simplicity.


1. **Logits to Probabilities**: The matcher uses `softmax` to convert the logits to probabilities. This is crucial for calculating the classification cost as negative log likelihood, which is more stable and interpretable.
2. **Cost Calculation**: The matcher combines classification and point costs linearly using specified weights. It supports batch processing where each element in the batch has its own set of targets and predictions.
3. **Linear Sum Assignment**: This uses the `linear_sum_assignment` from `SciPy`, which directly finds the minimum cost matching. It's applied batch-wise.

In [11]:
class HungarianMatcher(nn.Module):
    """Computes an assignment between predictions and ground truth targets."""
    def __init__(self, cost_class: float = 1.0, cost_point: float = 1.0):
        super().__init__()
        self.cost_class = cost_class
        self.cost_point = cost_point

    @torch.no_grad()
    def forward(self, outputs, targets):
        """
        Arguments:
            outputs: Dict containing at least 'pred_logits' and 'pred_points'.
                - 'pred_logits': Tensor of shape [batch_size, num_queries, num_classes]
                - 'pred_points': Tensor of shape [batch_size, num_queries, 4] (for bounding box points)
            targets: List of dictionaries containing:
                - 'labels': Tensor of shape [num_target_points]
                - 'points': Tensor of shape [num_target_points, 4]

        Returns:
            List of tuples for each batch element, containing:
            - Index of selected predictions.
            - Index of corresponding selected targets.
        """
        bs, num_queries = outputs['pred_logits'].shape[:2]
        out_prob = outputs['pred_logits'].softmax(-1)  # Convert logits to probabilities
        out_points = outputs['pred_points'].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Concatenate all targets across the batch
        tgt_points = torch.cat([t['points'] for t in targets]).to(out_points.device)
        tgt_labels = torch.cat([t['labels'] for t in targets])

        # Compute classification cost using negative log likelihood
        cost_class = -out_prob[:, tgt_labels].flatten(0, 1)

        # Compute L1 cost between predicted points and target points
        cost_point = torch.cdist(out_points, tgt_points, p=1)

        # Combine costs
        C = self.cost_class * cost_class + self.cost_point * cost_point
        C = C.view(bs, num_queries, -1).cpu()

        # Compute assignment for each batch element
        indices = [linear_sum_assignment(c) for c in C]
        return [(torch.as_tensor(i, dtype=torch.long), torch.as_tensor(j, dtype=torch.long)) for i, j in indices]

In [12]:
def build_matcher(cost_class=1.0, cost_point=1.0):
    return HungarianMatcher(cost_class, cost_point)

## Multihead Attention

### Changes Made:

This `MultiheadAttention`:
- Directly uses the input dimensions of query, key, and value without additional projection, simplifying the architecture.
- Uses scaled dot-product attention mechanism, following the principle outlined in "Attention is All You Need".
- Provides dropout for regularization and a final linear projection to match the output dimensions to the input dimensions.
- Is flexible with regard to whether it returns attention weights, allowing for easier debugging or further manipulation.

In [13]:
class MultiheadAttention(nn.Module):
    """Custom implementation of MultiheadAttention to support different dimensions
    for query, key, and value without separate projection matrices for each."""

    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads  # Make sure this is set
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        # Dropout
        self.dropout_layer = nn.Dropout(dropout)

        # Parameter initialization
        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.out_proj.weight)
        if self.out_proj.bias is not None:
            constant_(self.out_proj.bias, 0)

    def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
        """ Forward pass for custom multihead attention.
        Assumes L = target sequence length, S = source sequence length, N = batch size, E = embedding dimension.
        """
        tgt_len, bsz, embed_dim = query.size()
        assert query.size() == key.size() == value.size(), "Query, Key and Value must be of the same size"

        # Scale query for dot product attention
        scaling = self.head_dim ** -0.5
        query = query * scaling

        # Calculate Q, K, V
        q = query.reshape(tgt_len, bsz * self.num_heads, self.head_dim)
        k = key.reshape(-1, bsz * self.num_heads, self.head_dim)
        v = value.reshape(-1, bsz * self.num_heads, self.head_dim)

        # Dot product of Q and K (transpose)
        attn_output_weights = torch.bmm(q, k.transpose(1, 2))

        if attn_mask is not None:
            attn_output_weights += attn_mask

        if key_padding_mask is not None:
            attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, -1)
            attn_output_weights = attn_output_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float('-inf')
            )
            attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, -1)

        # Apply softmax and dropout on attention weights
        attn_output_weights = F.softmax(attn_output_weights, dim=-1)
        attn_output_weights = self.dropout_layer(attn_output_weights)

        # Multiply weights by V
        attn_output = torch.bmm(attn_output_weights, v)

        # Transpose and reshape to bring back to original dimensions
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)

        # Apply final linear projection
        attn_output = self.out_proj(attn_output)

        if need_weights:
            attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, -1)
            attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads
        else:
            attn_output_weights = None

        return attn_output, attn_output_weights

## Transformer Model

In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)]

In [15]:
def _get_clones(module, N):
    """Create N identical layers by deep copying the given module."""
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    elif activation == "glu":
        return F.glu
    raise RuntimeError("Activation function {} is not supported".format(activation))

### Encoder Model

In [16]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.activation = _get_activation_fn(activation)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout(src2)
        src = self.norm2(src)
        return src

In [17]:
class TransformerEncoder(nn.Module):
    def __init__(self, layer, n_layers):
        super().__init__()
        self.layers = _get_clones(layer, n_layers)
        self.norm = nn.LayerNorm(layer.self_attn.embed_dim)

    def forward(self, src, mask=None, src_key_padding_mask=None):
        for layer in self.layers:
            src = layer(src, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
        src = self.norm(src)
        return src

### Decoder Model

In [18]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.activation = _get_activation_fn(activation)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm3(tgt)
        return tgt

In [19]:
class TransformerDecoder(nn.Module):
    def __init__(self, layer, n_layers):
        super().__init__()
        self.layers = _get_clones(layer, n_layers)
        self.norm = nn.LayerNorm(layer.self_attn.embed_dim)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        for layer in self.layers:
            tgt = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)

        tgt = self.norm(tgt)
        
        return tgt

In [20]:
class TransformerModel(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_embed):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.nhead = encoder.layers[0].self_attn.num_heads
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.pos_embed = pos_embed
        self.d_model = encoder.layers[0].self_attn.embed_dim  # Assuming all layers have the same embed_dim

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
        memory = self.encoder(self.pos_embed(self.src_embed(src)), src_mask, src_key_padding_mask)
        output = self.decoder(self.pos_embed(self.tgt_embed(tgt)), memory, tgt_mask, None, tgt_key_padding_mask, None)
        return output

In [21]:
def build_transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, activation='relu'):
    """Function to build the transformer model."""
    encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
    encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
    
    decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
    decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
    
    src_embed = nn.Embedding(num_embeddings=10, embedding_dim=d_model)  # Placeholder for actual embedding logic
    tgt_embed = nn.Embedding(num_embeddings=10, embedding_dim=d_model)  # Placeholder for actual embedding logic
    pos_encoder = PositionalEncoding(d_model=d_model)
    
    return TransformerModel(encoder, decoder, src_embed, tgt_embed, pos_encoder)

## Segmentation

### Custom Modules

In [22]:
class MLP(nn.Module):
    """ Multi-layer Perceptron (MLP) with variable number of layers """
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

In [23]:
class AttentionMap(nn.Module):
    def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

        self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        nn.init.zeros_(self.k_linear.bias)
        nn.init.zeros_(self.q_linear.bias)
        nn.init.xavier_uniform_(self.k_linear.weight)
        nn.init.xavier_uniform_(self.q_linear.weight)
        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

    def forward(self, q, k, mask: Optional[Tensor] = None):
        q = self.q_linear(q)
        k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
        qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
        kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)

        if mask is not None:
            weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
        weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size())
        weights = self.dropout(weights)
        return weights

In [24]:
class ConvolutionalMaskHead(nn.Module):
    def __init__(self, dim, fpn_dims, context_dim):
        super().__init__()

        inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
        self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
        self.gn1 = nn.GroupNorm(8, dim)
        self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
        self.gn2 = nn.GroupNorm(8, inter_dims[1])
        self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
        self.gn3 = nn.GroupNorm(8, inter_dims[2])
        self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
        self.gn4 = nn.GroupNorm(8, inter_dims[3])
        self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
        self.gn5 = nn.GroupNorm(8, inter_dims[4])
        self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)

        self.dim = dim

        self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
        self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
        self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
        x = torch.cat([self._expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)

        x = self.lay1(x)
        x = self.gn1(x)
        x = F.relu(x)
        x = self.lay2(x)
        x = self.gn2(x)
        x = F.relu(x)

        cur_fpn = self.adapter1(fpns[0])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = self._expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay3(x)
        x = self.gn3(x)
        x = F.relu(x)

        cur_fpn = self.adapter2(fpns[1])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = self._expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay4(x)
        x = self.gn4(x)
        x = F.relu(x)

        cur_fpn = self.adapter3(fpns[2])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = self._expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay5(x)
        x = self.gn5(x)
        x = F.relu(x)

        x = self.out_lay(x)
        return x

    def _expand(self, tensor, length: int):
        return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)

In [25]:
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    if tensor_list[0].ndim == 3:
        max_size = [max(s) for s in zip(*[img.shape for img in tensor_list])]
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)


class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)

In [26]:
class SegmentationModel(nn.Module):
    def __init__(self, backbone, transformer, num_queries, aux_loss=False):
        super().__init__()
        self.backbone = backbone
        self.transformer = transformer
        self.num_queries = num_queries
        self.aux_loss = aux_loss

        hidden_dim, nheads = transformer.d_model, transformer.nhead
        self.bbox_attention = AttentionMap(hidden_dim, hidden_dim, nheads)
        self.mask_head = ConvolutionalMaskHead(hidden_dim + nheads, [1024, 512, 256], hidden_dim)

        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.class_embed = nn.Linear(hidden_dim, 91)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)

    def forward(self, samples):
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        src, mask = features[-1].decompose()
        src_proj = self.input_proj(src)
        hs, memory = self.transformer(src_proj, mask, self.query_embed.weight, pos[-1])

        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)

        bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
        seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
        outputs_seg_masks = seg_masks.view(samples.tensors.shape[0], self.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])

        out["pred_masks"] = outputs_seg_masks
        return out

    def _set_aux_loss(self, outputs_class, outputs_coord):
        return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

In [27]:
class PostProcess(nn.Module):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold

    @torch.no_grad()
    def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
        assert len(orig_target_sizes) == len(max_target_sizes)
        max_h, max_w = max_target_sizes.max(0)[0].tolist()
        outputs_masks = outputs["pred_masks"].squeeze(2)
        outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
        outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()

        for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
            img_h, img_w = t[0], t[1]
            results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
            results[i]["masks"] = F.interpolate(results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest").byte()

        return results

In [28]:
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)


def box_xyxy_to_cxcywh(x):
    x0, y0, x1, y1 = x.unbind(-1)
    b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
    return torch.stack(b, dim=-1)


def masks_to_boxes(masks):
    if masks.numel() == 0:
        return torch.zeros((0, 4), device=masks.device)

    h, w = masks.shape[-2:]
    y = torch.arange(0, h, dtype=torch.float)
    x = torch.arange(0, w, dtype=torch.float)
    y, x = torch.meshgrid(y, x)

    x_mask = (masks * x.unsqueeze(0))
    x_max = x_mask.flatten(1).max(-1)[0]
    x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]

    y_mask = (masks * y.unsqueeze(0))
    y_max = y_mask.flatten(1).max(-1)[0]
    y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]

    return torch.stack([x_min, y_min, x_max, y_max], 1)

In [29]:
class PostProcessPanoptic(nn.Module):
    def __init__(self, is_thing_map, threshold=0.85):
        super().__init__()
        self.threshold = threshold
        self.is_thing_map = is_thing_map

    def forward(self, outputs, processed_sizes, target_sizes=None):
        if target_sizes is None:
            target_sizes = processed_sizes
        assert len(processed_sizes) == len(target_sizes)
        out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"]
        assert len(out_logits) == len(raw_masks) == len(target_sizes)
        preds = []

        def to_tuple(tup):
            if isinstance(tup, tuple):
                return tup
            return tuple(tup.cpu().tolist())

        for cur_logits, cur_masks, cur_boxes, size, target_size in zip(out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes):
            scores, labels = cur_logits.softmax(-1).max(-1)
            keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold)
            cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
            cur_scores = cur_scores[keep]
            cur_classes = cur_classes[keep]
            cur_masks = cur_masks[keep]
            cur_masks = F.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
            cur_boxes = box_cxcywh_to_xyxy(cur_boxes[keep])

            h, w = cur_masks.shape[-2:]
            assert len(cur_boxes) == len(cur_classes)

            cur_masks = cur_masks.flatten(1)
            stuff_equiv_classes = defaultdict(lambda: [])
            for k, label in enumerate(cur_classes):
                if not self.is_thing_map[label.item()]:
                    stuff_equiv_classes[label.item()].append(k)

            def get_ids_area(masks, scores, dedup=False):
                m_id = masks.transpose(0, 1).softmax(-1)
                if m_id.shape[-1] == 0:
                    m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
                else:
                    m_id = m_id.argmax(-1).view(h, w)
                if dedup:
                    for equiv in stuff_equiv_classes.values():
                        if len(equiv) > 1:
                            for eq_id in equiv:
                                m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
                final_h, final_w = to_tuple(target_size)
                seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
                seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
                np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy()
                m_id = torch.from_numpy(rgb2id(np_seg_img))
                area = []
                for i in range(len(scores)):
                    area.append(m_id.eq(i).sum().item())
                return area, seg_img

            area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
            if cur_classes.numel() > 0:
                while True:
                    filtered_small = torch.as_tensor([area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device)
                    if filtered_small.any().item():
                        cur_scores = cur_scores[~filtered_small]
                        cur_classes = cur_classes[~filtered_small]
                        cur_masks = cur_masks[~filtered_small]
                        area, seg_img = get_ids_area(cur_masks, cur_scores)
                    else:
                        break
            else:
                cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)

            segments_info = []
            for i, a in enumerate(area):
                cat = cur_classes[i].item()
                segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a})
            del cur_classes

            with io.BytesIO() as out:
                seg_img.save(out, format="PNG")
                predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
            preds.append(predictions)
        return preds

In [30]:
class DETRsegm(nn.Module):
    def __init__(self, detr, freeze_detr=False):
        super().__init__()
        self.detr = detr

        if freeze_detr:
            for p in self.detr.parameters():
                p.requires_grad_(False)

        hidden_dim = self.detr.transformer.d_model
        nheads = self.detr.transformer.nhead  # This assumes nhead is correctly exposed
        self.bbox_attention = AttentionMap(hidden_dim, hidden_dim, nheads)
        self.mask_head = ConvolutionalMaskHead(hidden_dim + nheads, [1024, 512, 256], hidden_dim)

    def forward(self, samples: NestedTensor):
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.detr.backbone(samples)

        # Here, ensure '0' or the correct key is used
        bs = features['0'].tensors.shape[0]

        src, mask = features['0'].decompose()
        assert mask is not None
        src_proj = self.detr.input_proj(src)
        hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])

        outputs_class = self.detr.class_embed(hs)
        outputs_coord = self.detr.point_embed(hs).sigmoid()
        out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
        if self.detr.aux_loss:
            out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord)

        bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)

        seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
        outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])

        out["pred_masks"] = outputs_seg_masks
        
        return out

In [31]:
class PostProcessSegm(nn.Module):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold

    @torch.no_grad()
    def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
        assert len(orig_target_sizes) == len(max_target_sizes)
        max_h, max_w = max_target_sizes.max(0)[0].tolist()
        outputs_masks = outputs["pred_masks"].squeeze(2)
        outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
        outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()

        for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
            img_h, img_w = t[0], t[1]
            results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
            results[i]["masks"] = F.interpolate(
                results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
            ).byte()

        return results

## Conditional DETR Model

### Utility Functions

In [32]:
# def box_cxcywh_to_xyxy(x):
#     """ Convert box format from [center_x, center_y, width, height] to [x_min, y_min, x_max, y_max] """
#     x_c, y_c, w, h = x.unbind(-1)
#     b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
#     return torch.stack(b, dim=-1)

def inverse_sigmoid(x, eps=1e-5):
    """ Convert probabilities to logits with safe clamping """
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1 / x2)

In [33]:
def is_dist_avail_and_initialized():
    if not torch.distributed.is_available():
        return False
    if not torch.distributed.is_initialized():
        return False
    return True

def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return torch.distributed.get_world_size()

In [34]:
def accuracy(output, target, topk=(1,)):
    if target.numel() == 0:
        return [torch.zeros([], device=output.device)]
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [35]:
def dice_loss(inputs, targets, num_boxes):
    inputs = inputs.sigmoid().flatten(1)
    numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_boxes

def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_boxes

### Conditional DETR Model

In [36]:
class ConditionalDETR(nn.Module):
    """ Conditional DETR module for object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, channel_point, aux_loss=False):
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.point_embed = MLP(hidden_dim, hidden_dim, channel_point, 3)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.backbone = backbone
        self.aux_loss = aux_loss

        # Initialize biases for focal loss
        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes) * bias_value

        # Initialize point embedding layers
        nn.init.constant_(self.point_embed.layers[-1].weight.data, 0)
        nn.init.constant_(self.point_embed.layers[-1].bias.data, 0)

    def forward(self, samples: NestedTensor):
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        src, mask = features[-1].decompose()
        assert mask is not None
        hs, reference = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])

        reference_before_sigmoid = inverse_sigmoid(reference)
        outputs_coords = []
        for lvl in range(hs.shape[0]):
            tmp = self.point_embed(hs[lvl])
            tmp[..., :2] += reference_before_sigmoid
            outputs_coord = tmp.sigmoid()
            outputs_coords.append(outputs_coord)
        outputs_coord = torch.stack(outputs_coords)

        outputs_class = self.class_embed(hs)
        out = {'pred_logits': outputs_class[-1], 'pred_points': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        return [{'pred_logits': a, 'pred_points': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

### Set Criterion for Loss Computation

In [37]:
class SetCriterion(nn.Module):
    """ Computes the loss for Conditional DETR """
    def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses):
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.focal_alpha = focal_alpha

    def loss_labels(self, outputs, targets, indices, num_points, log=True):
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]).cuda()
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

        target_classes_onehot = target_classes_onehot[:, :, :-1]
        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_points, alpha=self.focal_alpha, gamma=2) * \
                  src_logits.shape[1]
        losses = {'loss_ce': loss_ce}

        if log:
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_points):
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_points(self, outputs, targets, indices, num_points):
        assert 'pred_points' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_points = outputs['pred_points'][idx]
        target_points = torch.cat([t['points'][i] for t, (_, i) in zip(targets, indices)], dim=0).cuda()

        loss_point = F.l1_loss(src_points, target_points, reduction='none')
        losses = {'loss_point': loss_point.sum() / num_points}
        return losses

    def loss_masks(self, outputs, targets, indices, num_points):
        assert "pred_masks" in outputs

        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)
        src_masks = outputs["pred_masks"]
        src_masks = src_masks[src_idx]
        masks = [t["masks"] for t in targets]
        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
        target_masks = target_masks.to(src_masks)
        target_masks = target_masks[tgt_idx]

        src_masks = F.interpolate(src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False)
        src_masks = src_masks[:, 0].flatten(1)

        target_masks = target_masks.flatten(1)
        target_masks = target_masks.view(src_masks.shape)
        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_points),
            "loss_dice": dice_loss(src_masks, target_masks, num_points),
        }
        return losses

    def _get_src_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_points, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'points': self.loss_points,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'Unsupported loss: {loss}'
        return loss_map[loss](outputs, targets, indices, num_points, **kwargs)

    def forward(self, outputs, targets):
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}

        indices = self.matcher(outputs_without_aux, targets)

        num_points = sum(len(t["labels"]) for t in targets)
        num_points = torch.as_tensor([num_points], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_points)
        num_points = torch.clamp(num_points / get_world_size(), min=1).item()

        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_points))

        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss == 'masks':
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        kwargs = {'log': False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_points, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses

### Post-Processing Module

In [38]:
class PostProcess(nn.Module):
    """ Convert model's output to the format expected by the COCO API """
    @torch.no_grad()
    def forward(self, outputs, target_sizes):
        out_logits, out_points = outputs['pred_logits'], outputs['pred_points']
        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        prob = out_logits.sigmoid()
        topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
        scores = topk_values
        topk_points = topk_indexes // out_logits.shape[2]
        labels = topk_indexes % out_logits.shape[2]
        points = box_cxcywh_to_xyxy(out_points)
        points = torch.gather(points, 1, topk_points.unsqueeze(-1).repeat(1, 1, 4))

        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        points = points * scale_fct[:, None, :]

        results = [{'scores': s, 'labels': l, 'points': b} for s, l, b in zip(scores, labels, points)]
        return results

### Function to Build the Model

In [39]:
def build_conditional_detr(args):
    """
    Builds the Conditional DETR model.
    
    Arguments:
    - args: configuration arguments for building the model.
    
    Returns:
    - model: the Conditional DETR model.
    - criterion: loss function used for training.
    - postprocessors: post-processing modules for transforming model outputs.
    """
    num_classes = 2 if args.dataset_file != 'coco' else 91
    if args.dataset_file == "coco_panoptic":
        num_classes = 250

    device = torch.device(args.device)

    # Build backbone using your refactored method
    backbone = build_refactored_backbone(
        model_name=args.backbone,
        train_layers=args.lr_backbone > 0,
        num_channels=args.num_channels,
        return_interm_layers=args.masks,
        hidden_dim=args.hidden_dim,
        position_embedding_type=args.position_embedding
    )

    # Build transformer using your provided method
    transformer = build_transformer(
        d_model=args.hidden_dim,
        nhead=args.nheads,
        num_encoder_layers=args.enc_layers,
        num_decoder_layers=args.dec_layers,
        dim_feedforward=args.dim_feedforward,
        dropout=args.dropout,
        activation='relu'  # or use args.activation if provided
    )

    # Create the ConditionalDETR model
    model = ConditionalDETR(
        backbone=backbone,
        transformer=transformer,
        num_classes=num_classes,
        num_queries=args.num_queries,
        channel_point=args.channel_point,
        aux_loss=args.aux_loss,
    )

    # Wrap with DETRsegm if masks are enabled
    if args.masks:
        model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))

    # Build matcher using your provided method
    matcher = build_matcher(cost_class=args.cost_class, cost_point=args.cost_point)
    
    # Define weight dictionary for losses
    weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_point': args.point_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef
    if args.masks:
        weight_dict["loss_mask"] = args.mask_loss_coef
        weight_dict["loss_dice"] = args.dice_loss_coef

    # Add auxiliary loss weights if aux_loss is enabled
    if args.aux_loss:
        aux_weight_dict = {}
        for i in range(args.dec_layers - 1):
            aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)

    # Define the losses to be used
    losses = ['labels', 'points', 'cardinality']
    if args.masks:
        losses += ["masks"]

    # Create the criterion (loss function)
    criterion = SetCriterion(
        num_classes=num_classes, 
        matcher=matcher, 
        weight_dict=weight_dict,
        focal_alpha=args.focal_alpha, 
        losses=losses
    )
    criterion.to(device)

    # Define postprocessors
    postprocessors = {'point': PostProcess()}
    if args.masks:
        postprocessors['segm'] = PostProcessSegm()
        if args.dataset_file == "coco_panoptic":
            is_thing_map = {i: i <= 90 for i in range(201)}
            postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)

    return model, criterion, postprocessors

# Sample Testing

## Data Generation

In [40]:
def generate_synthetic_data(num_samples, image_size=(256, 256), num_objects=5, seed=54):
    """
    Generates synthetic images and annotations for testing.
    
    Args:
    - num_samples: Number of samples to generate.
    - image_size: Size of the synthetic image (height, width).
    - num_objects: Number of objects per image.
    
    Returns:
    - images: List of synthetic images as Tensors.
    - targets: List of dictionaries containing 'labels' and 'points'.
    """
    images = []
    targets = []

    rng = np.random.default_rng(seed)  # Create a random number generator instance
    
    for _ in range(num_samples):
        image = Image.new('RGB', image_size, (255, 255, 255))
        draw = ImageDraw.Draw(image)
        
        sample_labels = []
        sample_points = []
        sample_masks = []
        
        for obj_id in range(num_objects):
            # Random object size and position
            w, h = rng.integers(10, 40, size=2)
            x = rng.integers(0, image_size[1] - w)
            y = rng.integers(0, image_size[0] - h)
            
            # Draw rectangle on the image
            draw.rectangle([x, y, x+w, y+h], outline='black', fill=(rng.integers(255), rng.integers(255), rng.integers(255)))
            
            # Generate label and point
            sample_labels.append(torch.tensor([obj_id % 2]))  # Label alternating between 0 and 1 for binary classification
            sample_points.append(torch.tensor([x + w/2, y + h/2, w, h], dtype=torch.float32))
            
            # Generate a binary mask
            mask = torch.zeros(image_size, dtype=torch.uint8)
            mask[y:y+h, x:x+w] = 1
            sample_masks.append(mask)
        
        images.append(torch.tensor(np.array(image).transpose(2, 0, 1), dtype=torch.float32) / 255.0)
        
        targets.append({
            'labels': torch.stack(sample_labels),
            'points': torch.stack(sample_points),
            'masks': torch.stack(sample_masks)
        })
    
    return images, targets

# Now you can call the generate_synthetic_data function without warnings
synthetic_images, synthetic_targets = generate_synthetic_data(num_samples=2)

## Build Model

In [41]:
class Args:
    dataset_file = 'test'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    backbone = 'resnet50'
    lr_backbone = 0.0
    num_channels = 2048
    return_interm_layers = False
    hidden_dim = 256
    position_embedding = 'sine'
    nheads = 8
    enc_layers = 6
    dec_layers = 6
    dim_feedforward = 1024
    dropout = 0.1
    num_queries = 10
    channel_point = 4
    aux_loss = False
    masks = True
    frozen_weights = None
    cls_loss_coef = 1.0
    point_loss_coef = 1.0
    giou_loss_coef = 1.0
    mask_loss_coef = 1.0
    dice_loss_coef = 1.0
    cost_class = 1.0
    cost_point = 1.0
    focal_alpha = 0.25

args = Args()

# Build model
model, criterion, postprocessors = build_conditional_detr(args)
model = model.to(args.device)

## Run Pipeline

In [42]:
# Convert to nested tensor format
def to_nested_tensor(tensors):
    if tensors[0].ndim == 3:
        max_size = [max(s) for s in zip(*[img.shape for img in tensors])]
        batch_shape = [len(tensors)] + max_size
        b, c, h, w = batch_shape
        dtype = tensors[0].dtype
        device = tensors[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensors, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('Unsupported tensor list format')
    return NestedTensor(tensor, mask)

# Preprocess data
nested_samples = to_nested_tensor([img.to(args.device) for img in synthetic_images])

# Forward pass through the model
model.eval()
with torch.no_grad():
    outputs = model(nested_samples)

# Post-process the results
results = postprocessors['point'](outputs, torch.tensor([list(img.shape[1:]) for img in synthetic_images]))

if args.masks:
    seg_results = postprocessors['segm'](results, outputs, torch.tensor([list(img.shape[1:]) for img in synthetic_images]), torch.tensor([list(img.shape[1:]) for img in synthetic_images]))

# Display results
for i, res in enumerate(results):
    print(f"Image {i}:")
    print(f"Scores: {res['scores']}")
    print(f"Labels: {res['labels']}")
    print(f"Points: {res['points']}")

    if args.masks:
        print(f"Masks: {seg_results[i]['masks'].shape}")

RuntimeError: Given groups=1, weight of size [256, 2048, 1, 1], expected input[2, 256, 64, 64] to have 2048 channels, but got 256 channels instead