In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [54]:
class ConvBlock(nn.Module):
    def __init__(self,inchannels,outchannels,kernel_size,stride,padding):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(inchannels,outchannels,kernel_size,stride,padding),
            nn.BatchNorm2d(outchannels),
            nn.ReLU(),
        )
    def forward(self,x):
        return self.block(x)
    
class YOLOv1(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.Sequential(
            ConvBlock(3,64,7,2,3),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=1),
            ConvBlock(64,192,3,1,0),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=1),

            ConvBlock(192,128,1,1,0),
            ConvBlock(128,256,3,1,0),
            ConvBlock(256,256,1,1,0),
            ConvBlock(256,512,1,1,0),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=1),

            ConvBlock(512,256,1,1,0),
            ConvBlock(256,512,3,1,1),
            ConvBlock(512,256,1,1,0),
            ConvBlock(256,512,3,1,1),
            ConvBlock(512,256,1,1,0),
            ConvBlock(256,512,3,1,1),
            ConvBlock(512,256,1,1,0),
            ConvBlock(256,512,3,1,1),
            ConvBlock(512,512,1,1,0),
            ConvBlock(512,1024,3,1,0),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=1),

            ConvBlock(1024,512,1,1,0),
            ConvBlock(512,1024,1,1,1),
            ConvBlock(1024,512,1,1,0),
            ConvBlock(512,1024,1,1,1),
            ConvBlock(1024,1024,3,1,0),
            ConvBlock(1024,1024,3,2,0),

            ConvBlock(1024,1024,3,1,1),
            ConvBlock(1024,1024,3,1,1),

        )
        self.fcfn = nn.Sequential(
            nn.Flatten(start_dim=1),
            nn.Linear(7 * 7 * 1024,4096),
            nn.ReLU(),
            nn.Linear(4096,7*7*30),
        )
    def forward(self,x):
        x = self.convs(x)
        x = self.fcfn(x)
        return x

In [55]:
model = YOLOv1()
model(torch.randn((1,3,448,448))).view(1,7,7,30).shape

torch.Size([1, 7, 7, 30])

In [56]:
sum([p.numel() for p in model.parameters()]) / 1e6

262.292734