In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import transforms as trans

import numpy as np
from torch.utils.data import Subset


random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(42)

In [None]:
#@title Trainable

class Trainable:
    def __init__(self, epochs, loss_f, batch_size, in_shape = (224, 224), has_sched = False, clip = False):
        self.epochs = epochs
        self.batch_size = batch_size
        self.loss_f = loss_f
        self.has_sched = has_sched
        self.clip = clip
        self.transforms = trans.Compose([trans.ToTensor(), trans.Resize(in_shape)])
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self = self.to(self.device)

    def fit(self, ds):
        # TODO: set pin_memory = True
        train_set = ds('/content/', train = True, transform = self.transforms, download = True)
        train_loader = torch.utils.data.DataLoader(train_set, self.batch_size, shuffle = True)
        
        self.min_loss = 10000
        for e in range(self.epochs):
            total_loss = 0
            for x, y in train_loader:
                x, y = x.to(self.device), y.to(self.device)
                out = self(x)
                loss = self.loss_f(out, y)
                total_loss += loss.item()
                self.optim.zero_grad()
                loss.backward()
                if self.clip:
                    nn.utils.clip_grad_norm_(self.parameters(), 2.0)
                self.optim.step()
            if self.has_sched:
                self.sched.step()
            
            print("Train loss #", e, ":", round(total_loss / len(train_set), 3))
            self.validate(ds)

    def validate(self, ds):
        test_set = ds('/content/', train = True, transform = self.transforms, download = True)
        test_loader = torch.utils.data.DataLoader(test_set, self.batch_size)
        
        self.eval()
        with torch.no_grad():
            total_loss = 0
            correct = 0
            for x, y in test_loader:
                x, y = x.to(self.device), y.to(self.device)
                out = self(x)
                loss = self.loss_f(out, y)
                total_loss += loss.item()
                pred = out.argmax(1)
                correct += pred.eq(y.view_as(pred)).sum().item()
            if self.min_loss > total_loss:
                self.min_loss = total_loss
                torch.save({
                    "optim": self.optim.state_dict(),
                    "model": self.state_dict()
                }, "/content/" + self.__class__.__name__ + "_best_model.pth")
        print("Test loss:", round(total_loss / len(test_set), 3), "Accuracy:", 100 * correct / len(test_set))
        self.train()

    def load_model(self):
        self.load_state_dict(torch.load("/content/" + self.__class__.__name__ + "_best_model.pth"))


In [None]:
#@title AlexNet
## https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf

class AlexNet(nn.Module, Trainable):
    def __init__(self, in_channels):
        nn.Module.__init__(self)
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 96, (11, 11), 4),
            nn.ReLU(),
            #nn.LocalResponseNorm(5, 0.0001, 0.75, 2),
            nn.MaxPool2d(3, 2),
            nn.Conv2d(96, 256, (5, 5), padding = 2),
            nn.ReLU(),
            #nn.LocalResponseNorm(5, 0.0001, 0.75, 2),
            nn.MaxPool2d(3, 2),
            nn.Conv2d(256, 384, (3, 3), padding = 1),
            nn.ReLU(),
            nn.Conv2d(384, 384, (3, 3), padding = 1),
            nn.ReLU(),
            nn.Conv2d(384, 256, (3, 3), padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(3, 2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 5 * 5, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, 10)
        )
        self.__init_weights()
        Trainable.__init__(self, 100, nn.CrossEntropyLoss(reduction = 'sum'), 8, has_sched = True)

    def __init_weights(self):
        for layer in self.conv:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, 0, 0.01)
                nn.init.constant_(layer.bias, 0)
        #nn.init.constant_(self.conv[4].bias, 1)
        #nn.init.constant_(self.conv[10].bias, 1)
        #nn.init.constant_(self.conv[12].bias, 1)

    def forward(self, x):
        return self.fc(self.conv(x))

    def fit(self, ds):
        #self.optim = torch.optim.SGD(self.parameters(), 0.01, 0.9, 0.0005)
        self.optim = torch.optim.Adam(self.parameters(), 0.0001)
        self.sched = torch.optim.lr_scheduler.StepLR(self.optim, 30, gamma = 0.1)
        Trainable.fit(self, ds)


