# Library

In [None]:
!pip install torchinfo



In [None]:
import torch
from torch import nn
from torchinfo import summary

# 모형

## DenseNet layer

In [None]:
class DenseLayer(nn.Module):

  def __init__(self, in_channels, k):
    super().__init__()
    self.residual = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels, 4*k, 1, bias=False),
        nn.BatchNorm2d(4*k),
        nn.ReLU(inplace=True),
        nn.Conv2d(4*k, k, 3, padding=1, bias=False)
    )

  def forward(self, x):
    return torch.concat([self.residual(x),x],dim=1)

In [None]:
class Transition(nn.Module):

  def __init__(self, in_channels, csp_transition=False):
    super().__init__()
    transition_layers = [
        nn.BatchNorm2d(in_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels, in_channels//2, 1, bias=False)
    ]

    if csp_transition is not True:
      transition_layers.append(nn.AvgPool2d(2))

    self.transition = nn.Sequential(*transition_layers)

  def forward(self, x):
    return self.transition(x)

## CSPNet

In [None]:
class CSPDenseBlock(nn.Module):

  def __init__(self, in_channels, num_blocks, k, last_stage=False):
    super().__init__()
    self.in_channels = in_channels
    csp_channels_1 = in_channels / 2
    csp_channels_2 = in_channels - csp_channels_1

    layers = []
    for _ in range(num_blocks):
      layers.append(DenseLayer(csp_channels_1, k))
      csp_channels_1 += k
    layers.append(Transition(csp_channels_1, csp_transition=True))
    csp_channels_1 //= 2
    self.dense_block == nn.Sequential(*layers)

    self.last = nn.Sequential(nn.BatchNorm2d(csp_channels_1 + csp_channels_2), nn.ReLU(inplace = True)) if last_stage else Transition(csp_channels_1 + csp_channels_2)
    self.channels = csp_channels_1 + csp_channels_2 if last_stage else (csp_channels_1 + csp_channels_2) // 2

  def forward(self, x):
    if self.in_channels % 2:
        csp_x_01 = x[:, self.in_channels // 2 + 1:, ...]
        csp_x_02 = x[:, :self.in_channels // 2 + 1, ...]
    else:
        csp_x_01 = x[:, self.in_channels // 2:, ...]
        csp_x_02 = x[:, :self.in_channels // 2, ...]

    csp_x_01 = self.dense_block(csp_x_01)
    csp_x = torch.cat([csp_x_01, csp_x_02], dim = 1)

    return self.last(csp_x)

In [None]:
class CSPDenseNet(nn.Module):
    def __init__(self, block_list, growth_rate, n_classes = 1000):
        super().__init__()

        assert len(block_list) == 4
        self.k = growth_rate

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 2 * self.k, 7, stride = 2, padding = 3, bias = False),
            nn.BatchNorm2d(2 * self.k),
            nn.ReLU(inplace = True),
        )
        self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)

        self.dense_block_01 = CSPDenseBlock(2 * self.k, block_list[0], self.k)
        self.dense_block_02 = CSPDenseBlock(self.dense_block_01.channels, block_list[1], self.k)
        self.dense_block_03 = CSPDenseBlock(self.dense_block_02.channels, block_list[2], self.k)
        self.dense_block_04 = CSPDenseBlock(self.dense_block_03.channels, block_list[3], self.k, last_stage = True)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.dense_block_04.channels, n_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.dense_block_01(x)
        x = self.dense_block_02(x)
        x = self.dense_block_03(x)
        x = self.dense_block_04(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x