In [1]:
import os
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor
from typing import Any, Callable, List, Optional, Tuple
from easydict import EasyDict

from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


## Model
Making Model(Inception-ResNet-v2) by PyTorch

In [2]:
# Basic Convolutional Module
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, bias = False, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv(x)
        return x

In [3]:
# Make Stem of Inception-v4 and Inception-ResNet-v2
class Stem(nn.Module):
    def __init__(self):
        super(Stem, self).__init__()

        # 299x299x3 -> 147x147x64
        self.conv1 = nn.Sequential(
            BasicConv2d(3, 32, kernel_size = 3, stride = 2, padding = "valid"),
            BasicConv2d(32, 32, kernel_size = 3, stride = 1, padding="valid"),
            BasicConv2d(32, 64, kernel_size = 3, stride = 1, padding = "same")
        )

        # 147x147x64 -> 73x73x160(64/96)
        self.branch1a = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 0)
        self.branch1b = BasicConv2d(64, 96, kernel_size = 3, stride = 2, padding ="valid")

        # 73x73x160 -> 71x71x192(96/96)
        self.branch2a = nn.Sequential(
            BasicConv2d(160, 64, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(64, 96, kernel_size = 3, stride = 1, padding = "valid")
        )

        self.branch2b = nn.Sequential(
            BasicConv2d(160, 64, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(64, 64, kernel_size = (7, 1), stride = 1, padding = "same"),
            BasicConv2d(64, 64, kernel_size = (1, 7), stride = 1, padding = "same"),
            BasicConv2d(64, 96, kernel_size = 3, stride = 1, padding = "valid")
        )

        self.branch3a = BasicConv2d(192, 192, kernel_size = 3, stride = 2, padding = "valid")

        # 71x71x192 -> 35x35x384(192/192)
        # if padding = 0(valid), kernel = 3, stride = 2 / if padding = 1, kernel = 4, stride = 2
        self.branch3b = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 0)

    def forward(self, x):
        x = self.conv1(x)
        # why dim=1? because we use mini-batch, so input will be (batch_size, channels, x, y)
        x = torch.cat((self.branch1a(x), self.branch1b(x)), dim=1)
        x = torch.cat((self.branch2a(x), self.branch2b(x)), dim=1)
        x = torch.cat((self.branch3a(x), self.branch3b(x)), dim=1)
        return x

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
x = torch.randn((3, 3, 299, 299)).to(device)
model = Stem().to(device)
output_Stem = model(x)
print('Input size: ', x.size())
print('Stem output size: ', output_Stem.size())

cuda
Input size:  torch.Size([3, 3, 299, 299])
Stem output size:  torch.Size([3, 384, 35, 35])


In [5]:
class Inception_ResNet_A(nn.Module):
    def __init__(self, in_channels):
        super(Inception_ResNet_A, self).__init__()
        
        # 35x35x384 -> 35x35x32
        self.branch1 = BasicConv2d(in_channels, 32, kernel_size = 1, stride = 1, padding = "same")

        # 35x35x384 -> 35x35x32
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, 32, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(32, 32, kernel_size = 3, stride = 1, padding = "same")
        )

        # 35x35x384 -> 35x35x64
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, 32, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(32, 48, kernel_size = 3, stride = 1, padding = "same"),
            BasicConv2d(48, 64, kernel_size = 3, stride = 1, padding = "same")
        )
        # after concat, 35x35x128
        # 35x35x128 -> 35x35x384
        self.dimchange = nn.Conv2d(128, 384, kernel_size = 1, stride = 1, padding = "same")

        self.bn = nn.BatchNorm2d(384)

        self.relu = nn.ReLU()

    def forward(self, x):
        x_shortcut = x
        x = torch.cat((self.branch1(x), self.branch2(x), self.branch3(x)), dim=1)
        x = self.dimchange(x)
        x = self.bn(x_shortcut + x)
        x = self.relu(x)
        return x

In [6]:
model = Inception_ResNet_A(output_Stem.size()[1]).to(device)
output_resA = model(output_Stem)
print("Input size: ", output_Stem.size())
print("Output size: ", output_resA.size())

Input size:  torch.Size([3, 384, 35, 35])
Output size:  torch.Size([3, 384, 35, 35])


In [7]:
class Reduction_A(nn.Module):
    def __init__(self, in_channels, k, l, m, n):
        # in Inception-ResNet-v2, (k,l,m,n) = (256,256,384,384)
        super(Reduction_A, self).__init__()

        # 35x35x384 -> 17x17x384
        self.branch1 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 0)
        # 35x35x384 -> 17x17xn
        self.branch2 = BasicConv2d(in_channels, n, kernel_size = 3, stride = 2, padding = "valid")
        # 35x35x384 -> 17x17xm
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, k, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(k, l, kernel_size = 3, stride = 1, padding = "same"),
            BasicConv2d(l, m, kernel_size = 3, stride = 2, padding = "valid")
        )
        # after concat, 17x17x(384+m+n), in Inception-ResNet-v2,17x17x1152

    def forward(self, x):
        x = torch.cat((self.branch1(x), self.branch2(x), self.branch3(x)), dim=1)
        return x