In [None]:
#@title ZFNet
# https://arxiv.org/pdf/1311.2901.pdf

class ZFNet(nn.Module, Trainable):
    def __init__(self, in_channels):
        nn.Module.__init__(self)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 96, 7, stride = 2),
            nn.ReLU(),
            nn.MaxPool2d(3, 2),
            nn.LocalResponseNorm(5, 0.0001, 0.75, 2),
            nn.Conv2d(96, 256, 5, stride = 2),
            nn.ReLU(),
            nn.MaxPool2d(3, 2),
            nn.LocalResponseNorm(5, 0.0001, 0.75, 2),
            nn.Conv2d(256, 384, 3),
            nn.Conv2d(384, 384, 3),
            nn.Conv2d(384, 256, 3),
            nn.MaxPool2d(3, 2),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.Dropout(),
            nn.Linear(4096, 10)
        )
        Trainable.__init__(self, 100, nn.CrossEntropyLoss(), 128)
        self.__init_weights()

    def __init_weights(self):
        for layer in self.conv:
            if isinstance(layer, nn.Conv2d):
                nn.init.constant_(layer.weight, 0.01)
                nn.init.constant_(layer.bias, 0)
        nn.init.constant_(self.fc[1].weight, 0.01)
        nn.init.constant_(self.fc[1].bias, 0)
        nn.init.constant_(self.fc[3].weight, 0.01)
        nn.init.constant_(self.fc[3].bias, 0)

    def forward(self, input):
        return self.fc(self.conv(input))

    def fit(self, ds):
        #self.optim = torch.optim.SGD(self.parameters(), 0.001, 0.2)
        self.optim = torch.optim.Adam(self.parameters(), 0.0001)
        Trainable.fit(self, ds)

In [None]:
#@title VGG16
# https://arxiv.org/pdf/1409.1556.pdf
# https://arxiv.org/pdf/1505.06798.pdf

class VGG16(nn.Module, Trainable):
    def __init__(self, in_channels):
        nn.Module.__init__(self)
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(512, 512, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(7 * 7 * 512, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 10)
        )
        self.__init_weights()
        Trainable.__init__(self, 100, nn.CrossEntropyLoss(), 16, has_sched = False)

    def __init_weights(self):
        for layer in self.conv:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, 0, 0.01)
                nn.init.constant_(layer.bias, 0)

    def forward(self, x):
        return self.fc(self.conv(x))

    def fit(self, ds):
        self.optim = torch.optim.Adam(self.parameters(), lr = 0.0001)
        #self.optim = torch.optim.SGD(self.parameters(), lr = 0.001, momentum = 0.9, weight_decay = 0.0005)
        #self.sched = torch.optim.lr_scheduler.StepLR(self.optim, 30, gamma = 0.1)

        Trainable.fit(self, ds)


In [None]:

#@title ResNet50
# https://arxiv.org/pdf/1512.03385.pdf
# https://arxiv.org/pdf/1704.06904.pdf

class ConvBlock(nn.Module):
    def __init__(self, nr_blocks, in_channels, out_channels, stride = 1):
        nn.Module.__init__(self)
        self.nr_blocks = nr_blocks
        self.blocks = []
        self.downsample = nn.Conv2d(in_channels, out_channels * 4, 1, stride = stride, bias = False)
        for block in range(nr_blocks):
            # starting the second block, update the in_channels and stride
            if block == 1:
                in_channels = out_channels * 4
                stride = 1
            self.blocks.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 1, bias = False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(),
                    nn.Conv2d(out_channels, out_channels, 3, stride = stride, padding = 1, bias = False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(),
                    nn.Conv2d(out_channels, out_channels * 4, 1, bias = False),
                    nn.BatchNorm2d(out_channels * 4),
                )
            )
        self.blocks = nn.Sequential(*self.blocks)
        
    def forward(self, x):
        identity = torch.clone(x)
        # downsample for the first block
        identity = self.downsample(identity)
        for block in range(self.nr_blocks):
            x = self.blocks[block](x)
            x += identity
            x = nn.ReLU()(x)
            identity = x
        return x


