In [1]:
import torch 
from torch import nn
from torch.autograd import Variable
import torchvision
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
%load_ext autoreload
%autoreload 2
%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
sns.set(color_codes=True)

import torchvision.datasets as dsets
import torchvision.transforms as T

from torch.utils.data import DataLoader

In [2]:
batch = 100
in_dim = 28 * 28
hidden_dim = 100
out_dim = 10
n_epoch = 5
lr = 1e-2

In [3]:
mnist_train = dsets.MNIST(root='./data', train=True, transform=T.ToTensor(), download=True)
mnist_test = dsets.MNIST(root='./data', train=False, transform=T.ToTensor(), download=True)

In [4]:
mnist_train_loader = DataLoader(mnist_train, batch_size=batch, shuffle=True, num_workers=5)
mnist_test_loader = DataLoader(mnist_test, batch_size=10000, shuffle=False, num_workers=5)

In [5]:
class Model(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
    def forward(self, x):
        net = self.fc1(x)
        net = self.relu(net)
        net = self.fc2(net)
        return net

In [6]:
model = Model(in_dim, hidden_dim, out_dim)
model.cuda(device=1)
criterion = nn.CrossEntropyLoss().cuda(device=1)
optim = torch.optim.Adam(model.parameters(), lr)

In [7]:
for e in range(n_epoch):
    for i, (images, labels) in enumerate(mnist_train_loader):
        x = Variable(images.view(-1, 28*28).cuda(1))
        y = Variable(labels.cuda(1))
        logits = model(x)
        loss = criterion(logits, y)
        model.zero_grad()
        loss.backward()
        optim.step()        
        if i % 100 == 0:
            print 'epoch: [{0:d}/{1:d}], iter: {2: d}, loss: [{3: .4f}]'.format(e, n_epoch, i, loss.data[0])   

epoch: [0/5], iter:  0, loss: [ 2.3411]
epoch: [0/5], iter:  100, loss: [ 0.2154]
epoch: [0/5], iter:  200, loss: [ 0.1975]
epoch: [0/5], iter:  300, loss: [ 0.1142]
epoch: [0/5], iter:  400, loss: [ 0.0474]
epoch: [0/5], iter:  500, loss: [ 0.1247]
epoch: [1/5], iter:  0, loss: [ 0.0885]
epoch: [1/5], iter:  100, loss: [ 0.0492]
epoch: [1/5], iter:  200, loss: [ 0.1291]
epoch: [1/5], iter:  300, loss: [ 0.0787]
epoch: [1/5], iter:  400, loss: [ 0.0205]
epoch: [1/5], iter:  500, loss: [ 0.0727]
epoch: [2/5], iter:  0, loss: [ 0.0763]
epoch: [2/5], iter:  100, loss: [ 0.0392]
epoch: [2/5], iter:  200, loss: [ 0.0635]
epoch: [2/5], iter:  300, loss: [ 0.0730]
epoch: [2/5], iter:  400, loss: [ 0.1127]
epoch: [2/5], iter:  500, loss: [ 0.1678]
epoch: [3/5], iter:  0, loss: [ 0.0905]
epoch: [3/5], iter:  100, loss: [ 0.0138]
epoch: [3/5], iter:  200, loss: [ 0.0382]
epoch: [3/5], iter:  300, loss: [ 0.1351]
epoch: [3/5], iter:  400, loss: [ 0.1130]
epoch: [3/5], iter:  500, loss: [ 0.1450]


In [8]:
for images, labels in mnist_test_loader:
    x = Variable(images.view(-1, 784).cuda(1))
    logits = model(x)
    _, labels_pred = torch.max(logits, 1)
    print np.mean((labels_pred.data.cpu() == labels))

0.9675
