In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm_notebook as tqdm

In [2]:
batch_size = 32
num_epochs = 20
lr = 0.001

In [3]:
# transformation on image
transform = transforms.Compose([
                                transforms.Grayscale(num_output_channels=3),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
# data loader
train_loader = DataLoader(datasets.MNIST('data/', train=True, download=True, transform=transform), 
                          batch_size=batch_size, shuffle=True)



In [4]:
# print(next(iter(train_loader)))

# for x, y in train_loader:
#     t = transforms.ToPILImage()
#     plt.imshow(t(x[0]))
#     print(y[0])
#     break

In [5]:
class MNISTClassifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(MNISTClassifier, self).__init__()
        self.input_size = input_size
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 1024)
        self.fc3 = nn.Linear(1024, output_size)
        
    def forward(self, x):
        x = x.view(-1, self.input_size)
        x = F.sigmoid(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        return x

model = MNISTClassifier(3*28*28, 10)

In [6]:
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
    losses = []
    for x, y in tqdm(train_loader):
        optimizer.zero_grad()
        pred = model(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print('[{}/{}]: {}'.format(epoch, num_epochs, torch.mean(torch.FloatTensor(losses))))

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




[0/20]: 1.5997411012649536


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[1/20]: 1.5372625589370728


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[2/20]: 1.5274230241775513


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[3/20]: 1.5252759456634521


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[4/20]: 1.5204590559005737


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[5/20]: 1.5186926126480103


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[6/20]: 1.5197830200195312


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[7/20]: 1.5182468891143799


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[8/20]: 1.5158323049545288


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[9/20]: 1.514346718788147


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[10/20]: 1.5130913257598877


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[11/20]: 1.5142236948013306


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[12/20]: 1.5120887756347656


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[13/20]: 1.5131043195724487


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[14/20]: 1.5103023052215576


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[15/20]: 1.5093730688095093


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[16/20]: 1.5107849836349487


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[17/20]: 1.5101836919784546


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[18/20]: 1.5137245655059814


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


[19/20]: 1.5131880044937134
