In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import numpy
import matplotlib.pyplot as plt

In [2]:
class Model(nn.Module):
    def __init__(self, in_channel):
        super(Model, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channel, 8, 5, padding=2)
        self.conv2 = nn.Conv2d(8, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(16*7*7, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.pool(out)
        out = F.relu(self.conv2(out))
        out = self.pool(out)
        out = out.view(out.size()[0], -1)
        out = F.relu(self.fc1(out))
        return self.fc2(out)

In [3]:
def load_data(root='../data', pre_process=None):
    train_dataset = torchvision.datasets.MNIST(root, download=True, train=True, transform=pre_process)
    test_dataset = torchvision.datasets.MNIST(root, download=True, train=False, transform=pre_process)
        
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1024,
        shuffle=True
    )
    return train_loader, test_loader

In [4]:
def train(train_loader, model, loss_fn, criterion, epochs=100):
    
    model.train()
    
    for epoch in range(epochs):
        loss = 0.0
        for idx, (data, label) in enumerate(train_loader):
            if torch.cuda.is_available():
                data, label = data.cuda(), label.cuda()
            output = model(data)
            criterion.zero_grad()
            curr_loss = loss_fn(output, label)
            curr_loss.backward()
            criterion.step()
            
            loss += curr_loss.item()
        
        print('loss for epoch {} : {}'.format(epoch, loss))
    
    return model

In [5]:
def eval(test_loader, model, loss_fn):
    
    loss = 0.0
    correct = 0.0
    
    model.eval()
    with torch.no_grad():
        for idx, (data, label) in enumerate(test_loader):
            if torch.cuda.is_available():
                data, label = data.cuda(), label.cuda()
            output = model(data)
            curr_loss = loss_fn(output, label)
            loss += curr_loss.item()
            pred = torch.max(output, dim=1, keepdim=True)[1]
            correct += pred.eq(label.view_as(pred)).sum()
    
    print('loss on test dataset : {}'.format(loss))
    print('accuarcy : {}'.format(correct/len(test_loader.dataset)))

In [6]:
if __name__=='__main__':
    
    pre_process = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
    train_data, test_data = load_data(pre_process=pre_process)
    
    model = Model(1)
    if torch.cuda.is_available():
        model = model.cuda()
    
    loss_fn = nn.CrossEntropyLoss()
    criterion = optim.SGD(model.parameters(), lr=0.01)
    
    model = train(train_data, model, loss_fn, criterion)
    
    eval(test_data, model, loss_fn)

loss for epoch 0 : 1147.6480618864298
loss for epoch 1 : 255.77923431247473
loss for epoch 2 : 158.52334367111325
loss for epoch 3 : 115.3131293617189
loss for epoch 4 : 92.82713167369366
loss for epoch 5 : 79.34567393362522
loss for epoch 6 : 69.62860082089901
loss for epoch 7 : 62.309159621596336
loss for epoch 8 : 56.3913088850677
loss for epoch 9 : 52.139172468334436
loss for epoch 10 : 48.811040475964546
loss for epoch 11 : 45.86175771802664
loss for epoch 12 : 42.98154804855585
loss for epoch 13 : 39.88753915950656
loss for epoch 14 : 38.08829386904836
loss for epoch 15 : 36.13586365431547
loss for epoch 16 : 34.32751172967255
loss for epoch 17 : 32.28043647110462
loss for epoch 18 : 31.11438038945198
loss for epoch 19 : 29.24854524806142
loss for epoch 20 : 28.004321239888668
loss for epoch 21 : 26.97936587035656
loss for epoch 22 : 25.386552929878235
loss for epoch 23 : 24.927415922284126
loss for epoch 24 : 23.1916126832366
loss for epoch 25 : 22.0575286783278
loss for epoch 2