class ResNet50(nn.Module, Trainable):
    def __init__(self, in_channels):
        nn.Module.__init__(self)
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, 7, stride = 2, padding = 3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, 2, padding = 1),

            ConvBlock(3, 64, 64),
            ConvBlock(4, 256, 128, stride = 2),
            ConvBlock(6, 512, 256, stride = 2),
            ConvBlock(3, 1024, 512, stride = 2),
            nn.AdaptiveAvgPool2d(1)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 4, 10)
        )
        Trainable.__init__(self, 100, nn.CrossEntropyLoss(), 64)
        self = self.to(self.device)

    def forward(self, x):
        return self.fc(self.conv(x))

    def fit(self, ds):
        #self.optim = torch.optim.Adam(self.parameters(), lr = 0.001)
        self.optim = torch.optim.SGD(self.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 0.0005)
        #self.sched = torch.optim.lr_scheduler.StepLR(self.optim, 30, gamma = 0.1)
        Trainable.fit(self, ds)

In [None]:
#@title GoogLeNet
# https://arxiv.org/abs/1409.4842

class ConvModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)


class InceptionModule(nn.Module):
    # kwargs contains the number of filters for the convolutions in the
    # same order as the table columns in the paper (1x1, 3x3 reduce, 3x3, 5x5 reduce, 5x5, pool proj)
    def __init__(self, in_channels, kwargs):
        super().__init__()
        self.conv1x1_1 = ConvModule(in_channels, kwargs[0], 1)
        self.conv3x3 = nn.Sequential(
            ConvModule(in_channels, kwargs[1], 1),
            ConvModule(kwargs[1], kwargs[2], 3, padding = 1)
        )
        self.conv5x5 = nn.Sequential(
            ConvModule(in_channels, kwargs[3], 1),
            ConvModule(kwargs[3], kwargs[4], 3, padding = 1)
        )
        self.conv1x1_2 = nn.Sequential(
            nn.MaxPool2d(3, 1, padding = 1),
            ConvModule(in_channels, kwargs[5], 1)
        )
    
    def forward(self, x):
        x1 = self.conv1x1_1(x)
        x2 = self.conv3x3(x)
        x3 = self.conv5x5(x)
        x4 = self.conv1x1_2(x)
        return torch.cat((x1, x2, x3, x4), 1)


class AuxiliaryOut(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.AdaptiveAvgPool2d(4),
            ConvModule(in_channels, 128, 1),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.7),
            nn.Linear(1024, 10)
        )

    def forward(self, x):
        #print(self.conv(x).shape)
        return self.fc(self.conv(x))
        

