# **ConvNeXt**
此份程式碼會介紹如何使用 Pt=yTorch 的方式建構 ConvNeXt 的模型架構。

![image](https://hackmd.io/_uploads/r1UV1gH_a.png)

- [source paper](https://arxiv.org/abs/2201.03545)

## 匯入套件

In [None]:
# PyTorch 相關套件
import torch
import torch.nn as nn

## ConvNext Arhietecture

![image](https://hackmd.io/_uploads/rksNyeru6.png)


In [None]:
class LayerNorm2d(nn.LayerNorm):

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = torch.nn.functional.layer_norm(x, self.normalized_shape,
                                           self.weight, self.bias, self.eps)
        x = x.permute(0, 3, 1, 2)
        return x


class ConvNeXtModule(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels,
                                   in_channels,
                                   kernel_size=7,
                                   padding='same',
                                   groups=in_channels)
        self.layernorm = LayerNorm2d(in_channels)
        self.pointwise1 = nn.Conv2d(in_channels,
                                    in_channels * 4,
                                    kernel_size=1)
        self.pointwise1_act = nn.GELU()
        self.pointwise2 = nn.Conv2d(in_channels * 4,
                                    out_channels,
                                    kernel_size=1)

    def forward(self, x):
        skip = x
        x = self.depthwise(x)
        x = self.layernorm(x)
        x = self.pointwise1(x)
        x = self.pointwise1_act(x)
        x = self.pointwise2(x)
        x = x + skip
        return x


class Downsample(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.norm = LayerNorm2d(in_channels)
        self.conv = nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size=2,
                              stride=2)

    def forward(self, x):
        x = self.norm(x)
        x = self.conv(x)
        return x


class ConvNeXtBlock(nn.Module):

    def __init__(self, in_channels, out_channels, num_modules):
        super().__init__()
        layer_list = []
        layer_list.append(Downsample(in_channels, out_channels))
        for _ in range(num_modules):
            layer_list.append(ConvNeXtModule(out_channels, out_channels))
        self.layers = nn.Sequential(*layer_list)

    def forward(self, x):
        return self.layers(x)

In [None]:
class ConvNeXtNet(nn.Module):

    def __init__(self, in_channels, channels, num_modules, num_classes):
        super().__init__()
        layer_list = []
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, channels[0], kernel_size=4, stride=4),
            LayerNorm2d(channels[0]))
        layer_list.append(
            ConvNeXtBlock(channels[0], channels[0], num_modules[0]))

        for i in range(1, len(channels)):
            layer_list.append(
                ConvNeXtBlock(channels[i - 1], channels[i], num_modules[i]))
        self.layers = nn.Sequential(*layer_list)

        out_channels = channels[-1]
        self.classifier = nn.Sequential(LayerNorm2d(out_channels),
                                        nn.AdaptiveAvgPool2d(1), nn.Flatten(),
                                        nn.Linear(out_channels, num_classes))

    def forward(self, x):
        x = self.stem(x)
        x = self.layers(x)
        x = self.classifier(x)
        return x

In [None]:
inputs = torch.randn(1, 3, 224, 224)
model = ConvNeXtNet(3,
                    channels=[96, 192, 384, 768],
                    num_modules=[3, 3, 9, 3],
                    num_classes=1000)
outputs = model(inputs)
print(outputs.size())