# MNIST with PyTorch
In this notebook, we will implement a fully connected network that classifies
handwritten digits.

This time, we will use the torchvision mnist dataset. The underlying data is
the same as in the Keras version, but torchvision is easier to interface with
from PyTorch.

In [2]:
from torchvision.datasets import MNIST
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
%matplotlib inline

DEVICE = torch.device("cpu") # Put your device string here.

ModuleNotFoundError: No module named 'torchvision'

First, let's setup our dataset, as we did in the micrograd example. In addition,
we will visualize a sample data point.

In [None]:
train_loader = MNIST("data", download=True, train=True, transform=transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = MNIST("data", download=True, train=False, transform=transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.1307,), (0.3081,))
]))
plt.imshow(train_loader[0][0][0], cmap="gray")
plt.title(f"Ground truth={train_loader[0][1]}");

Let's first implement our model class.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class MNISTNet(nn.Module):
  def __init__(self):
    super().__init__()
    # Note the use of nn.Sequential here for convenience.
    self.layers = nn.Sequential(
      nn.Flatten(),
      nn.Linear(784, 800),
      nn.ReLU(),
      nn.Linear(800, 10),
      nn.Softmax(dim=1)
    )

  def forward(self, x):
    return self.layers(x) # Using nn.Sequential makes this easy.

Now we will instantiate our model, loss function, and optimizer:

In [None]:
model = MNISTNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

This time, now that we have a much more powerful engine, we can actually train
our model!

Specifically note that instead of performing full-batch gradient descent, like
we did in `iris.ipynb`, we are using "mini-batch" gradient descent, so that we
aren't training on the entire dataset at once.

In [None]:
for epoch in range(100):
  for i, (x, y) in enumerate(train_loader):
    x, y = x.to(DEVICE), y.to(DEVICE)
    optimizer.zero_grad()
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()
  print(f"Epoch {epoch} loss: {loss}")

In [None]:
print("Final Accuracy:", sum(torch.argmax(model(x.to(DEVICE)), dim=1) == y.to(DEVICE) for x, y in test_loader) / len(test_loader))

Now that our model is trained, try running it on some test examples, and see how it does!

In [None]:
SAMPLE_IDX = 0
x, y = test_loader[SAMPLE_IDX]
plt.imshow(x[0], cmap="gray")
plt.title(f"Ground truth={y}")
pred = torch.argmax(model(x.to(DEVICE)))
plt.xlabel(f"Prediction={pred}");