In [100]:
import torch
import torchvision

from torch import nn
from torch import optim
from torch.utils.data import random_split
from torch.utils.data import DataLoader, Dataset

import numpy as np
import matplotlib.pyplot as plt

In [66]:
torch.manual_seed(123)

<torch._C.Generator at 0x7fb883107630>

In [67]:
train_ds = torchvision.datasets.MNIST(
        './files/', train=True, download=True, 
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.1307,), (0.3081,) # mean & std of mnist data
            )]))
test_ds = torchvision.datasets.MNIST(
        './files/', train=False, download=True, 
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.1307,), (0.3081,) # mean & std of mnist data
            )]))

In [68]:
train_ds, valid_ds = random_split(train_ds, [55000, 5000])

In [83]:
train_loader = DataLoader(train_ds, shuffle=True, batch_size=64)
valid_loader = DataLoader(valid_ds, shuffle=True, batch_size=64)
test_loader = DataLoader(test_ds, shuffle=True, batch_size=100)

In [71]:
class Net(nn.Module):
    def __init__(self, inp_size, hidden_size, out_size):
        super().__init__()
        self.l1 = nn.Linear(inp_size, hidden_size)
        self.l2 = nn.ReLU()
        self.l3 = nn.Linear(hidden_size, out_size)
    
    def forward(self, x):
        x = self.l1(x).requires_grad_()
        x = self.l2(x)
        x = self.l3(x)
        return x

In [72]:
net = Net(784, 50, 10)
optimizer = optim.SGD(net.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

In [98]:
def train(epochs):
    for epoch in range(epochs):    
        net.train() # set model to train mode
        for xb, yb in train_loader:
            yhat = net(xb.view(-1, 784,)) # forward prop the input through the network. Need to flatten array for MLP
            loss = loss_fn(yhat, yb)
            loss.backward() # backprop w.r.t 
            optimizer.step() # update params
            optimizer.zero_grad() # reset grads
            
        net.eval() # switch to evaluation mode
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in valid_loader:
                pred = net(xb.view(-1, 784))
                tot_loss += loss_fn(pred, yb)
                tot_acc += (torch.argmax(pred, dim=1)==yb).float().mean()
            print(tot_loss, tot_acc)

In [99]:
train(10)

tensor(10.9731) tensor(75.7031)
tensor(10.4552) tensor(75.9375)
tensor(10.3089) tensor(75.9062)
tensor(9.9877) tensor(75.9219)
tensor(9.7753) tensor(76.0469)
tensor(9.5666) tensor(76.2969)
tensor(9.1427) tensor(76.2188)
tensor(9.7240) tensor(76.0469)
tensor(8.8160) tensor(76.4375)
tensor(8.6029) tensor(76.4531)
