In [9]:
import os
import sys
import numpy as np
import einops
from typing import Union, Optional, Tuple, List, Dict
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float, Int
import functools
from pathlib import Path
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
from tqdm.notebook import tqdm
from dataclasses import dataclass
from PIL import Image
import json

from cnn import SimpleMLP

import plotly
import plotly.graph_objects as go

from IPython import get_ipython
ipython = get_ipython()
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

MAIN = __name__ == "__main__"

device = torch.device("cuda" if torch.cuda.is_available() else "mps:0")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


# Load MNIST

In [4]:
MNIST_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

def get_mnist(subset: int = 1):
    '''Returns MNIST training data, sampled by the frequency given in `subset`.'''
    mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=MNIST_TRANSFORM)
    mnist_testset = datasets.MNIST(root="./data", train=False, download=True, transform=MNIST_TRANSFORM)

    if subset > 1:
        mnist_trainset = Subset(mnist_trainset, indices=range(0, len(mnist_trainset), subset))
        mnist_testset = Subset(mnist_testset, indices=range(0, len(mnist_testset), subset))

    return mnist_trainset, mnist_testset


mnist_trainset, mnist_testset = get_mnist()
mnist_trainloader = DataLoader(mnist_trainset, batch_size=64, shuffle=True)
mnist_testloader = DataLoader(mnist_testset, batch_size=64, shuffle=False)

# Training

In [12]:
mlp = SimpleMLP().to(device)

batch_size = 64
epochs = 3
mnist_train, _ = get_mnist(subset=10)
mnist_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
losses = []
test_losses = []
for epoch in range(epochs):
    for imgs, labels in mnist_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = mlp(imgs)

        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())

# Plot losses
fig = go.Figure()
fig.add_trace(go.Scatter(y=losses, mode="lines"))
fig.update_layout(title="Losses", xaxis_title="Iteration", yaxis_title="Loss")
fig.show()