In [17]:
####

In [18]:
import torch 
from torch import nn

In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_pool = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool

        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)
        if self.use_pool:
            self.max_pool = nn.MaxPool2d(kernel_size=(2 , 2) , stride=(2 , 2))

    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_pool:
            x = self.max_pool(x)
        return x

In [21]:
class Residual_Block(nn.Module):
    def __init__(self ,
                 in_channels , 
                 use_pool = False):
        super(Residual_Block , self).__init__()

        self.conv1 = Conv(in_channels , in_channels , use_pool=use_pool)

    def forward(self , x):
        x_ = x.clone()
        x = self.conv1(x)
        x += x_
        return x

In [22]:
class Darknet_Block(nn.Module):
    def __init__(self , 
                 in_channels ,  
                 use_pool = False):
        super(Darknet_Block , self).__init__()

        self.conv1 = Conv(in_channels , in_channels // 2 , kernel_size=(1 , 1) , stride=(1 , 1) , padding = 0)
        self.conv2 = Conv(in_channels //2 , in_channels)
        self.residual = Residual_Block(in_channels , use_pool=use_pool)

    def forward(self , x):
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual(x)
        x += x_
        return x

In [26]:
config = [
          [32 , (3 , 3) , (1 , 1) , 1 , True , True , False] , # [out_channels , kernel_size , stride , padding , use_norm , use_activation , use_pool]
          [64 , (3 , 3) , (2 , 2) , 1 , True , True , False]  ,
          (64 , 1) , 
          [128 , (3 , 3) , (2 , 2) , 1 , True , True , False] ,
          (128 , 2) , 
          [256 , (3 , 3) , (2 , 2) , 1 , True , True , False] ,
          (256 , 8) , 
          [512 , (3 , 3) , (2 , 2) , 1 , True , True , False] ,
          (512 , 8), 
          [1024 , (3 , 3) , (2 , 2) , 1 , True , True , False] ,
          (1024 , 4)
]

In [29]:
class Darknet(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels_last = 1000 , 
                 config = config):
        super(Darknet , self).__init__()

        self.layers = nn.ModuleList()
        self.adp_pool = nn.AdaptiveAvgPool2d((1 , 1))
        for layer in config:
            if isinstance(layer , list):
                out_channels , kernel_size , stride , padding , use_norm , use_activation , use_pool = layer
                self.layers.append(Conv(
                    in_channels , 
                    out_channels ,
                    kernel_size , 
                    stride ,
                    padding , 
                    use_norm , 
                    use_activation , 
                    use_pool
                ))
                in_channels = out_channels
            elif isinstance(layer , tuple):
                out_channels , repeats = layer
                for _ in range(repeats):
                    self.layers.append(Residual_Block(out_channels))
                in_channels = out_channels
        self.linear = nn.Linear(in_channels , out_channels_last)
        
    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        x = self.adp_pool(x)
        x = self.linear(x.squeeze(-1).squeeze(-1))
        return x

In [None]:
x = torch.randn(2 , 3 , 256 , 256).to(device)
darknet = Darknet(3).to(device)
z = darknet(x)
z.shape