In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm_notebook as tqdm

In [2]:
batch_size = 32
num_epochs = 20
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# transformation on image
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
# data loader
train_loader = DataLoader(datasets.MNIST('data/', train=True, download=True, transform=transform), 
                          batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(datasets.MNIST('data/', train=False, download=True, transform=transform), 
                         batch_size=batch_size, shuffle=True, num_workers=4)



In [4]:
# print(next(iter(train_loader)))

# for x, y in train_loader:
#     t = transforms.ToPILImage()
#     plt.imshow(t(x[0]))
#     print(y[0])
#     break

In [5]:
class MNISTClassifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(MNISTClassifier, self).__init__()
        self.input_size = input_size
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 1024)
        self.fc3 = nn.Linear(1024, output_size)
        
    def forward(self, x):
        x = x.view(-1, self.input_size)
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x
    
class MNISTConvClassifier(nn.Module):
    def __init__(self, in_channels, output_size):
        super(MNISTConvClassifier, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, output_size)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


#model = MNISTClassifier(28*28, 10).to(device)
model = MNISTConvClassifier(1, 10).to(device)

In [6]:
def test(model, test_loader):
    num = 0
    loss = 0
    for x, y in test_loader:
        with torch.no_grad():
            x = x.to(device)
            y = y.to(device)
            pred_probs = model(x)
            topv, topi = pred_probs.topk(1)
            topi = topi.view(-1)
            loss += torch.sum(topi != y)
            num += len(y)
    return loss, num

In [7]:
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
    losses = []
    for x, y in tqdm(train_loader):
        optimizer.zero_grad()
        x = x.to(device)
        y = y.to(device)
        pred_probs = model(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(pred_probs, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print('[{}/{}]: train loss = {} test loss = {}/{}'.format(epoch, num_epochs,
                                                              torch.mean(torch.FloatTensor(losses)),
                                                              *test(model, test_loader)))

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[0/20]: train loss = 0.29921650886535645 test loss = 240/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[1/20]: train loss = 0.07648199796676636 test loss = 182/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[2/20]: train loss = 0.053167086094617844 test loss = 138/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[3/20]: train loss = 0.04239850491285324 test loss = 161/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[4/20]: train loss = 0.03377554193139076 test loss = 112/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[5/20]: train loss = 0.02811502292752266 test loss = 93/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[6/20]: train loss = 0.023468298837542534 test loss = 101/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[7/20]: train loss = 0.019344309344887733 test loss = 86/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[8/20]: train loss = 0.016692301258444786 test loss = 102/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[9/20]: train loss = 0.0141281234100461 test loss = 102/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[10/20]: train loss = 0.011579005979001522 test loss = 88/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[11/20]: train loss = 0.009657091461122036 test loss = 84/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[12/20]: train loss = 0.008890504948794842 test loss = 85/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[13/20]: train loss = 0.006390835158526897 test loss = 83/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[14/20]: train loss = 0.006752693559974432 test loss = 91/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[15/20]: train loss = 0.004755461122840643 test loss = 80/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[16/20]: train loss = 0.005365938413888216 test loss = 83/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[17/20]: train loss = 0.004354769363999367 test loss = 69/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[18/20]: train loss = 0.003254977520555258 test loss = 76/10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[19/20]: train loss = 0.003677463624626398 test loss = 82/10000
