In [2]:
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torchinfo import summary
from torchvision import datasets, transforms, models

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


## Implementatation

In [60]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, bias=False, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.conv(x)


class AuxiliaryClassifier(nn.Module):
    def __init__(self, in_channels, n_classes):
        super().__init__()
        self.avg_pool = nn.AvgPool2d(5, stride=3)
        self.conv1 = ConvBlock(in_channels, 128, kernel_size=1)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 1024),
            nn.ReLU(),
            nn.Linear(1024, n_classes)
        ) 
    def forward(self, x):
        x = self.avg_pool(x)
        x = self.conv1(x)
        x = self.gap(x)
        x = self.classifier(x)
        return x


class InceptionBlockA(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch1 = ConvBlock(in_channels, 64, kernel_size=1)
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, 48, kernel_size=1),
            ConvBlock(48, 64, kernel_size=3, padding=1),
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, 64, kernel_size=1),
            ConvBlock(64, 96, kernel_size=3, padding=1),
            ConvBlock(96, 96, kernel_size=3, padding=1)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channels, 64, kernel_size=1)
        )
    
    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        out = torch.cat((branch1, branch2, branch3, branch4), dim=1)
        return out


class ReductionA(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, 64, kernel_size=1),
            ConvBlock(64, 384, kernel_size=3, stride=2),
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, 64, kernel_size=1),
            ConvBlock(64, 96, kernel_size=3, padding=1),
            ConvBlock(96, 96, kernel_size=3, stride=2)
        )
    
    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        out = torch.cat((branch1, branch2, branch3), dim=1)
        return out


class InceptionBlockB(nn.Module):
    def __init__(self, in_channels, ch7):
        super().__init__()
        self.branch1 = ConvBlock(in_channels, 192, kernel_size=1)
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, ch7, kernel_size=1),
            ConvBlock(ch7, ch7, kernel_size=(1, 7), padding=(0, 3)),
            ConvBlock(ch7, 192, kernel_size=(7, 1), padding=(3, 0))
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, ch7, kernel_size=1),
            ConvBlock(ch7, ch7, kernel_size=(1, 7), padding=(0, 3)),
            ConvBlock(ch7, ch7, kernel_size=(7, 1), padding=(3, 0)),
            ConvBlock(ch7, ch7, kernel_size=(1, 7), padding=(0, 3)),
            ConvBlock(ch7, 192, kernel_size=(7, 1), padding=(3, 0))
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channels, 192, kernel_size=1)
        )
    
    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        out = torch.cat((branch1, branch2, branch3, branch4), dim=1)
        return out


class ReductionB(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, 192, kernel_size=1),
            ConvBlock(192, 320, kernel_size=3, stride=2),
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, 192, kernel_size=1),
            ConvBlock(192, 192, kernel_size=3, padding=1),
            ConvBlock(192, 192, kernel_size=3, stride=2)
        )
    
    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        out = torch.cat((branch1, branch2, branch3), dim=1)
        return out


class InceptionBlockC(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch1 = ConvBlock(in_channels, 320, kernel_size=1)
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, 448, kernel_size=1),
            ConvBlock(448, 384, kernel_size=3, padding=1),
        )
        self.branch2_1 = ConvBlock(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch2_2 = ConvBlock(384, 384, kernel_size=(3, 1), padding=(1, 0))
        self.branch3 = ConvBlock(in_channels, 384, kernel_size=1)
        self.branch3_1 = ConvBlock(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3_2 = ConvBlock(384, 384, kernel_size=(3, 1), padding=(1, 0))
        self.branch4 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channels, 192, kernel_size=1)
        )
    
    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch2_1 = self.branch2_1(branch2)
        branch2_2 = self.branch2_2(branch2)
        branch3 = self.branch3(x)
        branch3_1 = self.branch3_1(branch3)
        branch3_2 = self.branch3_2(branch3)
        branch4 = self.branch4(x)
        out = torch.cat((branch1, branch2_1, branch2_2, branch3_1, branch3_2, branch4), dim=1)
        return out

print(InceptionBlockA(288)(torch.rand(1, 288, 35, 35)).shape)
print(ReductionA(288)(torch.rand(1, 288, 35, 35)).shape)

print(InceptionBlockB(768, 128)(torch.rand(1, 768, 17, 17)).shape)
print(ReductionB(768)(torch.rand(1, 768, 17, 17)).shape)

print(InceptionBlockC(1280)(torch.rand(1, 1280, 8, 8)).shape)

torch.Size([1, 288, 35, 35])
torch.Size([1, 768, 17, 17])
torch.Size([1, 768, 17, 17])
torch.Size([1, 1280, 8, 8])
torch.Size([1, 2048, 8, 8])


