In [137]:
import torch
import torch.nn as nn

In [138]:
# Configuration of Darknet53
# (out_channels,kernel,stride) - convolution block
# [in_channels,num_layers] - residual block
config = [
        (32, 3, 1),
        (64, 3, 2),
        [64,1],
        (128, 3, 2),
        [128,2],
        (256, 3, 2),
        [256,8],
        (512, 3, 2),
        [512,8],
        (1024, 3, 2),
        [1024,4]]

In [139]:
class CNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels, kernel_size, stride, padding, bn_act=True):
        super(CNNBlock,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels, kernel_size, stride, padding, bias= not bn_act)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.LeakyReLU()
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky_relu(self.bn(self.conv(x)))
        else:
            return self.conv(x)

In [140]:
class ResidualBlock(nn.Module):
    def __init__(self,in_channels,use_residual=True, num_repeats=1):
        super(ResidualBlock,self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_repeats):
            self.layers += [nn.Sequential(
                    CNNBlock(in_channels, in_channels // 2, kernel_size=1, stride=1, padding=0),
                    CNNBlock(in_channels // 2, in_channels, kernel_size=3, stride=1, padding=1)
                )]
        
        self.use_residual = use_residual

    def forward(self, x):
        for layer in self.layers:
            if self.use_residual:
                x = x + layer(x)
            else:
                x = layer(x)

        return x


In [141]:
# This net is used in YOLOv3 as backbone
class Darknet53(nn.Module):
    def __init__(self,image_channels,num_classes):
        super(Darknet53,self).__init__()
        self.num_classes = num_classes
        self.in_channels = image_channels
        self.layers = self.create_layers()

    def create_layers(self):
        layers = nn.ModuleList()

        for module in config:
            if isinstance(module, tuple):
                out_channels, kernel_size, stride = module
                layers.append(
                    CNNBlock(
                        self.in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=1 if kernel_size == 3 else 0
                    )
                )
                self.in_channels = out_channels

            elif isinstance(module, list):
                layers.append(ResidualBlock(module[0], num_repeats=module[1]))

        return layers
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
        

In [142]:
# Darknet53
def darknet53(img_channels=3,num_classes=1000):
    return Darknet53(img_channels,num_classes)

In [143]:
# Test
def test():
  net = darknet53()
  x = torch.rand(2,3,224,224)
  y = net(x)
  print(y.shape)

In [144]:
test()

torch.Size([2, 1024, 7, 7])
