In [1]:
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torchinfo import summary
from torchvision import datasets, transforms, models

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


## Implementatation

In [2]:
class GroupedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, n_groups, stride):
        super().__init__()
        self.group_channels = in_channels // n_groups
        self.grouped_conv_paths = nn.ModuleList([
            nn.Conv2d(self.group_channels, out_channels//n_groups, kernel_size=3, stride=stride, padding=1, bias=False) for _ in range(n_groups)
        ])
    
    def forward(self, x):
        x = torch.concat([
            path(x[:, self.group_channels * i:self.group_channels * (i + 1), ...]) for i, path in
            enumerate(self.grouped_conv_paths)], dim=1)
        return x

class ResNeXtBlock(nn.Module):
    def __init__(self, input_channels, inner_channels, n_groups, stride=1, projection=None):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, inner_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(inner_channels),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            # GroupedConvolution(inner_channels, inner_channels, n_groups, stride),
            nn.Conv2d(inner_channels, inner_channels, kernel_size=3, stride=stride, padding=1, bias=False, groups=n_groups),
            nn.BatchNorm2d(inner_channels),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(inner_channels, inner_channels*2, kernel_size=1, bias=False),
            nn.BatchNorm2d(inner_channels*2)
        )
        self.projection = projection
        self.relu = nn.ReLU()
    
    def forward(self, x):
        shortcut = self.projection(x) if self.projection else x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.relu(x + shortcut)
        return x


class ResNeXt(nn.Module):
    def __init__(self, n_channels, n_classes, n_groups):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layers1 = self.make_layers(64, 128, n_groups, 3, 1)
        self.layers2 = self.make_layers(256, 256, n_groups, 4, 2)
        self.layers3 = self.make_layers(512, 512, n_groups, 6, 2)
        self.layers4 = self.make_layers(1024, 1024, n_groups, 3, 2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048, n_classes)
        )
    
    def make_layers(self, input_channels, inner_channels, n_groups, n_blocks, stride):
        projection = nn.Sequential(
            nn.Conv2d(input_channels, inner_channels*2, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(inner_channels*2)
        )
        layers = [ResNeXtBlock(input_channels, inner_channels, n_groups, stride, projection)]
        for layer in range(n_blocks-1):
            layers.append(ResNeXtBlock(inner_channels * 2, inner_channels, n_groups))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.layers1(x)
        x = self.layers2(x)
        x = self.layers3(x)
        x = self.layers4(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        return x

In [3]:
resnext_model = ResNeXt(3, 1000, 32)
summary(resnext_model, input_size=(1, 3, 224, 224), col_names=['input_size', 'output_size', 'mult_adds'], depth=3, device='cpu')

Layer (type:depth-idx)                   Input Shape               Output Shape              Mult-Adds
ResNeXt                                  [1, 3, 224, 224]          [1, 1000]                 --
├─Sequential: 1-1                        [1, 3, 224, 224]          [1, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 3, 224, 224]          [1, 64, 112, 112]         118,013,952
│    └─BatchNorm2d: 2-2                  [1, 64, 112, 112]         [1, 64, 112, 112]         128
│    └─ReLU: 2-3                         [1, 64, 112, 112]         [1, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [1, 64, 112, 112]         [1, 64, 56, 56]           --
├─Sequential: 1-3                        [1, 64, 56, 56]           [1, 256, 56, 56]          --
│    └─ResNeXtBlock: 2-4                 [1, 64, 56, 56]           [1, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [1, 64, 56, 56]           [1, 256, 56, 56]          51,380,736
│    │    └─Seq

In [4]:
resnext_torch_model = models.resnext50_32x4d()
summary(resnext_torch_model, input_size=(1, 3, 224, 224), col_names=['input_size', 'output_size', 'mult_adds'], depth=3, device='cpu')

Layer (type:depth-idx)                   Input Shape               Output Shape              Mult-Adds
ResNet                                   [1, 3, 224, 224]          [1, 1000]                 --
├─Conv2d: 1-1                            [1, 3, 224, 224]          [1, 64, 112, 112]         118,013,952
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 112, 112]         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 64, 56, 56]           [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           [1, 128, 56, 56]          25,690,112
│    │    └─BatchNorm2d: 3-2             [1, 128, 56, 56]          [1, 128, 56, 56]          256
│    │    └─Re

## Training

In [6]:
from pathlib import Path

TRAIN_RATIO = 0.8
data_dir = Path('./data/')

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform)
train_ds, val_ds = random_split(train_ds, (TRAIN_RATIO, 1 - TRAIN_RATIO))
val_ds.transform = transform
test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
import wandb
from src.engine import *

config = dict(batch_size=64, lr=5e-4, epochs=20, dataset='CIFAR100')
with wandb.init(project='pytorch-study', name='ResNext50', config=config) as run:
    w_config = run.config
    train_dl = DataLoader(train_ds, batch_size=w_config.batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=w_config.batch_size, shuffle=True)
    
    n_classes = len(train_ds.dataset.classes)
    resnext_mdoel = ResNeXt(3, n_classes, 32).to(DEVICE)
        
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(resnext_mdoel.parameters(), lr=w_config.lr)
    
    loss_history, acc_history = train(resnext_mdoel, train_dl, val_dl, criterion, optimizer, w_config.epochs, DEVICE, run) 

Epoch=20: 100%|██████████| 20/20 [2:04:48<00:00, 374.41s/it, train_loss=0.151, train_acc=95.05%, val_loss=2.684, val_acc=54.66%] 
