In [16]:
import os
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse
import torch
import torch.nn as nn
from torch import Tensor
from typing import Any, Callable, List, Optional, Tuple
from easydict import EasyDict

# Model

In [17]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
    
    def forward(self, x:Tensor):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

In [18]:
class Inception(nn.Module):
    def __init__(self, in_channels, n1x1, n3x3_reduce, n3x3, n5x5_reduce, n5x5, pool_proj):
        super(Inception, self).__init__()
        self.branch1 = ConvBlock(in_channels, n1x1, kernel_size = 1, stride = 1, padding = 0)

        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, n3x3_reduce, kernel_size = 1, stride = 1, padding = 0),
            ConvBlock(n3x3_reduce, n3x3, kernel_size = 3, stride = 1, padding = 1)
        )
        
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, n5x5_reduce, kernel_size = 1, stride = 1, padding = 0),
            ConvBlock(n5x5_reduce, n5x5, kernel_size = 5, stride = 1, padding = 2)
        )
        
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
            ConvBlock(in_channels, pool_proj, kernel_size = 1, stride = 1, padding = 0)
        )

    def forward(self, x:Tensor):
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x4 = self.branch4(x)

        return torch.cat([x1, x2, x3, x4], dim = 1)

In [4]:
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.avgpool = nn.AvgPool2d(kernel_size = 5, stride = 3)
        self.conv = ConvBlock(in_channels, 128, kernel_size = 1, stride = 1, padding = 0)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
        self.dropout = nn.Dropout(p = 0.7)
        self.relu = nn.ReLU()
    
    def forward(self, x:Tensor):
        x = self.avgpool(x)
        x = self.conv(x)
        # flatten
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

In [5]:
class GoogLeNet(nn.Module):
    def __init__(self, aux_logits = True, num_classes = 1000):
        super(GoogLeNet, self).__init__()
        # if aux_logits not boolean, assert
        assert aux_logits == True or aux_logits == False
        self.aux_logits = aux_logits

        # 224 x 224 x 3 -> 112 x 112 x 64, 7x7+2(S)
        self.conv1 = ConvBlock(in_channels = 3, out_channels = 64, kernel_size = 7, stride = 2, padding = 3)
        # maxpool with ceil, 3x3 + 2(S), 112 x 112 x 64 -> 56 x 56 x 64
        self.maxpool1 = nn.MaxPool2d(kernel_size = 3, stride = 2, ceil_mode = True)
        # 56 x 56 x 64 -> 56 x 56 x 64, 3x3 reduce
        self.conv2 = ConvBlock(in_channels = 64, out_channels = 64, kernel_size = 1, stride = 1, padding = 0)
        # 56 x 56 x 64 -> 56 x 56 x 192
        self.conv3 = ConvBlock(in_channels = 64, out_channels = 192, kernel_size = 3, stride = 1, padding = 1)
        # 56 x 56 x 192 -> 28 x 28 x 192
        self.maxpool2 = nn.MaxPool2d(kernel_size = 3, stride = 2, ceil_mode = True)

        # 28 x 28 x 192 -> 28 x 28 x 256
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        # 28 x 28 x 256 -> 28 x 28 x 480
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        # 28 x 28 x 480 -> 14 x 14 x 480
        self.maxpool3 = nn.MaxPool2d(kernel_size = 3, stride = 2, ceil_mode = True)

        # 14 x 14 x 480 -> 14 x 14 x 512
        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        # 14 x 14 x 512 -> 14 x 14 x 512
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        # 14 x 14 x 512 -> 14 x 14 x 512
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        # 14 x 14 x 512 -> 14 x 14 x 528
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        # 14 x 14 x 528 -> 14 x 14 x 832
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        # 14 x 14 x 832 -> 7 x 7 x 832
        self.maxpool4 = nn.MaxPool2d(kernel_size = 3, stride = 2, ceil_mode = True)

        # 7 x 7 x 832 -> 7 x 7 x 832
        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        # 7 x 7 x 832 -> 7 x 7 x 1024
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        # 7 x 7 x 1024 -> 1 x 1 x 1024
        self.avgpool = nn.AvgPool2d(kernel_size = 7, stride = 1)
        self.dropout = nn.Dropout(p = 0.4)
        # 1024 -> 1000(num_classes)
        self.linear = nn.Linear(1024, num_classes)


        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes) # inception4a
            self.aux2 = InceptionAux(528, num_classes) # inception4d

        else:
            self.aux1 = None
            self.aux2 = None
    
    def forward(self, x:Tensor):

        x = self.conv1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)
        
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)
        
        x = self.inception4a(x)
        
        if self.aux_logits and self.training:
            aux1 = self.aux1(x)
        else:
            aux2 = None
        
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        
        if self.aux_logits and self.training:
            aux2 = self.aux2(x)
        else:
            aux2 = None

        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)
        x = self.avgpool(x)
        # N x 1024 x 1 x 1 -> N x 1024
        x = x.view(x.size()[0], -1)
        x = self.linear(x)
        x = self.dropout(x)

        if self.aux_logits and self.training:
            return x, aux1, aux2
        else:
            return x

