In [None]:
from __future__ import annotations

In [None]:
!pip install -q einops

In [None]:
import torch
import torch.nn as nn

from typing import NamedTuple

import einops

In [None]:
class FakeBackboneResult(NamedTuple):
  hl_features: torch.Tensor
  ml_features: torch.Tensor
  ll_features: torch.Tensor

In [None]:
class FakeBackbone(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x: torch.Tensor) -> FakeBackboneResult:
    hl_fm = torch.randn(size=(1, 512, 13, 13))
    ml_fm = torch.randn(size=(1, 256, 26, 26))
    ll_fm = torch.randn(size=(1, 128, 52, 52))

    return FakeBackboneResult(
        hl_features=hl_fm,
        ml_features=ml_fm,
        ll_features=ll_fm
    )

In [None]:
backbone = FakeBackbone()

backbone_output = backbone(torch.randn(size=(1, 3, 416, 416)))

In [None]:
backbone_output.hl_features.shape

torch.Size([1, 512, 13, 13])

In [None]:
backbone_output.ml_features.shape

torch.Size([1, 256, 26, 26])

In [None]:
backbone_output.ll_features.shape

torch.Size([1, 128, 52, 52])

In [None]:
class DetectionHead(nn.Module):
  def __init__(self,
        in_channels: int,
        num_boxes_per_cell: int,
        num_classes: int):
    super().__init__()

    num_predicted_channels = num_boxes_per_cell * (4 + 1 + num_classes)

    self.conv = nn.Conv2d(
        in_channels=in_channels,
        out_channels=num_predicted_channels,
        kernel_size=1,
        stride=1,
    )

  def forward(self, x:torch.Tensor) -> torch.Tensor:
    x = self.conv(x)
    return x

In [None]:
hl_detector = DetectionHead(in_channels=512, num_boxes_per_cell=3, num_classes=3)

hl_detections = hl_detector(backbone_output.hl_features)
hl_detections.shape

torch.Size([1, 24, 13, 13])

In [None]:
ml_detector = DetectionHead(in_channels=256, num_boxes_per_cell=3, num_classes=3)
ll_detector = DetectionHead(in_channels=128, num_boxes_per_cell=3, num_classes=3)

In [None]:
ml_detections = ml_detector(backbone_output.ml_features)
ml_detections.shape

torch.Size([1, 24, 26, 26])

In [None]:
ll_detections = ll_detector(backbone_output.ll_features)
ll_detections.shape

torch.Size([1, 24, 52, 52])

In [None]:
hl_detections_for_training = einops.rearrange(hl_detections,
                                          "b (num_anchors_per_cell p) h w -> b num_anchors_per_cell h w p",
                                          num_anchors_per_cell=3)

hl_detections_for_training.shape

torch.Size([1, 3, 13, 13, 8])

In [None]:
pred_for_box0_at_cell_15 = hl_detections_for_training[0][0][1][5]
pred_for_box1_at_cell_15 = hl_detections_for_training[0][1][1][5]
pred_for_box2_at_cell_15 = hl_detections_for_training[0][2][1][5]

pred_for_box0_at_cell_15.shape

torch.Size([8])

In [None]:
box_coordinates = pred_for_box0_at_cell_15[:4]
box_objectness = pred_for_box0_at_cell_15[4]
box_classes = pred_for_box0_at_cell_15[5:]

box_coordinates, box_objectness, box_classes

(tensor([-0.8510,  0.5855,  0.5353,  0.0370], grad_fn=<SliceBackward0>),
 tensor(0.4091, grad_fn=<SelectBackward0>),
 tensor([-0.5069,  0.3296, -1.0508], grad_fn=<SliceBackward0>))

In [None]:
pred_for_coordinates = hl_detections_for_training[...,:4]

pred_for_coordinates.shape

torch.Size([1, 3, 13, 13, 4])

In [None]:
pred_for_objectness = hl_detections_for_training[..., 4]

pred_for_objectness.shape

torch.Size([1, 3, 13, 13])

In [None]:
pred_for_classes = hl_detections_for_training[..., 5:]

pred_for_classes.shape

torch.Size([1, 3, 13, 13, 3])

In [None]:
hl_detections_for_final_prediction = einops.rearrange(hl_detections,
                                          "b (num_anchors_per_cell p) h w -> b (num_anchors_per_cell h w) p",
                                          num_anchors_per_cell=3, h=13, w=13)

hl_detections_for_final_prediction.shape

torch.Size([1, 507, 8])