<a href="https://colab.research.google.com/github/AriPathak/Dino-DETR/blob/main/DETR_DinoV2_Hybrid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [72]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision
%config InlineBackend.figure_format = 'retina'
import torch
from torchvision.models import resnet50
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torchvision.transforms as T

In [73]:
class DinoV2Encoder(nn.Module):
  def __init__(self, learnable_modules:list):
    super().__init__()
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
    self.learnable_modules = learnable_modules
    #self.conv = nn.LazyConv2d(256, kernel_size=1)
    self.ffn = nn.LazyLinear(256)
    for param in model.parameters():
      param.requires_grad = False
    for i in self.learnable_modules:
      for param in model.blocks[i].parameters():
        param.requires_grad = True
    self.DinoV2 = nn.Sequential(model.patch_embed,
                          *[model.blocks[p] for p in range(12)])
  def forward(self, x):
    x = self.DinoV2(x)
    x = self.ffn(x)
    return x

In [74]:
class DETRDecoder(nn.Module):
  def __init__(self, num_classes, decoder, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
    super().__init__()
    self.transformer = decoder
    self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
    self.linear_bbox = nn.Linear(hidden_dim, 4)
    self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

  def forward(self, x):
    x = x.transpose(0, 1)
    h = self.transformer(self.query_pos.unsqueeze(1), x).transpose(0, 1)
    #first arg is input sequence to the encoder; second is input to the decoder
    return {'pred_logits': self.linear_class(h),
            'pred_boxes': self.linear_bbox(h).sigmoid()}

In [75]:
class DETR(nn.Module):
  def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
    super().__init__()
    self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
    self.transformer = nn.Transformer(
        hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

    # prediction heads, one extra class for predicting non-empty slots
    # note that in baseline DETR linear_bbox layer is 3-layer MLP
    self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
    self.linear_bbox = nn.Linear(hidden_dim, 4)

    # output positional encodings (object queries)
    self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

    # spatial positional encodings
    # note that in baseline DETR we use sine positional encodings
    self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
    self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

  def forward(self, h):
      # propagate inputs through ResNet-50 up to avg-pool layer

      # convert from 2048 to 256 feature planes for the transformer
      #[1, 2048, 7, 7]
      #[1, 256, 7, 7]

      # construct positional encodings
      H, W = h.shape[-2:]
      pos = torch.cat([
          self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
          self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
      ], dim=-1).flatten(0, 1).unsqueeze(1)

        # propagate through the transformer

      h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
                            self.query_pos.unsqueeze(1)).transpose(0, 1)

      #first arg is input sequence to the encoder; second is input to the decoder

      # finally project transformer outputs to class labels and bounding boxes
      return {'pred_logits': self.linear_class(h),
              'pred_boxes': self.linear_bbox(h).sigmoid()}

In [76]:
detr_trans = DETR(91)
state_dict = torch.hub.load_state_dict_from_url(
    url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
    map_location='cpu', check_hash=True)
detr_state_dict = state_dict.copy()
for n, v in enumerate(state_dict):
  if n >= 191 and not(456<=n<=457):
    del detr_state_dict[v]

detr_trans.load_state_dict(detr_state_dict)

detr_decoder = detr_trans.transformer.decoder #money
for c, param in detr_decoder.named_parameters():
  param.requires_grad = False
#TODO: unfreeze learnable modules for DETR Decoder
#detr_decoder.eval();



In [77]:
class DETR_DinoV2(nn.Module):
  def __init__(self, learnable_modules, num_classes, decoder):
    super().__init__()
    self.decoder = DETRDecoder(num_classes, decoder)
    self.encoder = DinoV2Encoder(learnable_modules)

  def forward(self, x):
    x = self.encoder(x)
    return self.decoder(x)

In [80]:
DETR_DinoV2_Hybrid = DETR_DinoV2([10, 11], 91, detr_decoder)
test_img = torch.randn(1, 3, 672, 672)
print("Predicted Classification Logits Output Shape: " , list(DETR_DinoV2_Hybrid(test_img)['pred_logits'].shape))
print("Predicted Bounding Box Output Shape: " , list(DETR_DinoV2_Hybrid(test_img)['pred_boxes'].shape))

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Pred Classification Logits Output Shape:  [1, 100, 92]
Pred Bounding Box Output Shape:  [1, 100, 4]
