In [2]:
import torch 
import torch.nn as nn

## Prerequisites
1. I assume the reader have good understanding of [DETR](https://arxiv.org/abs/2005.12872)

Deformable detr is  an improvement to DETR which was one of the first attempts in using transformers for the task of object detection. The main problems with DETR where
### 1. Not able to detect small objects properly. 
    Usually all the new object detection networks use some kind of Feature Pyramind to detect the object across differenet scales. In DETR we cannot use that because the self attention if across the entire feature map and this could lead to a computational explotion .Deformable Detr tends to solve this problem by attending to features from feature maps of different levels.
###  2. Too much training time.
    This is because the network needs to learn the actually features to attend to from the full feature map and this will take a good time to converge to. Deformable detr solves introducing a Deformable Attention module , that doesnt look into all the keys , but only a subset of them. 


In [3]:
pixel_values = torch.randn(4,3,1065,1066) 

# 1 .Initially we need to pass the images through the FPN and get features across different layers,
# 2. Also we need to get positional embedding for each of the feature map, the positional embedding is similar to the normal sine-cosine positional embedding in the original paper,
#  the only difference here is that since we have HxW in the feature domain , suppose if our embedding dim is 256, we will have them alingned in such a way that the first 128 corresponds to vertical and the next 128 corresponds
# to vertical so that in the end we end up with 256 and that encodes both vertical and horizontal positions. https://github.com/fundamentalvision/Deformable-DETR/blob/11169a60c33333af00a4849f1808023eba96a931/models/position_encoding.py#L55
# 3. Suppose we get feature map from 4 layers and let them be (4,512,134,134) ,(4,1024,67,67) , (4,2048,34,34) ,(4,2048,17,17) [Note the actual feature map in the paper is created by an additional conv+group norm] and there positional embeddings w
# will have the same size as well. but with the corresponding embedding dim, so they will be of size (4,256,134,134) ,(4,256,67,67) ,(4,256,34,34) ,(4,256,17,17)


feature_shapes = [
    (4, 512, 134, 134),
    (4, 1024, 67, 67),
    (4, 2048, 34, 34),
    (4, 2048, 17, 17)
]

# Positional embedding shapes (same spatial dims, but channel dim = 256)
embedding_shapes = [
    (4, 256, 134, 134),
    (4, 256, 67, 67),
    (4, 256, 34, 34),
    (4, 256, 17, 17)
]

# original implementation here https://github.com/fundamentalvision/Deformable-DETR/blob/11169a60c33333af00a4849f1808023eba96a931/models/backbone.py#L71
feature_maps = [torch.randn(shape) for shape in feature_shapes]

# original implemenation here https://github.com/fundamentalvision/Deformable-DETR/blob/11169a60c33333af00a4849f1808023eba96a931/models/position_encoding.py#L55
positional_embeddings = [torch.randn(shape) for shape in embedding_shapes]

# 4 . Now we have to have a 1x1 conv layer to reduce the channel dimension of the feature so that they match the embedding dimension of 256
conv_layers = nn.ModuleList([
    nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1),
    nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
    nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1),
    nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1)
])

# Apply the 1x1 conv layers
reduced_feature_maps = [conv(feature) for conv, feature in zip(conv_layers, feature_maps)]

for i, (fmap,pos_emb) in enumerate(zip(reduced_feature_maps,positional_embeddings)):
    print(f"Reduced feature map {i+1} shape:", fmap.shape)

Reduced feature map 1 shape: torch.Size([4, 256, 134, 134])
Reduced feature map 2 shape: torch.Size([4, 256, 67, 67])
Reduced feature map 3 shape: torch.Size([4, 256, 34, 34])
Reduced feature map 4 shape: torch.Size([4, 256, 17, 17])


In [4]:
# 5 . Also we need a learnable Level embedding for each levels , since here we are using 4 layers, and 256 embedding dim , the size of the level embedding will be (4,256)
# Learnable level embedding (in actual model this would be nn.Parameter)
level_embedding = torch.randn((4, 256))  # shape: (num_levels, embedding_dim)


#6. Now we need to flatten and transpose the features and positional embedding so they become the similar shape like token_len X embedding_dim , for example the first feature map will become (4,134*134,256) ,similarly we have do this 
# for all the feature maps and the positional embedding. and one additional thing to do is to add the level embedding to the positional embedding.

