In [1]:
import math
import copy
import numpy as np
from scipy.optimize import linear_sum_assignment
from typing import Optional

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_, xavier_normal_
import torchvision
from torchvision.models._utils import IntermediateLayerGetter

  from .autonotebook import tqdm as notebook_tqdm


## Positional Encoding

### Changes Made:

- **PositionEmbeddingSine**: It now calculates the sinusoidal embeddings based on the positions of elements in the image grid, suitable for transformers.
- **PositionEmbeddingLearned**: Adjusted to initialize weights more robustly and provided parameters for max dimensions to accommodate different image sizes.
- **build_position_encoding**: This function now supports creating either 'sine' or 'learned' embeddings based on the input arguments, making it flexible for different model configurations.

In [None]:
class PositionEmbeddingSine(nn.Module):
    """Sinusoidal Position Embedding for image-based inputs."""
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        self.scale = scale if scale is not None else 2 * math.pi

    def forward(self, x):
        assert x.mask is not None, "Mask cannot be None"
        mask = x.mask
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)

        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
    
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.tensors.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        
        return pos

In [None]:
class PositionEmbeddingLearned(nn.Module):
    """Learned Position Embedding for image-based inputs."""
    def __init__(self, num_pos_feats=256, max_height=50, max_width=50):
        super().__init__()
        self.row_embed = nn.Embedding(max_height, num_pos_feats)
        self.col_embed = nn.Embedding(max_width, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight, -0.05, 0.05)
        nn.init.uniform_(self.col_embed.weight, -0.05, 0.05)

    def forward(self, x):
        assert x.mask is not None, "Mask cannot be None"
        h, w = x.tensors.shape[-2:]
        i = torch.arange(w, device=x.tensors.device)
        j = torch.arange(h, device=x.tensors.device)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1)], dim=-1)
        pos = pos.permute(2, 0, 1).unsqueeze(0).repeat(x.tensors.size(0), 1, 1, 1)
        
        return pos

In [None]:
def build_position_encoding(hidden_dim, position_embedding='sine', max_height=50, max_width=50):
    """Factory function to build position encoding."""
    num_pos_feats = hidden_dim // 2
    if position_embedding == 'sine':
        return PositionEmbeddingSine(num_pos_feats, normalize=True)
    elif position_embedding == 'learned':
        return PositionEmbeddingLearned(num_pos_feats, max_height, max_width)
    else:
        raise ValueError(f"Not supported position embedding type {position_embedding}")

## Backbone

### Changes Made:

This refactored code maintains the same overall structure but simplifies the definition and initialization of the backbone and its components. The `FrozenBatchNorm2d` function is streamlined by using a regular `BatchNorm2d` and setting the gradient flags to `False`. The `BackboneBase` is merged and simplified into `SimplifiedBackbone`, and the functionality of handling different returning layers and feature channels is preserved.

In [None]:
def frozen_batch_norm2d(num_features, eps=1e-5):
    """Creates a frozen batch normalization layer."""
    layer = nn.BatchNorm2d(num_features, affine=True)
    layer.weight.requires_grad = False
    layer.bias.requires_grad = False
    layer.running_mean.requires_grad = False
    layer.running_var.requires_grad = False
    layer.eps = eps
    
    return layer

In [None]:
class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)

In [None]:
class SimplifiedBackbone(nn.Module):
    """Simplified backbone model for feature extraction."""
    def __init__(self, model_name, train_layers=False, num_channels=2048, return_interm_layers=False):
        super().__init__()
        backbone = getattr(torchvision.models, model_name)(pretrained=True, norm_layer=lambda num_features: frozen_batch_norm2d(num_features))

        for name, parameter in backbone.named_parameters():
            if not train_layers or all(not name.startswith(f'layer{i}') for i in [2, 3, 4]):
                parameter.requires_grad = False

        return_layers = {f'layer{i}': str(i-1) for i in range(1, 5)} if return_interm_layers else {'layer4': '0'}
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, x):
        xs = self.body(x.tensors)

        return {name: NestedTensor(x, x.mask) for name, x in xs.items()}