In [6]:
if __name__ == "__main__":
    x = torch.randn(3, 3, 224, 224)
    model = GoogLeNet(aux_logits = True, num_classes = 1000)
    print(model(x)[1].shape)

torch.Size([3, 1000])


# Train  
Train model by CIFAR-10 datasets

In [9]:
def load_dataset():
    # preprocess
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Resize((224, 224))
    ])

    # load data
    train = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    test = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
    train_loader = DataLoader(train, batch_size=args.batch_size, shuffle = True)
    test_loader = DataLoader(test, batch_size=args.batch_size, shuffle=True)
    return train_loader, test_loader

In [15]:
if __name__ == "__main__":
    args = EasyDict()
    args.batch_size = 100
    args.learning_rate = 0.0002
    args.n_epochs = 100
    args.plot = True

    np.random.seed(1)
    seed = torch.manual_seed(1)

    # load dataset
    train_loader, test_loader = load_dataset()

    # model, loss, optimizer
    losses = []
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print("we use GPU")
    else:
        print("we use CPU")
    model = GoogLeNet(aux_logits = True, num_classes = 10).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    # train
    for epoch in range(args.n_epochs):
        model.train()
        train_loss = 0
        correct, count = 0,0
        for batch_idx, (images, labels) in enumerate(train_loader, start=1):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            output, aux1, aux2 = model.forward(images)
            loss_output = criterion(output, labels)
            loss_aux1 = criterion(aux1, labels)
            loss_aux2 = criterion(aux2, labels)
            loss = loss_output + 0.3 * (loss_aux1 + loss_aux2)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(output, 1) # torch max output is max, max_index
            count += labels.size(0)
            correct += torch.sum(preds==labels)

            if batch_idx % 10 == 0:
                print(f"[*] Epoch: {epoch} \tStep: {batch_idx}/{len(train_loader)}\tTrain accuracy: {correct/count} \tTrain Loss: {(train_loss/count)*100}")

Files already downloaded and verified
Files already downloaded and verified
we use GPU
[*] Epoch: 13 	Step: 260/500	Train accuracy: 0.6883845925331116 	Train Loss: 0.7185489911299485
[*] Epoch: 13 	Step: 270/500	Train accuracy: 0.6874814629554749 	Train Loss: 0.7199741412092138
[*] Epoch: 13 	Step: 280/500	Train accuracy: 0.6872857213020325 	Train Loss: 0.7211960792541504
[*] Epoch: 13 	Step: 290/500	Train accuracy: 0.6872414350509644 	Train Loss: 0.7212720498956483
[*] Epoch: 13 	Step: 300/500	Train accuracy: 0.6869666576385498 	Train Loss: 0.7218428881963095
[*] Epoch: 13 	Step: 310/500	Train accuracy: 0.6867096424102783 	Train Loss: 0.7223478286497055
[*] Epoch: 13 	Step: 320/500	Train accuracy: 0.6864375472068787 	Train Loss: 0.7223972564563155
[*] Epoch: 13 	Step: 330/500	Train accuracy: 0.6858181953430176 	Train Loss: 0.7234349940762376
[*] Epoch: 13 	Step: 340/500	Train accuracy: 0.68644118309021 	Train Loss: 0.7232403337955474


In [21]:
if __name__ == "__main__":
    model.eval()
    correct, count = 0,0
    valid_loss = 0
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(test_loader, start=1):
            images, labels = images.to(device), labels.to(device)
            output = model.forward(images)
            loss = criterion(output, labels)
            valid_loss += loss.item()
            _, preds = torch.max(output, 1)
            count += labels.size(0)
            correct += torch.sum(preds==labels)
            if batch_idx % 10 == 0:
                print(f"[*] Step: {batch_idx}/{len(test_loader)}\tValid accuracy: {correct/count} \tValid Loss: {(valid_loss/count)*100}")

[*] Step: 10/100	Valid accuracy: 0.8630000352859497 	Valid Loss: 0.45474811792373654
[*] Step: 20/100	Valid accuracy: 0.862000048160553 	Valid Loss: 0.4923530235886574
[*] Step: 30/100	Valid accuracy: 0.8659999966621399 	Valid Loss: 0.487832265595595
[*] Step: 40/100	Valid accuracy: 0.8662500381469727 	Valid Loss: 0.4847776528447866
[*] Step: 50/100	Valid accuracy: 0.8673999905586243 	Valid Loss: 0.48437747985124585
[*] Step: 60/100	Valid accuracy: 0.8693333268165588 	Valid Loss: 0.47821816181143123
[*] Step: 70/100	Valid accuracy: 0.8684285879135132 	Valid Loss: 0.4816110055361475
[*] Step: 80/100	Valid accuracy: 0.8676250576972961 	Valid Loss: 0.4828685568645597
[*] Step: 90/100	Valid accuracy: 0.8684444427490234 	Valid Loss: 0.48009414788749477
[*] Step: 100/100	Valid accuracy: 0.8672999739646912 	Valid Loss: 0.4842480953037739
