In [None]:
import numpy as np
import tnn
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as f

from datasets import load_dataset

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [None]:
dataset = load_dataset("ylecun/mnist", num_proc=2)
train_size = 60000
test_size = 10000

train = dataset.get("train")
test = dataset.get("test")

train_indices = np.random.choice(len(train), size=train_size, replace=False)
test_indices = np.random.choice(len(test), size=test_size, replace=False)

train = train.select(train_indices)
test = test.select(test_indices)

In [None]:
def to_numpy(example):
    arr = np.reshape(example["image"], -1) / 255.0
    example["input"] = arr
    return example


train_dataset = train.map(to_numpy, num_proc=2).select_columns(["input", "label"])
test_dataset = test.map(to_numpy, num_proc=2).select_columns(["input", "label"])

In [None]:
def collate_fn(batch):
    inputs = torch.tensor([ex["input"] for ex in batch]).float()
    labels = torch.tensor([ex["label"] for ex in batch]).long()
    return inputs, labels


trainloader = data.DataLoader(
    train_dataset,
    batch_size=len(train_dataset),
    shuffle=True,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
)
testloader = data.DataLoader(
    test_dataset,
    batch_size=len(test_dataset),
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
)

In [None]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(28 * 28, 512)
        self.norm_1 = nn.LayerNorm(512)
        self.drop_1 = nn.Dropout(0.4)
        self.linear_2 = nn.Linear(512, 512)
        self.norm_2 = nn.LayerNorm(512)
        self.drop_2 = nn.Dropout(0.2)
        self.linear_3 = nn.Linear(512, 512)
        self.norm_3 = nn.LayerNorm(512)
        self.linear_4 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.norm_1(self.linear_1(x))
        x = self.drop_1(f.relu(x))

        x = self.norm_2(self.linear_2(x))
        x = self.drop_2(f.relu(x))

        x = self.norm_3(self.linear_3(x))
        x = self.linear_4(f.relu(x))
        return {"logits": x}

## Batch Gradient Descent

In [None]:
lr = 1e-1
loss_fn = nn.CrossEntropyLoss()
model = tnn.Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-batch.h5",
    verbose=10,
)

In [None]:
batch_metrics = trainer.train(epochs=80)

## Stochastic Gradient Descent

In [None]:
inputs = torch.tensor([ex["input"] for ex in train_dataset]).float().to(device)
labels = torch.tensor([ex["label"] for ex in train_dataset]).long().to(device)
train_tensor_dataset = data.TensorDataset(inputs, labels)

trainloader = data.DataLoader(
    train_tensor_dataset, batch_size=1, shuffle=True, drop_last=False
)

In [None]:
model = tnn.Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-sgd.h5",
    verbose=10,
)

In [None]:
sgd_metrics = trainer.train(epochs=80)

## Mini-batch Gradient Descent

### Batch size 32

In [None]:
trainloader = data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
)

In [None]:
model = tnn.Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-mini-batch-32.h5",
    verbose=10,
)

In [None]:
mini_batch_32_metrics = trainer.train(epochs=80)

### Batch size 64

In [None]:
trainloader = data.DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
)

In [None]:
model = tnn.Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-mini-batch-64.h5",
    verbose=10,
)

In [None]:
mini_batch_64_metrics = trainer.train(epochs=80)

### Batch size 128

In [None]:
trainloader = data.DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
)

In [None]:
model = tnn.Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-mini-batch-128.h5",
    verbose=10,
)

In [None]:
mini_batch_128_metrics = trainer.train(epochs=80)

### Batch size 256

In [None]:
trainloader = data.DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
)

In [None]:
model = tnn.Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-mini-batch-256.h5",
    verbose=10,
)

In [None]:
mini_batch_256_metrics = trainer.train(epochs=80)

### Batch size 512

In [None]:
trainloader = data.DataLoader(
    train_dataset,
    batch_size=512,
    shuffle=True,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
)

In [None]:
model = tnn.Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-mini-batch-512.h5",
    verbose=10,
)

In [None]:
mini_batch_512_metrics = trainer.train(epochs=80)

## SGD w/ Momentum 

### No Nesterov Accelerated Gradient

In [None]:
trainloader = data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
)

In [None]:
model = Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-momentum-no-nag.h5",
    verbose=10,
)

In [None]:
trainer.train(epochs=80)

### With Nesterov Accelerated Gradient

In [None]:
model = Model(MLP())
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)

In [None]:
trainer = tnn.Trainer(
    model,
    optim,
    loss_fn,
    trainloader,
    testloader,
    save_weights=False,
    device=device,
    path="../training/mnist-momentum-nag.h5",
    verbose=10,
)

In [None]:
trainer.train(epochs=80)