<a href="https://colab.research.google.com/github/KSI000321/Legend-13-/blob/main/DenseNet_%EC%8B%A4%EC%8A%B5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

In [None]:
class Bottleneck(nn.Module):
  def __init__(self, in_channels, k):
    super().__init__()
    self.residual = nn.Sequential(nn.BatchNorm2d(in_channels),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(in_channels, 4*k, kernel_size=1, bias=False),
                                  nn.BatchNorm2d(4*k),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(4*k, k, kernel_size=3, padding=1, bias=False))
    # BottleNeckBlock 통과 후 이미지의 사이즈는 동일

  def forward(self, x):
    return torch.cat([x, self.residual(x)], 1)

class TransitionBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.transition = nn.Sequential(nn.BatchNorm2d(in_channels),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(in_channels, out_channels, 1, bias=False),
                                    nn.AvgPool2d(2))

  def forward(self, x):
    return self.transition(x)

class DenseNet(nn.Module):
  def __init__(self, num_block_list, growth_rate, num_classes=1000):
    super().__init__()
    self.k = growth_rate
    inner_channels = 2 * self.k

    self.conv1 = nn.Sequential(nn.Conv2d(3, inner_channels, kernel_size=7, stride=2, padding=3, bias=False),
                               nn.BatchNorm2d(inner_channels),
                               nn.ReLU(inplace=True),
                               nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

    ###
    #features = layers
    layers = []
    for num_blocks in num_block_list[:-1]:
      layers += [self.make_dense_block(inner_channels, num_blocks)]
      inner_channels += num_blocks * self.k

      out_channels = int(inner_channels / 2)
      layers += [TransitionBlock(inner_channels, out_channels)]
      inner_channels = out_channels

    # 마지막 블럭을 따로 구현하는 이유 = Transition Layer를 추가하지 않기 때문
    layers += [self.make_dense_block(inner_channels, num_block_list[-1])]
    inner_channels += num_block_list[-1] * self.k

    # pre-act 적용
    layers += [nn.BatchNorm2d(inner_channels)]
    layers += [nn.ReLU(inplace=True)]

    self.features = nn.Sequential(*layers)
    ###

    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.linear = nn.Linear(inner_channels, num_classes)

  def forward(self, x):
    x = self.conv1(x)
    x = self.features(x)
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    out = self.linear(x)

    return out

  def make_dense_block(self, in_channels, nblocks):
    dense_block = []
    for _ in range(nblocks):
      dense_block += [Bottleneck(in_channels, self.k)]
      in_channels += self.k

    return nn.Sequential(*dense_block)

In [None]:
def densenet121(**kwargs):
    return DenseNet([6,12,24,16], growth_rate=32, **kwargs)

def densenet169(**kwargs):
    return DenseNet([6,12,32,32], growth_rate=32, **kwargs)

def densenet201(**kwargs):
    return DenseNet([6,12,48,32], growth_rate=32, **kwargs)

def densenet264(**kwargs):
    return DenseNet([6,12,64,48], growth_rate=32, **kwargs)

In [None]:
model = densenet264()
# print(model)
!pip install torchinfo
from torchinfo import summary
summary(model, input_size=(2,3,224,224), device='cpu')

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


Layer (type:depth-idx)                        Output Shape              Param #
DenseNet                                      [2, 1000]                 --
├─Sequential: 1-1                             [2, 64, 56, 56]           --
│    └─Conv2d: 2-1                            [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                       [2, 64, 112, 112]         128
│    └─ReLU: 2-3                              [2, 64, 112, 112]         --
│    └─MaxPool2d: 2-4                         [2, 64, 56, 56]           --
├─Sequential: 1-2                             [2, 2688, 7, 7]           --
│    └─Sequential: 2-5                        [2, 256, 56, 56]          --
│    │    └─Bottleneck: 3-1                   [2, 96, 56, 56]           45,440
│    │    └─Bottleneck: 3-2                   [2, 128, 56, 56]          49,600
│    │    └─Bottleneck: 3-3                   [2, 160, 56, 56]          53,760
│    │    └─Bottleneck: 3-4                   [2, 192, 56, 56]          57,920


In [None]:
x = torch.randn(2, 3, 224, 224)
print(model(x).shape)

torch.Size([2, 1000])
