In [1]:
import os
import cv2
import requests
from PIL import Image

import torch
import torch.nn as nn
from torch.hub import load
from torchvision import transforms
from transformers import AutoModel

In [4]:
# Load DETR
detr = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)

Using cache found in C:\Users\citak/.cache\torch\hub\facebookresearch_detr_main


In [5]:
# Load DINOv2 Backbone
dino_model = AutoModel.from_pretrained("facebook/dinov2-large")

config.json:   0%|          | 0.00/549 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

In [13]:
dino_model

Dinov2Model(
  (embeddings): Dinov2Embeddings(
    (patch_embeddings): Dinov2PatchEmbeddings(
      (projection): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Dinov2Encoder(
    (layer): ModuleList(
      (0-23): 24 x Dinov2Layer(
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attention): Dinov2SdpaAttention(
          (attention): Dinov2SdpaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): Dinov2SelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (layer_scale1): Dinov2LayerScale()
      

In [6]:
# Freeze DINOv2 backbone to avoid training it
for param in dino_model.parameters():
    param.requires_grad = False

# Define Preprocessing for DINOv2
transform = transforms.Compose([
    transforms.Resize((518, 518)),  # Resize image to match DINOv2 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [10]:
class DETRWithDINO(nn.Module):
    def __init__(self, detr, dino_model):
        super().__init__()
        self.detr = detr
        self.dino = dino_model

        # Remove the original CNN backbone
        self.detr.backbone = None

        # Define a linear layer to project DINOv2 embeddings to match DETR
        self.feature_proj = nn.Linear(1024, 256)  # DINOv2 -> DETR feature dim

    def forward(self, images):
        """ Extract features from DINOv2 and pass them to DETR """

        with torch.no_grad():
            dino_features = self.dino(images).last_hidden_state  # [B, Tokens, 1024]

        # Project DINOv2 output to match DETR's expected feature size
        projected_features = self.feature_proj(dino_features)  # [B, Tokens, 256]

        # Reshape to match CNN-like feature maps (B, C, H, W)
        B, T, C = projected_features.shape
        H = W = int(T**0.5)  # Assuming square token grid (H=W=√T)
        projected_features = projected_features.permute(0, 2, 1).view(B, C, H, W)

        # Generate positional encoding
        pos_embed = self.detr.transformer.encoder.pos_encoding(projected_features)

        # Create dummy mask (assumes no padding)
        mask = torch.zeros((B, H, W), dtype=torch.bool, device=images.device)

        # Pass through DETR’s transformer
        memory = self.detr.transformer.encoder(projected_features, mask, pos_embed)

        # Object queries for detection
        query_embed = self.detr.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)

        # Decoder stage
        hs = self.detr.transformer.decoder(query_embed, memory, mask, pos_embed)

        return hs

In [11]:
# Initialize the modified DETR
model = DETRWithDINO(detr, dino_model)

In [12]:
# Load an example image
url = "https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg"
image = Image.open(requests.get(url, stream=True).raw)
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Run inference
output = model(input_tensor)
print(output.shape)  # Expected shape: [1, num_queries, 256]

RuntimeError: shape '[1, 256, 37, 37]' is invalid for input of size 350720