In [7]:
#File description: Defines a CNN model (ObjectDetectionNetwork) for bounding box regression
import torch
from torch import nn

In [None]:
class ObjectDetectionNetwork(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        # Convolutional layers for feature extraction
        self.conv_stack = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=7, padding=1),
            nn.BatchNorm2d(32), # Batch normalization added
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=5, stride=2), # downsample
            nn.Conv2d(32, 64, kernel_size=4, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Detection head for classification
        self.cls_head = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

        # Detection head for bounding box regression
        self.box_head = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=7, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 4, kernel_size=5)
        )

    def forward(self, x):
        x = self.conv_stack(x)
        class_logits = self.cls_head(x)
        box_logits = self.box_head(x)
        box_preds = self.box_head(x) # format: [x_center, y_center, width, height]
        return class_logits, box_logits, box_preds


In [None]:
# Define device and initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ObjectDetectionNetwork(num_classes=10).to(device)
print(device)
print(model)

cpu
YOLOMiniNet(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (head): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=0, out_features=1024, bias=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=735, bias=True)
  )
)
