In [1]:
import torch
from torch.autograd import Variable
from torch.optim import Adam
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import mnist
from torchvision import transforms
from torch.utils.data import DataLoader

In [2]:
NUM_CLASSES = 10
BATCH_SIZE = 50
LEARING_RATE = 0.001

In [3]:
class Flatten(object):
    def __call__(self, data):
        return data.view(-1)

In [4]:
data_transform = transforms.Compose([
    transforms.ToTensor(),
    Flatten()
])
# train set
dataset = mnist.MNIST('./data/', train=True, download=True, transform=data_transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE)

# validation set
validation_dataset = mnist.MNIST('./data/', train=False, download=True, transform=data_transform)
validation_loader = DataLoader(dataset, batch_size=BATCH_SIZE)

In [5]:
# input_size
data, _ = next(iter(loader))
input_size = len(data[0])  # flatten 28x28 tensor to 1x784 tensor

In [6]:
HIDDEN_SIZE = 100

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = nn.Linear(input_size, HIDDEN_SIZE)
        self.h2 = nn.Linear(HIDDEN_SIZE, NUM_CLASSES)
    
    def forward(self, x):
        x = self.h1(x)
        x = F.relu(x)
        x = self.h2(x)
        x = F.softmax(x, dim=1)
        return x

In [7]:
model = Model()
optimizer = Adam(params=model.parameters(), lr=LEARING_RATE)

for epoch in range(10):
    for data, labels in loader:
        predictions_per_class = model(data)
        highest_prediction, highest_prediction_class = predictions_per_class.max(1)

        # how good are we? compare output with the target classes
        loss = F.nll_loss(predictions_per_class, labels)

        model.zero_grad() # ???
        loss.backward() # backpropagate
        optimizer.step()

In [8]:
loss = 0.0
for data, labels in validation_loader:
    predictions_per_class = model(data)
    _, highest_prediction_class = predictions_per_class.max(1)
    loss += F.nll_loss(predictions_per_class, labels)

print(loss/len(validation_loader))

tensor(-0.9774, grad_fn=<DivBackward0>)
