## AlexNet on CIFAR 10

In [None]:
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

'''
Step 1:
'''

transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

train_dataset = datasets.CIFAR10(root='./cifar_10data/',
                                 train=True, 
                                 transform=transform,
                                 download=True)

test_dataset = datasets.CIFAR10(root='./cifar_10data/',
                                train=False, 
                                transform=transforms.ToTensor())
    


'''
Step 2
'''

class AlexNet(nn.Module) :
    
    def __init__(self, num_class=10) :
        super(AlexNet, self).__init__()
        
        self.conv_layer1 = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=4),
                nn.ReLU(),
                nn.Conv2d(96, 96, kernel_size=3),
                nn.ReLU()
                )
        self.conv_layer2 = nn.Sequential(
                nn.Conv2d(96, 256, kernel_size=5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2)
                )
        self.conv_layer3 = nn.Sequential(
                nn.Conv2d(256, 384, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(384, 384, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(384, 256, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2)
                )
        
        self.fc_layer1 = nn.Sequential(
                nn.Dropout(),
                nn.Linear(9216, 4096),
                nn.ReLU(),
                nn.Dropout(),  #p=0.5 by default
                nn.Linear(4096, 4096),
                nn.ReLU(),     #p=0.5 by default
                nn.Linear(4096, 10)
                )
    
    def forward(self, x) :
        output = self.conv_layer1(x)
        output = self.conv_layer2(output)
        output = self.conv_layer3(output)
        output = output.view(-1, 9216)
        output = self.fc_layer1(output)
        return output

    

'''
Step 3
'''
model = AlexNet().to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, weight_decay=0.00005)

'''
Step 4
'''
model.train()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

import time
start = time.time()
for epoch in range(100) :
    print("{}th epoch starting.".format(epoch))
    for i, (images, labels) in enumerate(train_loader) :
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        train_loss = loss_function(model(images), labels)
        train_loss.backward()

        optimizer.step()

    print ("Epoch [{}] Loss: {:.4f}".format(epoch+1, train_loss.item()))

end = time.time()
print("Time ellapsed in training is: {}".format(end - start))


'''
Step 5
'''
model.eval()
test_loss, correct, total = 0, 0, 0

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)
with torch.no_grad():  #using context manager
    for images, labels in test_loader :
        images, labels = images.to(device), labels.to(device)

        output = model(images)
        test_loss += loss_function(output, labels).item()

        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(labels.view_as(pred)).sum().item()

        total += labels.size(0)

print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss /total, correct, total,
        100. * correct / total))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar_10data/cifar-10-python.tar.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./cifar_10data/cifar-10-python.tar.gz to ./cifar_10data/
0th epoch starting.
Epoch [1] Loss: 2.3054
1th epoch starting.
Epoch [2] Loss: 2.3034
2th epoch starting.
Epoch [3] Loss: 2.2971
3th epoch starting.
