In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader 
import torchvision
from torchvision.transforms import Compose, Normalize, RandomHorizontalFlip, RandomVerticalFlip, ToTensor
from torchvision.datasets import CIFAR10, CIFAR100, LSUN

In [None]:
BATCH_SIZE = 512
LR = 1e-3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLASSES = 10

In [None]:
transform = Compose([
                     ToTensor(), 
                     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                     RandomHorizontalFlip(),
                     RandomVerticalFlip()

])

In [None]:
train_data_cifar10 = CIFAR10(root='train_data_cifar10', download=True, train=True, transform=transform)
test_data_cifar10 = CIFAR10(root='test_data_cifar10', download=True, train=False, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to train_data_cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting train_data_cifar10/cifar-10-python.tar.gz to train_data_cifar10
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to test_data_cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting test_data_cifar10/cifar-10-python.tar.gz to test_data_cifar10


In [None]:
train_data_cifar10_dl = DataLoader(train_data_cifar10, batch_size=BATCH_SIZE, shuffle=True)
test_data_cifar10_dl = DataLoader(test_data_cifar10, batch_size=BATCH_SIZE)

In [None]:
architecture = {
    'conv': [(3, 64, 7, 2, 3), 'M', (64, 192, 3, 1, 1), 'M'],
    'inception3': [(192, 64, 96, 128, 16, 32, 32), (256, 128, 128, 192, 32, 96, 64), 'M'],
    'inception4': [(480, 192, 96, 208, 16, 48, 64), (512, 160, 112, 224, 24, 64, 64), 
                   (512, 128, 128, 256, 24, 64, 64), (512, 112, 144, 256, 32, 64, 64), 
                   (496, 256, 160, 320, 32, 128, 128), 'A'],
    'inception5': [(832, 256, 160, 320, 32, 128, 128), 'M', (832, 384, 192, 384, 48, 128, 128)],
    'linear': [4096, 1024, 512, 256, NUM_CLASSES]
}

In [None]:
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super(conv_block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU()
        self.batchnorm = nn.BatchNorm2d(out_channels)
    def forward(self, x):
        return self.relu(self.batchnorm(self.conv1(x)))

In [None]:
class inception_block(nn.Module):
    def __init__(self, in_channels, out_1, in_3, out_3, in_5, out_5, out_1_pool):
        super(inception_block, self).__init__()
        self.branch1 = conv_block(in_channels, out_1, kernel_size=1)
        self.branch2 = nn.Sequential(
            conv_block(in_channels, in_3, kernel_size=1),
            conv_block(in_3, out_3, kernel_size=3, padding=1)
        )
        self.branch3 = nn.Sequential(
            conv_block(in_channels, in_5, kernel_size=1),
            conv_block(in_5, out_5, kernel_size=5, padding=2)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            conv_block(in_channels, out_1_pool, kernel_size=1)
        )
    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim=1)

In [None]:
class google_net(nn.Module):
    def __init__(self, architecture, in_features, in_channels, NUM_CLASSES=10):
        super(google_net, self).__init__()
        self.architecture = architecture
        self.in_channels = in_channels
        self.in_features = in_features
        self.classes = NUM_CLASSES
        self.net = self.create_net(architecture)

    def create_net(self, architecture):
        layers = []
        for key in architecture:
            if key == 'conv':
                for layer in architecture['conv']:
                    if type(layer) == tuple:
                        layers += [conv_block(*layer)]
                    elif layer == 'M':
                        layers += [nn.MaxPool2d(3, 2, 1)]
            elif key == 'inception3':
                for layer in architecture['inception3']:
                    if type(layer) == tuple:
                        layers += [inception_block(*layer)]
                    elif layer == 'M':
                        layers += [nn.MaxPool2d(3, 2, 1)]
                for layer in architecture['inception4']:
                    if type(layer) == tuple:
                        layers += [inception_block(*layer)]
                    elif layer == 'M':
                        layers += [nn.MaxPool2d(3, 2, 1)]
                for layer in architecture['inception5']:
                    if type(layer) == tuple:
                        layers += [inception_block(*layer)]
                    elif layer == 'A':
                        layers += [nn.AvgPool2d(7, 1)]
        
            elif key == 'linear':
                layers += [nn.Flatten()]
                for layer in architecture['linear']:
                        layers += [nn.Linear(self.in_features, layer)]
        
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.net(x)
        #print(x.shape)
        return x

In [None]:
model = google_net(architecture, 1024*2*2, 3).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
cel = nn.CrossEntropyLoss()

In [None]:
print(model)

google_net(
  (net): Sequential(
    (0): conv_block(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (relu): ReLU()
      (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (2): conv_block(
      (conv1): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU()
      (batchnorm): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): inception_block(
      (branch1): conv_block(
        (conv1): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
        (relu): ReLU()
        (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (branch2): Sequential(
        (0): conv_block(
          (conv1): Conv2d(192, 96

In [None]:
def validate(model, data, device=DEVICE):
    correct = 0
    total = 0
    for idx, (images, labels) in enumerate(data):
        with torch.no_grad():
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            values, pred = torch.max(output, 1)
            correct += torch.sum(pred == labels)
            total += output.size(0)
    return (correct/total) * 100

In [None]:
def train(model, data, epochs=20, device=DEVICE):

    for e in range(epochs):
        running_loss = 0
        for idx, (images, labels) in enumerate(data):
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            optimizer.zero_grad()
            loss = cel(output, labels)
            running_loss += loss
            loss.backward()
            optimizer.step()
        accuracy = float(validate(model, test_data_cifar10_dl))
        print(f'Epcoh {e}, The accuracy is {accuracy}')

In [None]:
train(model, train_data_cifar10_dl, epochs=50)
'''
x = torch.randn(1, 3, 32, 32).to(DEVICE)
output = model(x)
print(output.shape)
'''


Epcoh 0, The accuracy is 48.0
Epcoh 1, The accuracy is 55.05999755859375
Epcoh 2, The accuracy is 60.89999771118164
Epcoh 3, The accuracy is 62.18000030517578
Epcoh 4, The accuracy is 65.68000030517578
Epcoh 5, The accuracy is 68.11000061035156
Epcoh 6, The accuracy is 69.47999572753906
Epcoh 7, The accuracy is 69.18000030517578
Epcoh 8, The accuracy is 70.76000213623047
Epcoh 9, The accuracy is 70.91999816894531
Epcoh 10, The accuracy is 71.9000015258789
Epcoh 11, The accuracy is 73.72999572753906
Epcoh 12, The accuracy is 74.15999603271484
Epcoh 13, The accuracy is 74.6199951171875
Epcoh 14, The accuracy is 74.79999542236328
Epcoh 15, The accuracy is 74.83999633789062
Epcoh 16, The accuracy is 75.68000030517578
Epcoh 17, The accuracy is 75.5199966430664
Epcoh 18, The accuracy is 76.0
Epcoh 19, The accuracy is 76.94999694824219
Epcoh 20, The accuracy is 75.76000213623047
Epcoh 21, The accuracy is 76.0999984741211
Epcoh 22, The accuracy is 76.9000015258789
Epcoh 23, The accuracy is 76.

'\nx = torch.randn(1, 3, 32, 32).to(DEVICE)\noutput = model(x)\nprint(output.shape)\n'