In [1]:
import torch
from torch import nn
from torchinfo import summary

In [6]:
class WideResnetBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1, projection=None, drop_p=0.3):
        super().__init__()

        # pre-activation 적용
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU(),
            nn.Conv2d(in_channel, out_channel, 3, stride=stride, padding=1, bias=False),

            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Dropout(p=drop_p),
            nn.Conv2d(out_channel, out_channel, 3, padding=1, bias=False)
        )

        self.projection = projection

    def forward(self, x):
        residual = self.residual(x)

        if self.projection is not None:
            shortcut = self.projection(x)
        else:
            shortcut = x
        
        output = residual + shortcut
        return output

### Input Shape : (3, 32, 32) [Cifar10]

In [18]:
class WideResnet(nn.Module):
    def __init__(self, depth, k, num_classes):
        super().__init__()

        N = int((depth-4)/3/2)
        self.in_channel = 16

        self.conv1 = nn.Conv2d(3, 16, 3, padding=1, bias=False)
        self.conv2 = self.make_layers(16*k, N, stride=1)
        self.conv3 = self.make_layers(32*k, N, stride=2)
        self.conv4 = self.make_layers(64*k, N, stride=2)

        self.bn = nn.BatchNorm2d(64*k)
        self.relu = nn.ReLU()

        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64*k, num_classes)

        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            
            elif isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0, std=0.01)
                nn.init.constant_(module.bias, 0.0)

    def make_layers(self, out_channel, N, stride):
        if stride!=1 or self.in_channel!=out_channel:
            projection = nn.Conv2d(self.in_channel, out_channel, 1, stride=stride, bias=False)
            # pre-activation 이라서, nn.BatchNorm2d(out_channel) 생략
        else:
            projection = None

        layers = []
        layers += [WideResnetBlock(self.in_channel, out_channel, stride=stride, projection=projection)]

        self.in_channel = out_channel

        for _ in range(0, N-1):
            layers += [WideResnetBlock(self.in_channel, out_channel)]
        
        return nn.Sequential(*layers)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = self.bn(x)
        x= self.relu(x)
        x = self.GlobalAvgPool(x)

        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

model = WideResnet(depth=28, k=10, num_classes=10)
summary(model, (2, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
WideResnet                               [2, 10]                   --
├─Conv2d: 1-1                            [2, 16, 32, 32]           432
├─Sequential: 1-2                        [2, 160, 32, 32]          --
│    └─WideResnetBlock: 2-1              [2, 160, 32, 32]          --
│    │    └─Sequential: 3-1              [2, 160, 32, 32]          253,792
│    │    └─Conv2d: 3-2                  [2, 160, 32, 32]          2,560
│    └─WideResnetBlock: 2-2              [2, 160, 32, 32]          --
│    │    └─Sequential: 3-3              [2, 160, 32, 32]          461,440
│    └─WideResnetBlock: 2-3              [2, 160, 32, 32]          --
│    │    └─Sequential: 3-4              [2, 160, 32, 32]          461,440
│    └─WideResnetBlock: 2-4              [2, 160, 32, 32]          --
│    │    └─Sequential: 3-5              [2, 160, 32, 32]          461,440
├─Sequential: 1-3                        [2, 320, 16, 16]    