In [None]:
class BackboneWithPositionEmbedding(nn.Sequential):
    """Combines a backbone and a position embedding module."""
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, x):
        xs = self[0](x)
        pos = [self[1](x) for x in xs.values()]
        return xs, pos

In [None]:
def build_refactored_backbone(model_name, train_layers, num_channels, return_interm_layers, hidden_dim, position_embedding_type):
    """Builds a backbone with integrated position encoding."""
    position_embedding = build_position_encoding(hidden_dim, position_embedding_type)
    backbone = SimplifiedBackbone(model_name, train_layers, num_channels, return_interm_layers)
    return BackboneWithPositionEmbedding(backbone, position_embedding)

## Matcher

### Changes Made:

This version of the Hungarian matcher will compute matching costs using classification scores, point distances, and optionally GIoU scores (if applicable). However, since the original paper discusses a KMO-based matcher, and your model doesn't include explicit GIoU computation for bounding boxes, we'll focus only on classification and point distances for simplicity.


1. **Logits to Probabilities**: The matcher uses `softmax` to convert the logits to probabilities. This is crucial for calculating the classification cost as negative log likelihood, which is more stable and interpretable.
2. **Cost Calculation**: The matcher combines classification and point costs linearly using specified weights. It supports batch processing where each element in the batch has its own set of targets and predictions.
3. **Linear Sum Assignment**: This uses the `linear_sum_assignment` from `SciPy`, which directly finds the minimum cost matching. It's applied batch-wise.

In [None]:
class HungarianMatcher(nn.Module):
    """Computes an assignment between predictions and ground truth targets."""
    def __init__(self, cost_class: float = 1.0, cost_point: float = 1.0):
        super().__init__()
        self.cost_class = cost_class
        self.cost_point = cost_point

    @torch.no_grad()
    def forward(self, outputs, targets):
        """
        Arguments:
            outputs: Dict containing at least 'pred_logits' and 'pred_points'.
                - 'pred_logits': Tensor of shape [batch_size, num_queries, num_classes]
                - 'pred_points': Tensor of shape [batch_size, num_queries, 4] (for bounding box points)
            targets: List of dictionaries containing:
                - 'labels': Tensor of shape [num_target_points]
                - 'points': Tensor of shape [num_target_points, 4]

        Returns:
            List of tuples for each batch element, containing:
            - Index of selected predictions.
            - Index of corresponding selected targets.
        """
        bs, num_queries = outputs['pred_logits'].shape[:2]
        out_prob = outputs['pred_logits'].softmax(-1)  # Convert logits to probabilities
        out_points = outputs['pred_points'].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Concatenate all targets across the batch
        tgt_points = torch.cat([t['points'] for t in targets]).to(out_points.device)
        tgt_labels = torch.cat([t['labels'] for t in targets])

        # Compute classification cost using negative log likelihood
        cost_class = -out_prob[:, tgt_labels].flatten(0, 1)

        # Compute L1 cost between predicted points and target points
        cost_point = torch.cdist(out_points, tgt_points, p=1)

        # Combine costs
        C = self.cost_class * cost_class + self.cost_point * cost_point
        C = C.view(bs, num_queries, -1).cpu()

        # Compute assignment for each batch element
        indices = [linear_sum_assignment(c) for c in C]
        return [(torch.as_tensor(i, dtype=torch.long), torch.as_tensor(j, dtype=torch.long)) for i, j in indices]

In [None]:
def build_matcher(cost_class=1.0, cost_point=1.0):
    return HungarianMatcher(cost_class, cost_point)

## Multihead Attention

### Changes Made:

This `MultiheadAttention`:
- Directly uses the input dimensions of query, key, and value without additional projection, simplifying the architecture.
- Uses scaled dot-product attention mechanism, following the principle outlined in "Attention is All You Need".
- Provides dropout for regularization and a final linear projection to match the output dimensions to the input dimensions.
- Is flexible with regard to whether it returns attention weights, allowing for easier debugging or further manipulation.

