In [1]:
import torch
from torch import nn
from torchvision.models import resnet50

In [64]:
class DETR(nn.Module):
    
    def __init__(self, num_classes, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers):
    
        super().__init__()
        
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        
        self.transformer = nn.Transformer(hidden_dim, nheads,
                                          num_encoder_layers, num_decoder_layers)
        
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_box = nn.Linear(hidden_dim, 4) # Box coords 
        
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # 100, 256 -> 100 N with 256 dim 
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) # 50, 128
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) # 50, 128
        
        # Pos embedding 100 in number but they are more than 1-dim for each pos so we can add them to the
        # outpout of the conv1x1 
    def forward(self, inputs):
        
        x = self.backbone(inputs)
        # ([1, 2048, 25, 38]) # b, C, h, w 
        
        h = self.conv(x)  # ([1, 256, 25, 38]) # b, C, h, w 1x1 conv project 
        H, W = h.shape[-2:]
        
        #print(self.col_embed[:W].shape) -> (38, 128) -> (1, 38, 128) -> (25, 38, 128)
        #print(self.row_embed[:H].shape) -> (25, 128) -> (25, 1, 128) -> (25, 38, 128) -> (25, 38, 256) 
        # -> [950, 256] -> [950, 1, 256] pos 
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        
        # h fl -> [1, 256, 950] -> [950, 1, 256] # 950 "sequence length"/patches then it uses attention
        # Each patch 256 dim info repr  
        # so that the obj queries weights them 
        # q -> [100, 1, 256]
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1))

        # h -> # [100, 1, 256]
        return self.linear_class(h), self.linear_box(h).sigmoid() # Relatative to imgs -> sigmoid 

In [65]:
    
detr = DETR(num_classes=91, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6)

print(sum(p.nelement() for p in detr.parameters())) # number of parameters in total

detr.eval()
inputs = torch.randn(1, 3, 800, 1200) # b, c, h, w 
logits, bboxes = detr(inputs) # Per box -> perdiction 

41459616
torch.Size([950, 1, 256])


In [66]:
print(logits.shape, bboxes.shape)

torch.Size([100, 1, 92]) torch.Size([100, 1, 4])