class GoogLeNet(nn.Module, Trainable):
    def __init__(self, in_channels):
        nn.Module.__init__(self)
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.training_mode = True
        self.conv = nn.Sequential(
            ConvModule(in_channels, 64, 7, stride = 2),
            nn.MaxPool2d(3, 2),
            #nn.LocalResponseNorm(),
            ConvModule(64, 64, 1),
            ConvModule(64, 192, 1),
            #nn.LocalResponseNorm(),
            nn.MaxPool2d(3, 2),
        )
        self.inception3a = InceptionModule(192, [64, 96, 128, 16, 32, 32])
        self.inception3b = InceptionModule(256, [128, 128, 192, 32, 96, 64])
        self.inception4a = InceptionModule(480, [192, 96, 208, 16, 48, 64])
        self.aux_out_1 = AuxiliaryOut(512)
        self.inception4b = InceptionModule(512, [160, 112, 224, 24, 64, 64])
        self.inception4c = InceptionModule(512, [128, 128, 256, 24, 64, 64])
        self.inception4d = InceptionModule(512, [112, 144, 288, 32, 64, 64])
        self.aux_out_2 = AuxiliaryOut(528)
        self.inception4e = InceptionModule(528, [256, 160, 320, 32, 128, 128])
        self.inception5a = InceptionModule(832, [256, 160, 320, 32, 128, 128])
        self.inception5b = InceptionModule(832, [384, 192, 384, 48, 128, 128])
        self.fc = nn.Sequential(nn.Flatten(), nn.Linear(1024, 10))

        Trainable.__init__(self, 100, self.loss, 32)
        self = self.to(self.device)

    def forward(self, x):
        x = self.conv(x)
        
        x = self.inception3a(x)
        x = self.inception3b(x)

        x = self.inception4a(x)
        if self.training_mode:
            aux_out_1 = self.aux_out_1(x)
        x = F.max_pool2d(x, 3, 2)
        
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        if self.training_mode:
            aux_out_2 = self.aux_out_2(x)
        x = self.inception4e(x)
        x = F.max_pool2d(x, 3, 2)
        x = self.inception5a(x)
        x = self.inception5b(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = F.dropout(x, 0.4)
        x = self.fc(x)
        return x, aux_out_1, aux_out_2

    def fit(self, ds):
        self.training_mode = True
        self.optim = torch.optim.SGD(self.parameters(), lr = 0.001, momentum = 0.9)
        #self.sched = torch.optim.lr_scheduler.StepLR(self.optim, 8, gamma = 0.1)
        Trainable.fit(self, ds)

    def loss(self, out, y):
        x, aux_out_1, aux_out_2 = out
        l1 = F.cross_entropy(x, y)
        l2 = F.cross_entropy(aux_out_1, y)
        l3 = F.cross_entropy(aux_out_2, y)
        return 0.3 * (l2 + l3) + l1



In [None]:
#@title Inception-v3
# https://arxiv.org/pdf/1512.00567.pdf

class ConvModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)

class InceptionA(nn.Module):
    def __init__(self, in_channels, pool1x1_filters):
        super().__init__()
        
        self.conv3x3 = nn.Sequential(
           ConvModule(in_channels, 64, 1),
           ConvModule(64, 96, 3, padding = 1),
           ConvModule(96, 96, 3, padding = 1)
        )
        
        self.conv5x5 = nn.Sequential(
           ConvModule(in_channels, 48, 1),
           #ConvModule(48, 96, 3, padding = 1),
           #ConvModule(96, 96, 3, padding = 1),
           ConvModule(48, 64, 5, padding = 2)
        )

        self.pool1x1 = nn.Sequential(
           #nn.AvgPool2d(3),
           ConvModule(in_channels, pool1x1_filters, 1)
        )

        self.conv1x1 = nn.Sequential(
           ConvModule(in_channels, 64, 1)
        )

    def forward(self, x):
        x1 = self.conv3x3(x)
        x2 = self.conv5x5(x)
        x3 = self.pool1x1(x)
        x4 = self.conv1x1(x)
        return torch.cat([x1, x2, x3, x4], 1)

class InceptionB(nn.Module):
    # reduction module
    def __init__(self, in_channels):
        super().__init__()
        
        self.conv3x3_3x3 = nn.Sequential(
           ConvModule(in_channels, 64, 1),
           ConvModule(64, 96, 3),
           ConvModule(96, 96, 3, stride = 2, padding = 1)
        )
        
        self.conv3x3 = ConvModule(in_channels, 384, 3, stride = 2)
        self.pool = nn.MaxPool2d(3, 2)

    def forward(self, x):
        x1 = self.conv3x3_3x3(x)
        x2 = self.conv3x3(x)
        x3 = self.pool(x)
        return torch.cat([x1, x2, x3], 1)

class InceptionC(nn.Module):
    def __init__(self, in_channels, nr_7x7s):
        super().__init__()
        
        self.conv7x7_7x7 = nn.Sequential(
           ConvModule(in_channels, nr_7x7s, 1),
           ConvModule(nr_7x7s, nr_7x7s, (7, 1), padding = (3, 0)),
           ConvModule(nr_7x7s, nr_7x7s, (1, 7), padding = (0, 3)),
           ConvModule(nr_7x7s, nr_7x7s, (7, 1), padding = (3, 0)),
           ConvModule(nr_7x7s, 192, (1, 7), padding = (0, 3))
        )
        
        self.conv7x7 = nn.Sequential(
           ConvModule(in_channels, nr_7x7s, 1),
           ConvModule(nr_7x7s, nr_7x7s, (1, 7), padding = (0, 3)),
           ConvModule(nr_7x7s, 192, (7, 1), padding = (3, 0))
        )

        self.pool1x1 = nn.Sequential(
           #nn.AvgPool2d(3),
           ConvModule(in_channels, 192, 1)
        )

        self.conv1x1 = ConvModule(in_channels, 192, 1)

    def forward(self, x):
        x1 = self.conv7x7_7x7(x)
        x2 = self.conv7x7(x)
        x3 = self.pool1x1(x)
        x4 = self.conv1x1(x)
        return torch.cat([x1, x2, x3, x4], 1)

