# 10. CNN with MNIST

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.init
from torch.autograd import Variable

import torchvision.utils as utils
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import numpy as np
import random
import os

## 10.1 Preparing MNIST Data

In [2]:
mnist_train = dsets.MNIST(root='data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='data/',
                         train=False,
                         transform=transforms.ToTensor(),
                         download=True)

In [3]:
batch_size = 100

train_loader  = torch.utils.data.DataLoader(dataset=mnist_train,
                                           batch_size=batch_size,
                                           shuffle=False)

test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                         batch_size=batch_size,
                                         shuffle=False)

## 10.2 Define Model

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.layer = nn.Sequential(
            nn.Conv2d(1,16,5),
            nn.ReLU(),
            nn.Conv2d(16,32,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(32,64,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
        
        self.fc_layer = nn.Sequential(
            nn.Linear(64*3*3,100),
            nn.ReLU(),
            nn.Linear(100,10)
        )       
        
    def forward(self,x):
        out = self.layer(x)
        out = out.view(batch_size,-1)
        out = self.fc_layer(out)

        return out

In [5]:
model = CNN().cuda()

In [6]:
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [7]:
num_epochs = 5

In [8]:
if os.path.isfile('_.pkl'):
    model.load_state_dict(torch.load('cnn.pkl'))
    print("Model Loaded!")

else:
    
    for epoch in range(num_epochs):

        total_batch = len(mnist_train) // batch_size

        for i, (batch_images, batch_labels) in enumerate(train_loader):

            X = Variable(batch_images).cuda()
            Y = Variable(batch_labels).cuda()

            pre = model(X)
            cost = loss(pre, Y)

            optimizer.zero_grad()
            cost.backward()
            optimizer.step()

            if (i+1) % 100 == 0:
                print('Epoch [%d/%d], lter [%d/%d] Loss: %.4f'
                     %(epoch+1, num_epochs, i+1, total_batch, cost.data[0]))

    if not os.path.isfile('cnn.pkl'):
        print("Model Saved!")
        torch.save(model.state_dict(), 'cnn.pkl')

Epoch [1/5], lter [100/600] Loss: 0.6938
Epoch [1/5], lter [200/600] Loss: 0.2466
Epoch [1/5], lter [300/600] Loss: 0.2362
Epoch [1/5], lter [400/600] Loss: 0.1060
Epoch [1/5], lter [500/600] Loss: 0.1295
Epoch [1/5], lter [600/600] Loss: 0.2647
Epoch [2/5], lter [100/600] Loss: 0.0219
Epoch [2/5], lter [200/600] Loss: 0.0438
Epoch [2/5], lter [300/600] Loss: 0.0688
Epoch [2/5], lter [400/600] Loss: 0.0398
Epoch [2/5], lter [500/600] Loss: 0.0607
Epoch [2/5], lter [600/600] Loss: 0.2283
Epoch [3/5], lter [100/600] Loss: 0.0109
Epoch [3/5], lter [200/600] Loss: 0.0218
Epoch [3/5], lter [300/600] Loss: 0.0232
Epoch [3/5], lter [400/600] Loss: 0.0243
Epoch [3/5], lter [500/600] Loss: 0.0467
Epoch [3/5], lter [600/600] Loss: 0.2067
Epoch [4/5], lter [100/600] Loss: 0.0094
Epoch [4/5], lter [200/600] Loss: 0.0096
Epoch [4/5], lter [300/600] Loss: 0.0126
Epoch [4/5], lter [400/600] Loss: 0.0191
Epoch [4/5], lter [500/600] Loss: 0.0289
Epoch [4/5], lter [600/600] Loss: 0.2089
Epoch [5/5], lte

## 10.3 Test Model

In [9]:
model.eval()

correct = 0
total = 0

for images, labels in test_loader:
    
    images = Variable(images).cuda()
    outputs = model(images)
    
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of test images: %f %%' % (100 * correct / total))

Accuracy of test images: 98.790000 %
