# **Pytorch CSPNet implemention**
- A simple implemention of CSPDarkNet52 with Pytorch

In [3]:
import torch
import torch.nn as nn

---

## Mish activation
- 

In [8]:
class Mish(nn.Module):

    def __init__(self):
        super(Mish, self).__init__()
        self.tanh = nn.Tanh()
        self.softplus = nn.Softplus()

    def forward(self, x):
        ori_input = x
        out = self.softplus(x)
        out = self.tanh(out)
        out *= ori_input
        return out

---

## The conv layers for CSPDarkNet
- 

In [9]:
class CSPDConv(nn.Module):

    def __init__(self, in_dim, out_dim, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super(CSPDConv, self).__init__()
        #
        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
        self.bn = nn.BatchNorm2d(out_dim)
        self.mish = Mish()

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.mish(out)
        return out

---

## Define the residual blocks
- the fusion happens before we activate the final output, so the gradient flow is interrupted

In [10]:
class ResidualBolck(nn.Module):

    def __init__(self, in_dim, inside_dim=None):
        super(ResidualBolck, self).__init__()
        # thus the dim will not change in the block
        if inside_dim is None:
            inside_dim = in_dim
        self.conv1 = CSPDConv(in_dim, inside_dim)
        self.conv2 = nn.Conv2d(inside_dim, in_dim, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(in_dim)
        self.mish = Mish()

    def forward(self, x):
        ori_input = x
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.bn(out)
        # fusion before activation
        out += ori_input
        out = self.mish(out)
        return out

---

## Contruct the CSPDarkNet
- There exits minor differences between the first block and other blocks

### CSPDarkNetFirst

In [15]:
class CSPDarkNetFirst(nn.Module):

    def __init__(self, in_dim, out_dim, downsample=None):
        super(CSPDarkNetFirst, self).__init__()
        if downsample is None:
            self.downsample = CSPDConv(in_dim, out_dim, kernel_size=3, stride=2, padding=1)
        # the input will be transformed into two parts during the first stage
        self.trans1 = CSPDConv(in_dim, out_dim)
        self.trans2 = CSPDConv(in_dim, out_dim)
        # the first residual block
        self.resblock = ResidualBolck(out_dim, inside_dim=(out_dim // 2))
        # adjust the dim after each concation
        self.concat = CSPDConv(out_dim * 2, out_dim)
    
    def forward(self, x):
        # downsample
        x = self.downsample(x)
        # transform the x into two parts
        x_1 = self.trans1(x)
        x_2 = self.trans2(x)
        # part 2 will enter the blocks
        out = self.resblock(x_2)
        # concate the two tensor then adjust the dim
        out = torch.cat((x_1, out), dim=1)
        out = self.concat(out)
        return out

### CSPDarkNetBody

In [16]:
class CSPDarkNetBody(nn.Module):

    def __init__(self, in_dim, out_dim, block_num, downsample=None):
        super(CSPDarkNetBody, self).__init__()
        if downsample is None:
            self.downsample = CSPDConv(in_dim, out_dim, kernel_size=3, stride=2, padding=1)
        # the input will be split into two parts during the first stage
        self.trans1 = CSPDConv(in_dim, out_dim // 2)
        self.trans2 = CSPDConv(in_dim, out_dim // 2)
        # the first residual block
        self.resblocks = nn.Sequential(*[ResidualBolck(out_dim // 2) for _ in range(block_num)])
        # adjust the dim after each concation, this time the dim should be already aligned
        self.concat = CSPDConv(out_dim, out_dim)
        
    def forward(self, x):
        # downsample
        x = self.downsample(x)
        # transform the x into two parts
        x_1 = self.trans1(x)
        x_2 = self.trans2(x)
        # part 2 will enter the blocks
        out = self.resblocks(x_2)
        # concate the two tensor then adjust the dim
        out = torch.cat((x_1, out), dim=1)
        out = self.concat(out)
        return out

### CSPDarkNet

In [18]:
class CSPDarkNet(nn.Module):

    def __init__(self, out_dim, blocks, downsample=None):
        super(CSPDarkNet, self).__init__()
        dim_list = [32, 64, 128, 256, 512, 1024]
        # images enter the network
        self.conv1 = CSPDConv(in_dim=3, out_dim=dim_list[0], kernel_size=3, stride=1, padding=1)
        # the first residual block part
        self.resblock1 = CSPDarkNetFirst(in_dim=dim_list[0], out_dim=dim_list[1])
        # other residual blocks
        self.resblocks = nn.Sequential(*[CSPDarkNetBody(in_dim=dim_list[i], out_dim=dim_list[i+1], block_num=blocks[i-1]) for i in range(1, len(blocks)+1)])
        # classcification
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(dim_list[5], out_dim)
    
    def forward(self, x):
        # forward propagation
        out = self.conv1(x)
        out = self.resblock1(out)
        out = self.resblocks(out)
        out = self.avgpool(out)
        # flatten the feature matrix
        out = torch.flatten(out)
        out = self.fc(out)
        out = torch.softmax(out)
        return out           

---

## Contruct the CSPDarkNet53 and check its structure

In [20]:
def CSPDarkNet53(out_dim):
    return CSPDarkNet(out_dim=out_dim, blocks=[2, 8, 8, 4])

mynet = CSPDarkNet53(800)
print(mynet)

CSPDarkNet(
  (conv1): CSPDConv(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (mish): Mish(
      (tanh): Tanh()
      (softplus): Softplus(beta=1, threshold=20)
    )
  )
  (resblock1): CSPDarkNetFirst(
    (downsample): CSPDConv(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (mish): Mish(
        (tanh): Tanh()
        (softplus): Softplus(beta=1, threshold=20)
      )
    )
    (trans1): CSPDConv(
      (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (mish): Mish(
        (tanh): Tanh()
        (softplus): Softplus(beta=1, threshold=20)
      )
    )
    (trans2): CSPDConv(
      (conv): 