<a href="https://colab.research.google.com/github/The20thDuck/MNIST_nn/blob/main/MNIST_nn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install einops



In [2]:
import torch
import einops
from torch import nn
import torchvision
from torch.utils.data.dataset import random_split


In [3]:
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)



In [4]:
class ResNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.l1 = nn.Linear(28*28, 64)
    self.l2 = nn.Linear(64, 64)
    self.l3 = nn.Linear(64, 10)
    self.do = nn.Dropout(.1)

  def forward(self, x):
    h1 = nn.functional.relu(self.l1(x))
    h2 = nn.functional.relu(self.l2(h1))
    do = self.do(h1 + h2)
    return self.l3(do)

model = ResNet().to('cuda')

In [5]:

optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

loss = nn.CrossEntropyLoss()

In [6]:
mnist_data = torchvision.datasets.MNIST('data', train=True, download=True, transform=torchvision.transforms.ToTensor())
train, val = random_split(mnist_data, [55_000, 5_000])
train_loader = torch.utils.data.DataLoader(train,
                                          batch_size=32)
val_loader = torch.utils.data.DataLoader(val,
                                          batch_size=32)


In [7]:
nb_epochs = 5
model.train()
for epoch in range(nb_epochs):
  losses = list()
  accuracies = list()
  for batch in train_loader:
    x, y = batch # x: b x 1 x 28 x 28
    x = x.view(x.size(0), -1)
    x = x.to('cuda')
    y = y.to('cuda')
    # 1 forward
    l = model(x)

    # 2 compute objective function
    J = loss(l, y) # takes in logits, and one of (class indices, probabilities)
    # 3 clean gradients 
    model.zero_grad()
    # 4 accumulate gradients
    J.backward()
    # step in opposite direction of gradient
    optimizer.step()
    losses.append(J.item())
    accuracies.append((y == l.detach().argmax(dim=1)).float().mean())
  print(f'Epoch {epoch+1}, train loss: {torch.tensor(losses).mean() :.2f}, accuracy: {torch.tensor(accuracies).mean() :.2f}')

# Validation calculation
model.eval()
losses = list()
accuracies = list()
for batch in val_loader:
  x, y = batch # x: b x 1 x 28 x 28
  x = x.view(x.size(0), -1)
  x = x.to('cuda')
  y = y.to('cuda')
  
  # 1 forward
  with torch.no_grad():
    l = model(x)

  # 2 compute objective function
  J = loss(l, y) # takes in logits, and one of (class indices, probabilities)

  accuracies.append((y == l.detach().argmax(dim=1)).float().mean())
  losses.append(J.item())
print(f'validation loss: {torch.tensor(losses).mean() :.2f}, accuracy: {torch.tensor(accuracies).mean() :.2f}')

Epoch 1, train loss: 0.87, accuracy: 0.77
Epoch 2, train loss: 0.38, accuracy: 0.89
Epoch 3, train loss: 0.32, accuracy: 0.91
Epoch 4, train loss: 0.28, accuracy: 0.92
Epoch 5, train loss: 0.25, accuracy: 0.93
validation loss: 0.23, accuracy: 0.94
