In [22]:
import torch
import torch.nn as nn
from torchvision.ops import generalized_box_iou
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from typing import Tuple

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

In [23]:
def cxcywh_to_xyxy(boxes):
    cx, cy, w, h = boxes.unbind(-1)
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    return torch.stack((x1, y1, x2, y2), dim=-1)

In [24]:
class HungarianMatcher(nn.Module):
    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 2):
        super(HungarianMatcher, self).__init__()

        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou

    def _get_index_map(self, targets):
        batch_idx = torch.cat([torch.full((len(t['labels']),), i, dtype=torch.int64)
                               for i, t in enumerate(targets)])
        gt_idx = torch.cat([torch.arange(len(t['labels']), dtype=torch.int64)
                            for t in targets])
        return batch_idx, gt_idx

    #Manhattan distance between predicted and GTBoxes

    def _bbox_distance(self, pred_boxes, targets):
        boxes = torch.cat([t['boxes'] for t in targets])
        pred_boxes = pred_boxes.view(-1, 4)
        cost = torch.cdist(pred_boxes, boxes, p=1)
        return cost

    def _giou_loss(self, pred_boxes, targets):
        all_cost = []

        for i, t in enumerate(targets):
            preds = pred_boxes[i]
            gts = t['boxes']

            if len(gts) == 0:
                all_cost.append(torch.zeros(preds.size(0), device=preds.device))
                continue
            pred_xyxy = cxcywh_to_xyxy(preds)
            targetxyxy = cxcywh_to_xyxy(gts)

            giou = generalized_box_iou(pred_xyxy, targetxyxy)
            best_giou_per_pred = giou.max(dim=-1)[0]
            all_cost.append(-best_giou_per_pred)
        return all_cost

    @torch.no_grad()
    def forward(self, outputs, targets):
        bs, num_queries = outputs['pred_logits'].shape[:2]
        indices = []

        for i in range(bs):
            # --- extract predictions for this image ---
            out_prob = outputs['pred_logits'][i].softmax(-1)  # [N, C]
            out_bbox = outputs['pred_boxes'][i]  # [N, 4]

            tgt_ids = targets[i]['labels']  # [K]
            tgt_bbox = targets[i]['boxes']  # [K, 4]

            # --- compute costs ---
            cost_class = -out_prob[:, tgt_ids]  # [N, K]
            cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)  # [N, K]
            cost_giou = -generalized_box_iou(
            cxcywh_to_xyxy(out_bbox), cxcywh_to_xyxy(tgt_bbox)
            )  # [N, K]

            # --- total cost ---
            C = (self.cost_class * cost_class +
                 self.cost_bbox * cost_bbox +
                 self.cost_giou * cost_giou)

            i_idx, j_idx = linear_sum_assignment(C.cpu())
            indices.append((
                torch.as_tensor(i_idx, dtype=torch.int64),
                torch.as_tensor(j_idx, dtype=torch.int64)
            ))

        return indices



