### Imports

In [1]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torchinfo import summary as Model_Summary
import torch.optim as optim
from typing import Optional

### Inception Model

Convolutional block: 1 convolutional layer + batch normalisation, with ReLU activation

In [8]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1, padding='same'):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        # could also use default momentum and eps values for batch norm
        self.batch_norm = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.9997)
        
    def forward(self, x):
        return F.relu(self.batch_norm(self.conv(x)))

Inception block: contains various branches, output from these branches are concatenated to produce a final output in the ```forward()``` function

In [3]:
class InceptionBlockA(nn.Module):
    def __init__(self, 
                in_channels: int,
                pool_features: int
            ):
        super(InceptionBlockA, self).__init__()
        self.branch1x1 = ConvBlock(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = ConvBlock(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = ConvBlock(48, 64, kernel_size=5)

        self.branch3x3dbl_1 = ConvBlock(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = ConvBlock(64, 96, kernel_size=3)
        self.branch3x3dbl_3 = ConvBlock(96, 96, kernel_size=3)

        self.branch_pool = ConvBlock(in_channels, pool_features, kernel_size=1)
    
    def forward(self, x: Tensor):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        branches = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(branches, 1)

class InceptionBlockB(nn.Module):
    def __init__(self, 
                in_channels: int,
            ):
        super(InceptionBlockB, self).__init__()

        self.branch3x3 = ConvBlock(in_channels, 384, kernel_size=3, stride=2, padding='valid')
        self.branch3x3dbl_1 = ConvBlock(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = ConvBlock(64, 96, kernel_size=3)
        self.branch3x3dbl_3 = ConvBlock(96, 96, kernel_size=3, stride=2, padding='valid')
    
    def forward(self, x: Tensor):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2, padding=0)

        branches = [branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(branches, 1)

class InceptionBlockC(nn.Module):
    def __init__(self, 
                in_channels: int,
                channels_7x7: int,
            ):
        super(InceptionBlockC, self).__init__()

        self.branch1x1 = ConvBlock(in_channels, 192, kernel_size=1)

        c7 = channels_7x7
        self.branch7x7_1 = ConvBlock(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = ConvBlock(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = ConvBlock(c7, 192, kernel_size=(7, 1), padding=(3, 0))
        # self.branch7x7_2 = ConvBlock(c7, c7, kernel_size=(1, 7))
        # self.branch7x7_3 = ConvBlock(c7, 192, kernel_size=(7, 1))

        self.branch7x7dbl_1 = ConvBlock(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = ConvBlock(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = ConvBlock(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = ConvBlock(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = ConvBlock(c7, 192, kernel_size=(1, 7), padding=(0, 3))
        # self.branch7x7dbl_2 = ConvBlock(c7, c7, kernel_size=(7, 1))
        # self.branch7x7dbl_3 = ConvBlock(c7, c7, kernel_size=(1, 7))
        # self.branch7x7dbl_4 = ConvBlock(c7, c7, kernel_size=(7, 1))
        # self.branch7x7dbl_5 = ConvBlock(c7, 192, kernel_size=(1, 7))

        self.branch_pool = ConvBlock(in_channels, 192, kernel_size=1)
    
    def forward(self, x: Tensor):
        
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        branches = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(branches, 1)

class InceptionBlockD(nn.Module):
    def __init__(self, 
                in_channels: int,
            ):
        super(InceptionBlockD, self).__init__()

        self.branch3x3_1 = ConvBlock(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = ConvBlock(192, 320, kernel_size=3, stride=2, padding='valid')

        self.branch7x7x3_1 = ConvBlock(in_channels, 192, kernel_size=1)
        # self.branch7x7x3_2 = ConvBlock(192, 192, kernel_size=(1, 7), padding=(0, 3))
        # self.branch7x7x3_3 = ConvBlock(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_2 = ConvBlock(192, 192, kernel_size=(1, 7))
        self.branch7x7x3_3 = ConvBlock(192, 192, kernel_size=(7, 1))
        self.branch7x7x3_4 = ConvBlock(192, 192, kernel_size=3, stride=2, padding='valid')
    
    def forward(self, x: Tensor):

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2, padding=0)

        branches = [branch3x3, branch7x7x3, branch_pool]
        return torch.cat(branches, 1)

class InceptionBlockE(nn.Module):
    def __init__(self, 
                in_channels: int,
            ):
        super(InceptionBlockE, self).__init__()
        
        self.branch1x1 = ConvBlock(in_channels, 320, kernel_size=1)

        self.branch3x3_1 = ConvBlock(in_channels, 384, kernel_size=1)
        # self.branch3x3_2a = ConvBlock(384, 384, kernel_size=(1, 3), padding=(0, 1))
        # self.branch3x3_2b = ConvBlock(384, 384, kernel_size=(3, 1), padding=(1, 0))
        self.branch3x3_2a = ConvBlock(384, 384, kernel_size=(1, 3))
        self.branch3x3_2b = ConvBlock(384, 384, kernel_size=(3, 1))

        self.branch3x3dbl_1 = ConvBlock(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = ConvBlock(448, 384, kernel_size=3)
        # self.branch3x3dbl_3a = ConvBlock(384, 384, kernel_size=(1, 3), padding=(0, 1))
        # self.branch3x3dbl_3b = ConvBlock(384, 384, kernel_size=(3, 1), padding=(1, 0))
        self.branch3x3dbl_3a = ConvBlock(384, 384, kernel_size=(1, 3))
        self.branch3x3dbl_3b = ConvBlock(384, 384, kernel_size=(3, 1))

        self.branch_pool = ConvBlock(in_channels, 192, kernel_size=1)
    
    def forward(self, x: Tensor):

        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        branches = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(branches, 1)

Auxiliary classifiers

In [4]:
class InceptionAux(nn.Module):
    def __init__(self, 
                 in_channels: int, 
                 num_classes: int, 
            ):
        super(InceptionAux, self).__init__()
        # pooling operation can be applied as layer or defined in forward function dependent
        # on whether it is to be trained as well
        # self.pool0 = nn.AvgPool2d(kernel_size=5, stride=3)
        # self.conv0 = ConvBlock(in_channels, 128, kernel_size=1, stride=stride, padding=padding)
        self.conv0 = ConvBlock(in_channels, 128, kernel_size=1)
        self.conv1 = ConvBlock(128, 768, kernel_size=5)
        # self.pool1 = nn.AvgPool2d(kernel_size=1)
        # self.conv1.stddev = 0.01
        self.fc = nn.Linear(768, num_classes)
        # self.fc.stddev = 0.001
    
    def forward(self, x):
        # x = self.pool0(x)
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
        x = self.conv0(x)
        x = self.conv1(x)
        x = F.adaptive_avg_pool2d(x, (1 ,1))
        # x = self.pool1(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

Full model

In [6]:
class InceptionV3(nn.Module):
    def __init__(self, 
                 num_classes:int=2,
                 aux_logits: bool=True
                 ):
        super(InceptionV3, self).__init__()

        self.aux_logits = aux_logits
        
        # Initial convolutional and pooling layers
        # Note: padding = valid = 0 = no_padding; padding = same = 1
        self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=2)
        self.conv1 = nn.Conv2d(32, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(64, 80, kernel_size=1)
        self.conv4 = nn.Conv2d(80, 192, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)

        # Inception blocks
        self.inception_mixed_0 = InceptionBlockA(192, pool_features=32)
        self.inception_mixed_1 = InceptionBlockA(256, pool_features=64)
        self.inception_mixed_2 = InceptionBlockA(288, pool_features=64)
        self.inception_mixed_3 = InceptionBlockB(288)
        self.inception_mixed_4 = InceptionBlockC(768, channels_7x7=128)
        self.inception_mixed_5 = InceptionBlockC(768, channels_7x7=160)
        self.inception_mixed_6 = InceptionBlockC(768, channels_7x7=160)
        self.inception_mixed_7 = InceptionBlockC(768, channels_7x7=192)
        self.inception_mixed_8 = InceptionBlockD(768)
        self.inception_mixed_9 = InceptionBlockE(1280)
        self.inception_mixed_10 = InceptionBlockE(2048)
        
        if aux_logits:
            self.AuxLogits = InceptionAux(768, num_classes)

        # Final layers
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.8)
        self.fc = nn.Linear(in_features=2048, out_features=num_classes)

    def forward(self, x):
        # Initial layers
        x = self.conv0(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pool2(x)
        # Inception blocks
        x = self.inception_mixed_0(x)
        x = self.inception_mixed_1(x)
        x = self.inception_mixed_2(x)
        x = self.inception_mixed_3(x)
        x = self.inception_mixed_4(x)
        x = self.inception_mixed_5(x)
        x = self.inception_mixed_6(x)
        x = self.inception_mixed_7(x)
        # Auxiliary heads
        aux: Optional[Tensor] = None
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
        # more Inception
        x = self.inception_mixed_8(x)
        x = self.inception_mixed_9(x)
        x = self.inception_mixed_10(x)

        x = self.avg_pool(x)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x, aux

Check model

In [9]:
model = InceptionV3()
print(model)

# Model_Summary(model, test_input.shape)

InceptionV3(
  (conv0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
  (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1))
  (conv4): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (inception_mixed_0): InceptionBlockA(
    (branch1x1): ConvBlock(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)
      (batch_norm): BatchNorm2d(64, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
    )
    (branch5x5_1): ConvBlock(
      (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), padding=same)
      (batch_norm): BatchNorm2d(48, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
    )
    (branch5x5_2): ConvBlock(
      

In [11]:
device = torch.device("mps" if torch.has_mps else "cpu")
print(device)

mps


In [33]:
# Test model: sanity check with dummy input
test_input = torch.randn(32, 3, 299, 299)
output, aux = model(test_input)
print(output.shape)
print(output)

torch.Size([32, 2])
tensor([[-0.2828, -0.3216],
        [ 0.3339, -0.5673],
        [-0.4273,  0.0761],
        [ 0.6038,  0.9882],
        [-0.0190,  0.1430],
        [-0.3380,  0.0111],
        [ 0.0303, -0.1941],
        [-0.0349,  0.1665],
        [ 0.5243, -0.1355],
        [ 0.0596,  0.2506],
        [-0.1671, -0.0450],
        [ 0.4931,  0.1848],
        [ 0.1625, -0.5538],
        [-0.1402,  0.3171],
        [ 0.1562, -0.9295],
        [-0.5895,  0.3743],
        [ 0.3200, -0.3243],
        [ 0.6614, -0.1649],
        [-0.1347, -0.7843],
        [-0.1391,  0.2504],
        [-0.0815, -0.1637],
        [ 1.0770, -0.5914],
        [ 0.6380, -0.6112],
        [-0.0785,  0.0452],
        [ 0.3207, -0.1734],
        [ 0.0546,  0.0732],
        [-0.3749, -0.4200],
        [ 0.3292, -1.0037],
        [ 0.2760, -0.7996],
        [-0.5322, -1.6115],
        [ 0.3166,  0.7640],
        [ 0.5762, -0.2912]], grad_fn=<AddmmBackward0>)


### Helper functions

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Parameters

In [2]:
# Number of classes in the dataset
num_classes = 2
# Batch size for training (change depending on how much memory you have)
batch_size = 32
# Number of epochs to train for
# * paper: 100
num_epochs = 20

### Initialise model

In [None]:
# move function to `initalise_models.py`
def inception(num_classes):

    # Hyperparameters
    WEIGHT_DECAY = 0.9                  # Decay term for RMSProp.
    # weight_decay = 0.00004?
    # from inception_v3_parameters
    MOMENTUM = 0.9                      # Momentum in RMSProp.
    EPSILON = 1.0                       # Epsilon term for RMSProp.
    INITIAL_LEARNING_RATE = 0.1         # Initial learning rate.
    NUM_EPOCHS_PER_DECAY = 30.0         # Epochs after which learning rate decays.
    LEARNING_RATE_DECAY_FACTOR = 0.16   # Learning rate decay factor.

    model = InceptionV3()
    # * set parameters correct?
    optimiser = optim.RMSprop(lr=INITIAL_LEARNING_RATE, momentum=MOMENTUM, eps=EPSILON, weight_decay=WEIGHT_DECAY)
    criterion = nn.CrossEntropyLoss()
    # * complete
    parameters = {"learning_rate": INITIAL_LEARNING_RATE, "momentum": MOMENTUM}

    return model, optimiser, criterion, parameters

In [None]:
# Define the loss function with weight decay
loss_fn = nn.CrossEntropyLoss()
weight_decay = 0.00004
l2_reg = torch.tensor(0.)
for param in conv_layer.parameters():
    l2_reg += torch.norm(param)
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn(output, target) + weight_decay * l2_reg

### Load data

### Model

In [None]:
# Hyperparameters
RMSPROP_DECAY = 0.9                # Decay term for RMSProp.
MOMENTUM = 0.9                     # Momentum in RMSProp.
RMSPROP_EPSILON = 1.0              # Epsilon term for RMSProp.
INITIAL_LEARNING_RATE = 0.1        # Initial learning rate.
NUM_EPOCHS_PER_DECAY = 30.0        # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.16  # Learning rate decay factor.