class InceptionD(nn.Module):
    # reduction module
    def __init__(self, in_channels):
        super().__init__()

        self.conv7x7 = nn.Sequential(
            ConvModule(in_channels, 192, 1),
            ConvModule(192, 192, (1, 7), padding = (0, 3)),
            ConvModule(192, 192, (7, 1), padding = (3, 0)),
            ConvModule(192, 192, 3, stride = 2)
        )

        self.conv3x3 = nn.Sequential(
           ConvModule(in_channels, 192, 1),
           ConvModule(192, 320, 3, stride = 2)
        )
        
        self.pool = nn.MaxPool2d(3, 2)

    def forward(self, x):
        x1 = self.conv7x7(x)
        x2 = self.conv3x3(x)
        x3 = self.pool(x)
        return torch.cat([x1, x2, x3], 1)

class InceptionE(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.conv3x3_3x3 = nn.Sequential(
           ConvModule(in_channels, 448, 1),
           ConvModule(448, 384, 3, padding = 1),
           ConvModule(384, 384, (1, 3), padding = (0, 1)),
           ConvModule(384, 384, (3, 1), padding = (1, 0))
        )
        
        self.conv3x3 = nn.Sequential(
           ConvModule(in_channels, 384, 1),
           ConvModule(384, 384, (1, 3), padding = (0, 1)),
           ConvModule(384, 384, (3, 1), padding = (1, 0))
        )

        self.pool1x1 = nn.Sequential(
           #nn.AvgPool2d(3),
           ConvModule(in_channels, 192, 1)
        )

        self.conv1x1 = ConvModule(in_channels, 320, 1)

    def forward(self, x):
        x1 = self.conv3x3_3x3[0](x)
        x1 = self.conv3x3_3x3[1](x1)
        x1_1x3 = self.conv3x3_3x3[2](x1)
        x1_3x1 = self.conv3x3_3x3[3](x1)
        x1 = torch.cat((x1_1x3, x1_3x1), 1)

        x2 = self.conv3x3[0](x)
        x2_1x3 = self.conv3x3[1](x2)
        x2_3x1 = self.conv3x3[2](x2)
        x2 = torch.cat((x2_1x3, x2_3x1), 1)

        x3 = self.pool1x1(x)
        x4 = self.conv1x1(x)
        return torch.cat([x1, x2, x3, x4], 1)

class AuxiliaryOut(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.AvgPool2d(5, 3),
            ConvModule(in_channels, 128, 1),
            ConvModule(128, 768, 5)
        )
        self.conv[2].stddev = 0.01
        
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(768, 10)
        )
        self.fc[2].stddev = 0.001

    def forward(self, x):
        return self.fc(self.conv(x))