In [25]:
#Understand the Set Criterion
class SetCriterion(nn.Module):
    def __init__(self, matcher, weight_dict, eos_coef=0.1, losses=None):
        super(SetCriterion, self).__init__()
        if losses is None:
            losses = ['labels', 'boxes']
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses

    def _get_src_batch_indices(self, indices):
        batch_idx = torch.cat(
            [torch.full((len(src),), i, dtype=torch.int64)
             for i, (src, _) in enumerate(indices)]
        )
        src_idx = torch.cat(
            [src for (src, _) in indices]
        )
        return batch_idx, src_idx

    def loss_labels(self, outputs, targets, indices, num_boxes):
        """Compute classification loss for matched predictions only"""
        batch_idx, src_idx = self._get_src_batch_indices(indices)
        pred_logits = outputs['pred_logits'][batch_idx, src_idx]
        #Ground truth label for the same box
        target_classes = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
        #Compute loss function
        loss_ce = F.cross_entropy(pred_logits, target_classes, reduction='none')
        weights = torch.ones_like(target_classes, dtype=torch.float)
        loss = (loss_ce * weights).sum() / num_boxes
        return {'loss_ce': loss}

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute DETR box regress loss"""
        batch_idx, src_idx =self._get_src_batch_indices(indices)
        src_boxes = outputs['pred_boxes'][batch_idx, src_idx]

        #Get each box ground truth
        target_boxes = torch.cat(
            [t['boxes'][J] for t, (_, J) in zip(targets, indices)],
        )

         # L1 loss   numeric distance between boxes
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        loss_bbox = loss_bbox.sum()/ num_boxes

        #GIOU Loss (Overlapping Boxes)
        src_xyxy = cxcywh_to_xyxy(src_boxes)
        target_xyxy = cxcywh_to_xyxy(target_boxes)

        giou = generalized_box_iou(src_xyxy, target_xyxy)
        giou = torch.nan_to_num(giou, nan= 0.0, posinf=0.0, neginf=-1.0)
        loss_giou = (1 - torch.diag(giou)).sum() / num_boxes

        return {'loss_box':loss_bbox, 'loss_giou': loss_giou}


    def forward(self, outputs, targets ):
        """Compute the total DETR LOSS"""
        indices = self.matcher(outputs, targets)
        num_boxes = sum(len(t['labels']) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device = next(iter(outputs.values())).device)
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            torch.distributed.all_reduce(num_boxes)
            num_boxes = num_boxes / torch.distributed.get_world_size()
        num_boxes = max(num_boxes.item(), 1.0)

        losses ={}

        for loss in self.losses:
            if loss == 'labels':
                losses.update(self.loss_labels(outputs, targets, indices, num_boxes))
            elif loss == 'boxes':
                losses.update(self.loss_boxes(outputs, targets, indices, num_boxes))
        total_loss = sum(self.weight_dict[k] * v for k, v in losses.items() if k in self.weight_dict)
        return total_loss, losses

In [26]:
outputs = {
    "pred_logits": torch.randn(2, 5, 91),  # 2 images, 5 queries each
    "pred_boxes": torch.rand(2, 5, 4)
}
targets = [
    {"labels": torch.tensor([3, 5, 7]), "boxes": torch.rand(3, 4)},
    {"labels": torch.tensor([4, 8]), "boxes": torch.rand(2, 4)}
]
matcher = HungarianMatcher()
indices = matcher(outputs, targets)
print(indices)


[(tensor([0, 1, 2]), tensor([1, 0, 2])), (tensor([0, 4]), tensor([1, 0]))]


In [27]:
weight_dict = {'loss_ce': 1.0, 'loss_bbox': 5.0, 'loss_giou': 2.0}
outputs = {
    "pred_logits": torch.randn(2, 5, 91),
    "pred_boxes": torch.rand(2, 5, 4)
}
targets = [
    {"labels": torch.tensor([3, 5, 7]), "boxes": torch.rand(3, 4)},
    {"labels": torch.tensor([4, 8]), "boxes": torch.rand(2, 4)}
]

matcher = HungarianMatcher()
criterion = SetCriterion(matcher, weight_dict, eos_coef=0.1, losses=['labels', 'boxes'])

total_loss, loss_dict = criterion(outputs, targets)
print("Matched indices:", matcher(outputs, targets))
print("Total loss:", total_loss)
print("Loss breakdown:", loss_dict)


Matched indices: [(tensor([0, 1, 3]), tensor([2, 0, 1])), (tensor([3, 4]), tensor([1, 0]))]
Total loss: tensor(7.6155)
Loss breakdown: {'loss_ce': tensor(5.3163), 'loss_box': tensor(1.1238), 'loss_giou': tensor(1.1496)}


In [28]:
import torchvision.models as models
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
class Backbone(nn.Module):
    def __init__(self, backbone_name = 'resnet50', dilation = False, return_layers = None):
        super().__init__()
        if return_layers is None:
            return_layers = {'layer4': '0'}
        backbone = getattr(models, backbone_name)(pretrained=True, replace_stride_with_dilation = [False, False, dilation])
        self.body = backbone

        if dilation:
            self.body.layer4[0].downsample[0].stride = (1, 1)
            self.body.layer4[0].conv2.stride = (1, 1)
            self.body.layer4[2].conv1.stride = (1, 1)
            self.body.layer4[2].conv2.stride = (1, 1)
            self.body.layer4[2].downsample[0].stride = (1, 1)

        self.out_channels = 2048 if backbone_name == 'resnet101' else 256 * 8

    def forward(self, inputs):
        x = self.body.conv1(inputs)
        x =self.body.bn1(x)
        x = self.body.relu(x)
        x = self.body.maxpool(x)
        x = self.body.layer1(x)
        x = self.body.layer2(x)
        x = self.body.layer3(x)
        x = self.body.layer4(x)
        return x


In [33]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        # Detect layout: batch-first or sequence-first
        if x.dim() == 3 and x.shape[0] < x.shape[1]:
            # Likely (seq_len, batch, d_model)
            seq_len = x.size(0)
            return x + self.pe[:, :seq_len, :].transpose(0, 1)
        else:
            # Likely (batch, seq_len, d_model)
            seq_len = x.size(1)
            return x + self.pe[:, :seq_len, :]


In [46]:
import torch
from torch import nn, Tensor
import math

class MultiheadAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        assert d_model % nhead == 0
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask=None, key_padding_mask=None) -> Tensor:
        # --- Detect layout (batch-first vs seq-first) ---
        if query.dim() == 3 and query.shape[0] < query.shape[1]:
            # seq-first layout (L, B, D)
            query = query.transpose(0, 1)  # -> (B, L, D)
            key = key.transpose(0, 1)
            value = value.transpose(0, 1)
            seq_first = True
        else:
            seq_first = False

        B, tgt_len, d = query.shape
        src_len = key.shape[1]

        # --- Project Q, K, V ---
        Q = self.q_proj(query).view(B, tgt_len, self.nhead, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(B, src_len, self.nhead, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(B, src_len, self.nhead, self.head_dim).transpose(1, 2)

        # --- Attention weights ---
        attn = (Q @ K.transpose(-2, -1)) / self.scale  # (B, nhead, tgt, src)
        if attn_mask is not None:
            attn += attn_mask
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        # --- Weighted sum ---
        out = (attn @ V).transpose(1, 2).contiguous().view(B, tgt_len, d)
        out = self.out_proj(out)

        # --- Convert back if needed ---
        if seq_first:
            out = out.transpose(0, 1)  # -> (L, B, D)
        return out


In [47]:
from torch import nn
import torch

#Opt for pytorch builtin transformer
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn import TransformerDecoder, TransformerDecoderLayer

class DETR(nn.Module):
    def __init__(self, num_classes=20, num_queries=100, d_model=256, nhead=8,
                 num_encoder_layers=6, num_decoder_layers=6,
                 dim_feedforward=2048, dropout=0.1, backbone='resnet50', dilation=False):
        super().__init__()
        self.num_queries = num_queries
        self.d_model = d_model
        self.num_classes = num_classes + 1  # + âˆ…

        # Backbone + projection
        self.backbone = Backbone(backbone, dilation)
        self.conv = nn.Conv2d(self.backbone.out_channels, d_model, 1)

        # Positional encodings
        self.encoder_pe = PositionalEncoding(d_model)
        self.decoder_pe = PositionalEncoding(d_model, num_queries)

        # Transformer (batch_first=True)
        encoder_layer = TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, batch_first=True
        )
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_encoder_layers)

        decoder_layer = TransformerDecoderLayer(
            d_model, nhead, dim_feedforward, dropout, batch_first=True
        )
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_decoder_layers)

        # Object queries
        self.query_embed = nn.Embedding(num_queries, d_model)

        # Prediction heads
        self.class_embed = nn.Linear(d_model, num_classes)
        self.bbox_embed = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 4),
            nn.Sigmoid()
        )

        # Criterion
        matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
        weight_dict = {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}
        self.criterion = SetCriterion(matcher, weight_dict, eos_coef=0.1, losses=['labels', 'boxes'])

    def forward(self, x, targets=None):
        # Backbone feature extraction
        features = self.backbone(x)
        features = self.conv(features)
        B, C, H, W = features.shape
        src = features.flatten(2).transpose(1, 2)

        # Add positional encoding
        src = self.encoder_pe(src)

        # Transformer encoder
        memory = self.transformer_encoder(src)

        # Prepare queries
        query_embed = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
        tgt = torch.zeros_like(query_embed)
        query_embed = self.decoder_pe(query_embed)

        # Transformer decoder
        hs = self.transformer_decoder(tgt, memory)  

        # Predictions
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs)
        out = {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}

        if self.training and targets is not None:
            loss_dict = self.criterion(out, targets)
            return loss_dict

        return out


In [48]:
model = DETR(num_classes=20)
x = torch.randn(2, 3, 800, 1333)
out = model(x)
print(out['pred_logits'].shape, out['pred_boxes'].shape)


torch.Size([2, 100, 20]) torch.Size([2, 100, 4])