In [8]:
model = Reduction_A(output_resA.size()[1], 256, 256, 384, 384).to(device)
output_redA = model(output_resA)
print("Input size: ", output_resA.size())
print("Output size: ", output_redA.size())

Input size:  torch.Size([3, 384, 35, 35])
Output size:  torch.Size([3, 1152, 17, 17])


In [9]:
class Inception_ResNet_B(nn.Module):
    def __init__(self, in_channels):
        super(Inception_ResNet_B, self).__init__()

        # 17x17x1152 -> 17x17x192
        self.branch1 = BasicConv2d(in_channels, 192, kernel_size = 1, stride = 1, padding = "same")

        # 17x17x1152 -> 17x17x192
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, 128, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(128, 160, kernel_size = (1,7), stride = 1, padding = "same"),
            BasicConv2d(160, 192, kernel_size = (7,1), stride = 1, padding = "same")
        )
        # after concat, 17x17x384 -> 17x17x1152
        self.dimchange = nn.Conv2d(384, 1152, kernel_size = 1, stride = 1, padding = "same")

        self.bn = nn.BatchNorm2d(1152)

        self.relu = nn.ReLU()
    
    def forward(self, x):
        x_shortcut =  x
        x = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        x = self.dimchange(x)
        x = self.bn(x_shortcut + x)
        x = self.relu(x)
        return x

In [10]:
model = Inception_ResNet_B(output_redA.size()[1]).to(device)
output_resB = model(output_redA)
print("Input size: ", output_redA.size())
print("Output size: ", output_resB.size())

Input size:  torch.Size([3, 1152, 17, 17])
Output size:  torch.Size([3, 1152, 17, 17])


In [11]:
class Reduction_B(nn.Module):
    def __init__(self, in_channels):
        super(Reduction_B, self).__init__()

        # 17x17x1152 -> 8x8x1152
        self.branch1 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 0)

        # 17x17x1152 -> 8x8x384
        self.branch2 = nn.Sequential(
            BasicConv2d(1152, 256, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(256, 384, kernel_size = 3, stride = 2, padding = "valid")
        )
        
        # 17x17x1152 -> 8x8x288
        self.branch3 = nn.Sequential(
            BasicConv2d(1152, 256, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(256, 288, kernel_size = 3, stride = 2, padding = "valid")
        )

        # 17x17x1152 -> 8x8x320
        self.branch4 = nn.Sequential(
            BasicConv2d(1152, 256, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(256, 288, kernel_size = 3, stride = 1, padding = "same"),
            BasicConv2d(288, 320, kernel_size = 3, stride = 2, padding = "valid")
        )
        # after concat, 8x8x2144
    
    def forward(self, x):
        x = torch.cat((self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)), dim=1)
        return x

In [12]:
model = Reduction_B(output_resB.size()[1]).to(device)
output_redB = model(output_resB)
print("Input size: ", output_resB.size())
print("Output size: ", output_redB.size())

Input size:  torch.Size([3, 1152, 17, 17])
Output size:  torch.Size([3, 2144, 8, 8])


In [13]:
class Inception_ResNet_C(nn.Module):
    def __init__(self, in_channels):
        super(Inception_ResNet_C, self).__init__()

        # 8x8x2144 -> 8x8x192
        self.branch1 = BasicConv2d(in_channels, 192, kernel_size = 1, stride = 1, padding = "same")

        # 8x8x2144 -> 8x8x256
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, 192, kernel_size = 1, stride = 1, padding = "same"),
            BasicConv2d(192, 224, kernel_size = (1,3), stride = 1, padding = "same"),
            BasicConv2d(224, 256, kernel_size = (3,1), stride = 1, padding = "same")
        )
        # after concat, 8x8x384 -> 8x8x2144 (I don't know why not paper says 2048)
        self.dimchange = nn.Conv2d(448, 2144, kernel_size = 1, stride = 1, padding = "same")

        self.bn = nn.BatchNorm2d(2144)

        self.relu = nn.ReLU()
    
    def forward(self, x):
        x_shortcut =  x
        x = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        x = self.dimchange(x)
        x = self.bn(x_shortcut + x)
        x = self.relu(x)
        return x

In [14]:
model = Inception_ResNet_C(output_redB.size()[1]).to(device)
output_resC = model(output_redB)
print("Input size: ", output_redB.size())
print("Output size: ", output_resC.size())

Input size:  torch.Size([3, 2144, 8, 8])
Output size:  torch.Size([3, 2144, 8, 8])