class InceptionV3(nn.Module, Trainable):
    def __init__(self, in_channels):
        nn.Module.__init__(self)
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.training_mode = False
        self.conv = nn.Sequential(
            ConvModule(in_channels, 32, 3, stride = 2),
            ConvModule(32, 32, 3),
            ConvModule(32, 64, 3, padding = 1),
            nn.MaxPool2d(3, 2),
            ConvModule(64, 80, 1),
            ConvModule(80, 192, 3),
            nn.MaxPool2d(3, 2),
        )
        self.inception_a = nn.Sequential(
            InceptionA(192, 32),
            InceptionA(256, 64), 
            InceptionA(288, 64)
        )
        self.inception_b = InceptionB(288) # reduction
        self.inception_c = nn.Sequential(
            InceptionC(768, 128),
            InceptionC(768, 160),
            InceptionC(768, 160),
            InceptionC(768, 192)
        )

        self.inception_d = InceptionD(768) # reduction
        self.inception_e = nn.Sequential(
            InceptionE(1280),
            InceptionE(2048)
        )
        self.out_aux = AuxiliaryOut(768)
        self.fc = nn.Sequential(nn.AvgPool2d(1), nn.Dropout(), nn.Flatten(), nn.Linear(2048, 10))
        Trainable.__init__(self, 100, self.loss, 32, (299, 299), True, True)
        self = self.to(self.device)

    def forward(self, x):
        x = self.conv(x)
        x = self.inception_a(x)
        x = self.inception_b(x)
        x = self.inception_c(x)
        if self.training_mode:
            x_out_aux = self.out_aux(x)
        x = self.inception_d(x)
        x = self.inception_e(x)
        x = F.max_pool2d(x, 8)
        return self.fc(x), x_out_aux

    def fit(self, ds):
        self.training_mode = True
        #self.optim = torch.optim.SGD(self.parameters(), lr = 0.045, momentum = 0.9)
        self.optim = torch.optim.RMSprop(self.parameters(), lr = 0.045, eps = 1.0, momentum = 0.9)
        self.sched = torch.optim.lr_scheduler.StepLR(self.optim, 2, gamma = 0.94)
        Trainable.fit(self, ds)

    def loss(self, out, y):
        x, aux_out = out
        l1 = F.cross_entropy(x, y)
        l2 = F.cross_entropy(aux_out, y)
        return l1 + l2


In [None]:
#@title Xception
# https://arxiv.org/abs/1610.02357

class ConvModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)

class SeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, **kwargs):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel, groups = in_channels, **kwargs)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return self.bn(x)


class EntryModule(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            ConvModule(in_channels, 32, 3, stride = 2),
            ConvModule(32, 64, 3),
        )
        
        self.identity_1 = nn.Sequential(
            ConvModule(64, 128, 1, stride = 2),
            nn.BatchNorm2d(128)
        )
        self.sep_conv_1 = nn.Sequential(
            SeparableConv(64, 128, 3, padding = 1),
            nn.ReLU(),
            SeparableConv(128, 128, 3, padding = 1),
            nn.MaxPool2d(3, 2, 1),
        )

        self.identity_2 = nn.Sequential(
            ConvModule(128, 256, 1, stride = 2),
            nn.BatchNorm2d(256)
        )
        self.sep_conv_2 = nn.Sequential(
            nn.ReLU(),
            SeparableConv(128, 256, 3, padding = 1),
            nn.ReLU(),
            SeparableConv(256, 256, 3, padding = 1),
            nn.MaxPool2d(3, 2, 1),
        )

        self.identity_3 = nn.Sequential(
            ConvModule(256, 728, 1, stride = 2),
            nn.BatchNorm2d(728)
        )
        self.sep_conv_3 = nn.Sequential(
            nn.ReLU(),
            SeparableConv(256, 728, 3, padding = 1),
            nn.ReLU(),
            SeparableConv(728, 728, 3, padding = 1),
            nn.MaxPool2d(3, 2, 1)
        )

    def forward(self, x):
        x = self.conv(x)

        identity = self.identity_1(x)
        x = self.sep_conv_1(x)
        x += identity
        
        identity = self.identity_2(x)
        x = self.sep_conv_2(x)
        x += identity
        
        identity = self.identity_3(x)
        x = self.sep_conv_3(x)
        x += identity
        
        return x


class MiddleModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.ReLU(),
            SeparableConv(728, 728, 3, padding = 1),
            nn.ReLU(),
            SeparableConv(728, 728, 3, padding = 1),
            nn.ReLU(),
            SeparableConv(728, 728, 3, padding = 1)
        )

    def forward(self, x):
        identity = x.clone()
        x = self.conv(x)
        return x + identity

class ExitModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_sep_1 = nn.Sequential(
            nn.ReLU(),
            SeparableConv(728, 728, 3, padding = 1),
            nn.ReLU(),
            SeparableConv(728, 1024, 3, padding = 1),
            nn.MaxPool2d(3, 2, 1)
        )
        self.identity = nn.Sequential(
            ConvModule(728, 1024, 1, stride = 2),
            nn.BatchNorm2d(1024)
        )
        self.conv_sep_2 = nn.Sequential(
            nn.ReLU(),
            SeparableConv(1024, 1536, 3),
            nn.ReLU(),
            SeparableConv(1536, 2048, 3),
            nn.AdaptiveAvgPool2d(1)
        )

    def forward(self, x):
        identity = self.identity(x)
        x = self.conv_sep_1(x)
        x += identity
        x = self.conv_sep_2(x)
        return x


class Xception(nn.Module, Trainable):
    def __init__(self, in_channels):
        nn.Module.__init__(self)
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.conv = nn.Sequential(
            EntryModule(in_channels),
            MiddleModule(),
            MiddleModule(),
            MiddleModule(),
            MiddleModule(),
            MiddleModule(),
            MiddleModule(),
            MiddleModule(),
            ExitModule()
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048, 512),
            nn.Dropout(),
            nn.Linear(512, 10)
        )
        Trainable.__init__(self, 100, nn.CrossEntropyLoss(), 32)
        self = self.to(self.device)

    def forward(self, x):
        return self.fc(self.conv(x))

    def fit(self, ds):
        self.optim = torch.optim.SGD(self.parameters(), lr = 0.045, momentum = 0.9, weight_decay = 0.00001)
        #self.sched = torch.optim.lr_scheduler.StepLR(self.optim, 2, gamma = 0.94)
        Trainable.fit(self, ds)


In [None]:
#@title MobileNet
## https://arxiv.org/pdf/1704.04861.pdf

class ConvModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)

class ConvDW(nn.Module):
    # depthwise separable 2D convolution
    def __init__(self, in_channels, out_channels, alpha = 1, **kwargs):
        super().__init__()
        in_channels *= alpha
        out_channels *= alpha
        self.depthwise = ConvModule(in_channels, in_channels, 3, padding = 1, groups = in_channels, **kwargs)
        self.pointwise = ConvModule(in_channels, out_channels, 1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class MobileNet(nn.Module, Trainable):
    def __init__(self, in_channels, alpha = 1, ro = 1):
        nn.Module.__init__(self)
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        # alpha - hyperparamter changing the default number of kernels (alpha * nr_kernels)
        # ro - hyperparameter changing the default image size (ro * H x ro * W)
        self.conv = nn.Sequential(
            ConvModule(in_channels, 32, 3, stride = 2, padding = 1),
            ConvDW(32, 64),
            ConvDW(64, 128, stride = 2),
            ConvDW(128, 128),
            ConvDW(128, 256, stride = 2),
            ConvDW(256, 256),
            ConvDW(256, 512, stride = 2),
            ConvDW(512, 512),
            ConvDW(512, 512),
            ConvDW(512, 512),
            ConvDW(512, 512),
            ConvDW(512, 512),
            ConvDW(512, 1024, stride = 2),
            ConvDW(1024, 1024),
            nn.AvgPool2d(7)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024, 10)
        )
        Trainable.__init__(self, 100, nn.CrossEntropyLoss(), 32)
        self = self.to(self.device)
    
    def forward(self, x):
        return self.fc(self.conv(x))

    def fit(self, ds):
        #self.optim = torch.optim.RMSprop(self.parameters(), lr = 0.1, momentum = 0.9)
        self.optim = torch.optim.Adam(self.parameters(), lr = 0.001)
        #self.sched = torch.optim.lr_scheduler.StepLR(self.optim, 8, gamma = 0.1)
        Trainable.fit(self, ds)

In [None]:
x = torch.randn(1, 3, 224, 224)
net = VGG16(1)
net.fit(MNIST)

#TODO: VGG16

Train loss # 0 : 0.144
Test loss: 0.144 Accuracy: 11.236666666666666
Train loss # 1 : 0.144
Test loss: 0.144 Accuracy: 11.236666666666666