In [61]:
class InceptionNetV3(nn.Module):
    def __init__(self, n_channels, n_classes, use_aux):
        super().__init__()
        
        self.use_aux = use_aux
        
        self.conv1_1 = ConvBlock(n_channels, 32, kernel_size=3, stride=2)
        self.conv1_2 = ConvBlock(32, 32, kernel_size=3)
        self.conv1_3 = ConvBlock(32, 64, kernel_size=3, padding=1)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        
        self.conv2_1 = ConvBlock(64, 80, kernel_size=3)
        self.conv2_2 = ConvBlock(80, 192, kernel_size=3, stride=2)
        self.conv2_3 = ConvBlock(192, 288, kernel_size=3, stride=1, padding=1)
        
        self.inception_3a = InceptionBlockA(288)
        self.inception_3b = InceptionBlockA(288)
        self.reduction1 = ReductionA(288)
       
        self.inception_4a = InceptionBlockB(768, 128)
        self.inception_4b = InceptionBlockB(768, 160)
        self.inception_4c = InceptionBlockB(768, 160)
        self.inception_4d = InceptionBlockB(768, 192)
        self.aux_classifier = AuxiliaryClassifier(768, n_classes) if use_aux else None
        self.reduction2 = ReductionB(768)
        
        self.inception_5a = InceptionBlockC(1280)
        self.inception_5b = InceptionBlockC(2048)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(2048, n_classes),
        )
        
    def forward(self, x):
        x = self.conv1_1(x)
        x = self.conv1_2(x)
        x = self.conv1_3(x)
        x = self.maxpool(x)
        
        x = self.conv2_1(x)
        x = self.conv2_2(x)
        x = self.conv2_3(x)
        
        x = self.inception_3a(x)
        x = self.inception_3b(x)
        x = self.reduction1(x)
        
        x = self.inception_4a(x)
        x = self.inception_4b(x)
        x = self.inception_4c(x)
        x = self.inception_4d(x)
        aux_loss = self.aux_classifier2(x)  if self.use_aux and self.training else None
        x = self.reduction2(x)
        
        x = self.inception_5a(x)
        x = self.inception_5b(x)
        
        x = self.avgpool(x)
        x = self.classifier(x)
        
        return (x, aux_loss) if self.use_aux else x

In [62]:
inception_v3_model = InceptionNetV3(3, 1000, use_aux=True).to('cpu')
summary(inception_v3_model, input_size=(1, 3, 299, 299), col_names=['output_size', 'num_params', 'mult_adds'], device='cpu', depth=2)

Layer (type:depth-idx)                        Output Shape              Param #                   Mult-Adds
InceptionNetV3                                [1, 1000]                 1,255,656                 --
├─ConvBlock: 1-1                              [1, 32, 149, 149]         --                        --
│    └─Sequential: 2-1                        [1, 32, 149, 149]         928                       19,181,728
├─ConvBlock: 1-2                              [1, 32, 147, 147]         --                        --
│    └─Sequential: 2-2                        [1, 32, 147, 147]         9,280                     199,148,608
├─ConvBlock: 1-3                              [1, 64, 147, 147]         --                        --
│    └─Sequential: 2-3                        [1, 64, 147, 147]         18,560                    398,297,216
├─MaxPool2d: 1-4                              [1, 64, 73, 73]           --                        --
├─ConvBlock: 1-5                              [1, 80, 71, 

## Training

In [None]:
models.inception_v3

In [6]:
from pathlib import Path

TRAIN_RATIO = 0.8
data_dir = Path('./data/')

transform = transforms.Compose([
    transforms.Resize(299),
    transforms.ToTensor(),
])

train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform)
train_ds, val_ds = random_split(train_ds, (TRAIN_RATIO, 1 - TRAIN_RATIO))
val_ds.transform = transform
test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
import wandb
from src.engine import *

config = dict(batch_size=64, lr=1e-5, epochs=20, dataset='CIFAR100')
with wandb.init(project='pytorch-study', name='InceptionV3', config=config) as run:
    w_config = run.config
    train_dl = DataLoader(train_ds, batch_size=w_config.batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=w_config.batch_size, shuffle=True)
    
    n_classes = len(train_ds.dataset.classes)
    vgg_model = InceptionNetV3(3, 1000, use_aux=False).to(DEVICE)
        
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(vgg_model.parameters(), lr=w_config.lr)
    
    loss_history, acc_history = train(vgg_model, train_dl, val_dl, criterion, optimizer, w_config.epochs, DEVICE, run) 

Epoch=5: 100%|██████████| 5/5 [05:26<00:00, 65.36s/it, train_loss=2.303, train_acc=10.37%, val_loss=2.303, val_acc=8.81%]