In [15]:
class Inception_ResNet_V2(nn.Module):
    def __init__(self, A, B, C, k=256, l=256, m=384, n=384, num_classes=10, init_weight=True):
        super(Inception_ResNet_V2, self).__init__()
        blocks = []
        blocks.append(Stem())
        for i in range(A):
            blocks.append(Inception_ResNet_A(384))
        blocks.append(Reduction_A(384, k, l, m, n))
        for i in range(B):
            blocks.append(Inception_ResNet_B(1152))
        blocks.append(Reduction_B(1152))
        for i in range(C):
            blocks.append(Inception_ResNet_C(2144))
        
        self.blocks = nn.Sequential(*blocks)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(p = 0.2)
        self.linear = nn.Linear(2144, num_classes)

        if init_weight:
            self._initialize_weights()
        
    def forward(self, x):
        x = self.blocks(x)
        x = self.avgpool(x)
        x = nn.Flatten()(x)
        x = self.dropout(x)
        x = self.linear(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    # if bias exist, bias = 0
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

In [16]:
# in paper, only Inception_ResNet_v1 architecture shows number, (5, 10, 5)
# but in Google Research Blog, it shows Inception_ResNet_v2 number, (10, 20, 10)
# https://ai.googleblog.com/2016/08/improving-inception-and-image.html?m=1
model = Inception_ResNet_V2(10, 20, 10).to(device)
from torchsummary import summary
summary(model, (3, 299, 299), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             864
       BatchNorm2d-2         [-1, 32, 149, 149]              64
              ReLU-3         [-1, 32, 149, 149]               0
       BasicConv2d-4         [-1, 32, 149, 149]               0
            Conv2d-5         [-1, 32, 147, 147]           9,216
       BatchNorm2d-6         [-1, 32, 147, 147]              64
              ReLU-7         [-1, 32, 147, 147]               0
       BasicConv2d-8         [-1, 32, 147, 147]               0
            Conv2d-9         [-1, 64, 147, 147]          18,432
      BatchNorm2d-10         [-1, 64, 147, 147]             128
             ReLU-11         [-1, 64, 147, 147]               0
      BasicConv2d-12         [-1, 64, 147, 147]               0
        MaxPool2d-13           [-1, 64, 73, 73]               0
           Conv2d-14           [-1, 96,

## Train
Train Model by CIFAR-10

In [17]:
def load_dataset():
    # preprocess
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Resize((299, 299))
    ])

    # load data
    train = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True) 
    test = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
    train_loader = DataLoader(train, batch_size=args.batch_size, shuffle = True)
    test_loader = DataLoader(test, batch_size=args.batch_size, shuffle=True)
    return train_loader, test_loader

In [20]:
# hyperparameter
args = EasyDict()
args.batch_size = 32
args.learning_rate = 0.005
args.n_epochs = 5

# functions
criterion = nn.CrossEntropyLoss(reduction = 'sum').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

PATH = '/content/gdrive/MyDrive/Colab Notebooks/inceptionv4.pt'
import os.path

epoch_start = 1

if os.path.exists(PATH):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_start = checkpoint['epoch'] + 1
    print("successfully loaded!")
    print("epoch saved until here: ", epoch_start-1)
    print("train starts from this epoch: Epoch ", epoch_start)

In [21]:
import time

# load dataset
train_loader, test_loader = load_dataset()

loss_hist = []
accuracy_hist = []

start_time = time.time()

# train
for epoch in range(epoch_start, args.n_epochs+1):
    model.train()
    train_loss = 0
    correct, count = 0, 0
    for batch_idx, (images, labels) in enumerate(train_loader, start=1):
        images, labels = images.to(device), labels.to(device)
        output = model(images)
        optimizer.zero_grad()
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, preds = torch.max(output, 1) # torch max output is (max, max_index)
        count += labels.size(0)
        correct += torch.sum(preds == labels)
        
    loss_hist.append(train_loss/count)
    accuracy_hist.append(correct/count)
    print(f"[*] Epoch: {epoch} \tTrain accuracy: {correct/count} \tTrain Loss: {train_loss/count}")
    torch.save({
        'epoch' : epoch,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        }, PATH)

end_time = time.time()

print(f"Training time : {end_time - start_time}")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:06<00:00, 28168472.96it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified




[*] Epoch: 1 	Train accuracy: 0.12003999948501587 	Train Loss: 2.3079440895080565


KeyboardInterrupt: ignored

In [None]:
# test
model.eval()
correct, count = 0, 0
test_loss = 0

with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(test_loader, start=1):
        images, labels = images.to(device), labels.to(device)
        output = model(images)
        loss = criterion(output, labels)
        test_loss += loss.item()
        _, preds = torch.max(output, 1)
        count += labels.size(0)
        correct += torch.sum(preds == labels)

print(f"Test accuracy: {correct/count} \tTest Loss: {test_loss/count}")

In [None]:
# plot graph
width = len(loss_hist)
plt.title("Train Loss")
plt.plot(range(epoch_start+1, width+epoch_start+1), loss_hist)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

accuracy_hist2 = torch.stack(accuracy_hist).cpu().tolist()
plt.title("Train Accuracy")
plt.plot(range(epoch_start+1, width+epoch_start+1), accuracy_hist2)
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.show()