In [None]:
class MultiheadAttention(nn.Module):
    """Custom implementation of MultiheadAttention to support different dimensions
    for query, key, and value without separate projection matrices for each."""

    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        # Dropout
        self.dropout_layer = nn.Dropout(dropout)

        # Parameter initialization
        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.out_proj.weight)
        if self.out_proj.bias is not None:
            constant_(self.out_proj.bias, 0)

    def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
        """ Forward pass for custom multihead attention.
        Assumes L = target sequence length, S = source sequence length, N = batch size, E = embedding dimension.
        """
        tgt_len, bsz, embed_dim = query.size()
        assert query.size() == key.size() == value.size(), "Query, Key and Value must be of the same size"

        # Scale query for dot product attention
        scaling = self.head_dim ** -0.5
        query = query * scaling

        # Calculate Q, K, V
        q = query.reshape(tgt_len, bsz * self.num_heads, self.head_dim)
        k = key.reshape(-1, bsz * self.num_heads, self.head_dim)
        v = value.reshape(-1, bsz * self.num_heads, self.head_dim)

        # Dot product of Q and K (transpose)
        attn_output_weights = torch.bmm(q, k.transpose(1, 2))

        if attn_mask is not None:
            attn_output_weights += attn_mask

        if key_padding_mask is not None:
            attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, -1)
            attn_output_weights = attn_output_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float('-inf')
            )
            attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, -1)

        # Apply softmax and dropout on attention weights
        attn_output_weights = F.softmax(attn_output_weights, dim=-1)
        attn_output_weights = self.dropout_layer(attn_output_weights)

        # Multiply weights by V
        attn_output = torch.bmm(attn_output_weights, v)

        # Transpose and reshape to bring back to original dimensions
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)

        # Apply final linear projection
        attn_output = self.out_proj(attn_output)

        if need_weights:
            attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, -1)
            attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads
        else:
            attn_output_weights = None

        return attn_output, attn_output_weights

## Transformer Model

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)]

In [None]:
def _get_clones(module, N):
    """Create N identical layers by deep copying the given module."""
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    elif activation == "glu":
        return F.glu
    raise RuntimeError("Activation function {} is not supported".format(activation))

### Encoder Model

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.activation = _get_activation_fn(activation)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout(src2)
        src = self.norm2(src)
        return src

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, layer, n_layers):
        super().__init__()
        self.layers = _get_clones(layer, n_layers)
        self.norm = nn.LayerNorm(layer.self_attn.embed_dim)

    def forward(self, src, mask=None, src_key_padding_mask=None):
        for layer in self.layers:
            src = layer(src, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
        src = self.norm(src)
        return src

### Decoder Model

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.activation = _get_activation_fn(activation)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm3(tgt)
        return tgt

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, layer, n_layers):
        super().__init__()
        self.layers = _get_clones(layer, n_layers)
        self.norm = nn.LayerNorm(layer.self_attn.embed_dim)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        for layer in self.layers:
            tgt = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)

        tgt = self.norm(tgt)
        
        return tgt

In [None]:
class TransformerModel(nn.Module):
    """ Container module hosting the encoder and decoder. """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_embed):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.pos_embed = pos_embed

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
        memory = self.encoder(self.pos_embed(self.src_embed(src)), src_mask, src_key_padding_mask)
        output = self.decoder(self.pos_embed(self.tgt_embed(tgt)), memory, tgt_mask, None, tgt_key_padding_mask, None)
        return output

In [None]:
def build_transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, activation='relu'):
    """Function to build the transformer model."""
    encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
    encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
    
    decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
    decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
    
    src_embed = nn.Embedding(num_embeddings=10, embedding_dim=d_model)  # Placeholder for actual embedding logic
    tgt_embed = nn.Embedding(num_embeddings=10, embedding_dim=d_model)  # Placeholder for actual embedding logic
    pos_encoder = PositionalEncoding(d_model=d_model)
    
    return TransformerModel(encoder, decoder, src_embed, tgt_embed, pos_encoder)