## NiN on CIFAR 10

In [None]:
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

'''
Step 1:
'''

transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

train_dataset = datasets.CIFAR10(root='./cifar_10data/',
                                 train=True, 
                                 transform=transform,
                                 download=True)

test_dataset = datasets.CIFAR10(root='./cifar_10data/',
                                train=False, 
                                transform=transforms.ToTensor())
    
'''
Step 2
'''
class NiN(nn.Module) :
    def __init__(self) :
        super(NiN, self).__init__()
        
        self.mlpconv_layer1 = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, padding=2),    # 192 * 32 * 32
                nn.ReLU(),
                nn.Conv2d(192, 160, kernel_size=1),             # 160 * 32 * 32
                nn.ReLU(),
                nn.Conv2d(160, 96, kernel_size=1),              # 96 * 32 * 32
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),# 96 * 16 * 16
                nn.Dropout()
                )
        self.mlpconv_layer2 = nn.Sequential(
                nn.Conv2d(96, 192, kernel_size=5, padding=2),   # 192 * 16 * 16
                nn.ReLU(),
                nn.Conv2d(192, 192, kernel_size=1),             # 192 * 16 * 16
                nn.ReLU(),
                nn.Conv2d(192, 192, kernel_size=1),             # 192 * 16 * 16
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),# 192 * 8 * 8
                nn.Dropout()
                )
        self.mlpconv_layer3 = nn.Sequential(
                nn.Conv2d(192, 192, kernel_size=3, padding=1),  # 192 * 8 * 8
                nn.ReLU(),
                nn.Conv2d(192, 192, kernel_size=1),             # 192 * 8 * 8
                nn.ReLU(),
                nn.Conv2d(192, 10, kernel_size=1),              # 10 * 8 * 8
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=8)                     # 10 * 1 * 1
                )
        
    
    def forward(self, x) :
        output = self.mlpconv_layer1(x)
        output = self.mlpconv_layer2(output)
        output = self.mlpconv_layer3(output)
        output = output.view(-1, 10)
        return output


'''
Step 3
'''
model = NiN().to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=0.00001)


'''
Step 4
'''
model.train()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

import time
start = time.time()
for epoch in range(100) :
    print("{}th epoch starting.".format(epoch))
    for i, (images, labels) in enumerate(train_loader) :
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        train_loss = loss_function(model(images), labels)
        train_loss.backward()

        optimizer.step()

    print ("Epoch [{}] Loss: {:.4f}".format(epoch+1, train_loss.item()))

end = time.time()
print("Time ellapsed in training is: {}".format(end - start))


'''
Step 5
'''
model.eval()
test_loss, correct, total = 0, 0, 0

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)
with torch.no_grad():
    for images, labels in test_loader :
        images, labels = images.to(device), labels.to(device)

        output = model(images)
        test_loss += loss_function(output, labels).item()

        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(labels.view_as(pred)).sum().item()

        total += labels.size(0)

print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss /total, correct, total,
        100. * correct / total))