<a href="https://colab.research.google.com/github/ajw1587/Pytorch_Study/blob/main/31_ResNext.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

if torch.cuda.is_available():
  device = 'cuda'
else:
  device = 'cpu'

print(device)

cuda


In [2]:
def conv_start():
    return nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=4),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2),
    )

def bottleneck_block(in_dim, mid_dim, out_dim, groups=32, down=False):
    layers = []
    width = mid_dim // 64 * 32 * 4 # bottleneck_width
    k = 2 if down else 1
    layers.append(nn.Conv2d(in_dim, width, kernel_size=1, stride=k, padding=0))

    layers.extend([
        nn.BatchNorm2d(width),
        nn.ReLU(inplace=True),
        nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1, groups=groups),
        nn.BatchNorm2d(width),
        nn.ReLU(inplace=True),
        nn.Conv2d(width, out_dim, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(out_dim),
    ])
    return nn.Sequential(*layers)

class Bottleneck(nn.Module):
    def __init__(self, in_dim, mid_dim, out_dim, down:bool = False, starting:bool=False) -> None:
        super(Bottleneck, self).__init__()
        if starting:
            down = False
        self.block = bottleneck_block(in_dim, mid_dim, out_dim, down=down)
        self.relu = nn.ReLU(inplace=True)
        k = 2 if down else 1
        conn_layer = nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=k, padding=0) # size 줄어듬

        self.changedim = nn.Sequential(conn_layer, nn.BatchNorm2d(out_dim))

    def forward(self, x):
        identity = self.changedim(x)
        x = self.block(x)
        x += identity
        x = self.relu(x)
        return x

def make_layer(in_dim, mid_dim, out_dim, repeats, starting=False):
        layers = []
        layers.append(Bottleneck(in_dim, mid_dim, out_dim, down=True, starting=starting))
        for _ in range(1, repeats):
            layers.append(Bottleneck(out_dim, mid_dim, out_dim, down=False))
        return nn.Sequential(*layers)

class ResNeXt(nn.Module):
    def __init__(self, repeats:list = [3,4,6,3], num_classes=1000):
        super(ResNeXt, self).__init__()
        self.num_classes = num_classes
        self.conv1 = conv_start()

        base_dim = 64
        self.conv2 = make_layer(base_dim, base_dim, base_dim*4, repeats[0], starting=True)
        self.conv3 = make_layer(base_dim*4, base_dim*2, base_dim*8, repeats[1])
        self.conv4 = make_layer(base_dim*8, base_dim*4, base_dim*16, repeats[2])
        self.conv5 = make_layer(base_dim*16, base_dim*8, base_dim*32, repeats[3])

        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.classifer = nn.Linear(2048, self.num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifer(x)
        return x

In [3]:
# resnet = ResNeXt()
# param = list(resnet.parameters())
# # print(len(param))
# for i in param:
#     print(i.shape)

In [6]:
from torchsummary import summary

resnet = ResNeXt().to(device)
summary(resnet, (3, 224, 224), device=device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 113, 113]           9,472
       BatchNorm2d-2         [-1, 64, 113, 113]             128
              ReLU-3         [-1, 64, 113, 113]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5          [-1, 256, 56, 56]          16,640
       BatchNorm2d-6          [-1, 256, 56, 56]             512
            Conv2d-7          [-1, 128, 56, 56]           8,320
       BatchNorm2d-8          [-1, 128, 56, 56]             256
              ReLU-9          [-1, 128, 56, 56]               0
           Conv2d-10          [-1, 128, 56, 56]           4,736
      BatchNorm2d-11          [-1, 128, 56, 56]             256
             ReLU-12          [-1, 128, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]          33,024
      BatchNorm2d-14          [-1, 256,