features_flatten = []
positional_and_level_embedding_flattened = []

for level, (feature, pos_emb) in enumerate(zip(reduced_feature_maps, positional_embeddings)):
    # Flatten and transpose: (B, C, H, W) -> (B, HW, C)
    feature_flatten = feature.flatten(2).transpose(1, 2)
    positional_plus_level_embed = pos_emb.flatten(2).transpose(1, 2) + level_embedding[level].view(1, 1, -1)

    features_flatten.append(feature_flatten)
    positional_and_level_embedding_flattened.append(positional_plus_level_embed)

    # Print shapes
    print(f"Level {level + 1}:")
    print(f"  Feature shape: {feature_flatten.shape}")
    print(f"  Positional + Level Embedding shape: {positional_plus_level_embed.shape}")

    

Level 1:
  Feature shape: torch.Size([4, 17956, 256])
  Positional + Level Embedding shape: torch.Size([4, 17956, 256])
Level 2:
  Feature shape: torch.Size([4, 4489, 256])
  Positional + Level Embedding shape: torch.Size([4, 4489, 256])
Level 3:
  Feature shape: torch.Size([4, 1156, 256])
  Positional + Level Embedding shape: torch.Size([4, 1156, 256])
Level 4:
  Feature shape: torch.Size([4, 289, 256])
  Positional + Level Embedding shape: torch.Size([4, 289, 256])


In [5]:
# Step 7: Concatenate along sequence dimension (dim=1)
inputs_embeds = torch.cat(features_flatten, dim=1)  # shape: (B, total_seq_len, 256)
position_embeddings = torch.cat(positional_and_level_embedding_flattened, dim=1)  # shape: (B, total_seq_len, 256)

print("Concatenated Inputs Embeds shape:", inputs_embeds.shape)
print("Concatenated Position Embeddings shape:", position_embeddings.shape)

Concatenated Inputs Embeds shape: torch.Size([4, 23890, 256])
Concatenated Position Embeddings shape: torch.Size([4, 23890, 256])


In [8]:
# 8. we need to apply a initial dropout before passing it to the encoder 
inputs_embeds = nn.functional.dropout(inputs_embeds, p=0.1)
batch_size = inputs_embeds.shape[0]
#9. Generating the reference points, so this is a concept that is similar to the deformable convolution , so basically for each feature_point/query in the feature map we need to look into the corresponding point in the other feature
# map as well, feature maps a re normilized  based on their height and width, so we can look for the corresponding point for each query in different points as well, here
#original implemenation https://github.com/fundamentalvision/Deformable-DETR/blob/11169a60c33333af00a4849f1808023eba96a931/models/deformable_transformer.py#L238 
spatial_shapes_list = [(134, 134), (67, 67), (34, 34), (17, 17)]

reference_points_list = []
for H_, W_ in spatial_shapes_list:
        # Create meshgrid of normalized coordinates
        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32),
            torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32),
            indexing='ij'  # Important for correct axis ordering
        )
        # Normalize
        ref_y = ref_y.reshape(-1) / H_
        ref_x = ref_x.reshape(-1) / W_

        # Stack and expand to batch size
        ref = torch.stack((ref_x, ref_y), dim=-1)  # shape: (H_*W_, 2)
        ref = ref[None].expand(batch_size, -1, -1)  # shape: (B, H_*W_, 2)
        reference_points_list.append(ref)


# Concatenate all levels
reference_points = torch.cat(reference_points_list, dim=1)  # shape: (B, total_seq_len, 2)
# Expand to include level dimension
reference_points = reference_points[:, :, None, :]  # shape: (B, total_seq_len, 1, 2)

# Repeat across levels
num_levels = len(spatial_shapes_list)
reference_points = reference_points.expand(-1, -1, num_levels, -1)  # shape: (B, total_seq_len, L, 2)
print("Reference points shape input to encoder ",reference_points.shape)



Reference points shape input to encoder  torch.Size([4, 23890, 4, 2])


In [9]:
#so for now each query we have 4 positions (x,y) across 4 different channels, now this will be passed to the encoder.

reference_points[0,0]

tensor([[0.0037, 0.0037],
        [0.0037, 0.0037],
        [0.0037, 0.0037],
        [0.0037, 0.0037]])