In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.utils.data as data
import torch.optim as optim
import numpy as np
from datasets import load_dataset

In [2]:
mnist = load_dataset("mnist")
train, test = mnist.get("train"), mnist.get("test")

In [3]:
train.set_format(type="numpy", columns=["image", "label"])
test.set_format(type="numpy", columns=["image", "label"])
num_train_samples = 10000
num_test_samples = 1000

train_indices = np.random.choice(num_train_samples, num_train_samples, replace=False)
test_indices = np.random.choice(num_test_samples, num_test_samples, replace=False)
train = train.rename_column("image", "input").select(train_indices)
test = test.rename_column("image", "input").select(test_indices)

In [4]:
def preprocess(example):
    arr = np.reshape(example["input"], -1)
    arr = arr / np.linalg.norm(arr, axis=-1, keepdims=True)
    example["input"] = arr
    return example

train = train.map(preprocess, num_proc=2)
test = test.map(preprocess, num_proc=2)

Map (num_proc=2):   0%|          | 0/10000 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [5]:
train_inputs = torch.from_numpy(train["input"]).float().squeeze()
test_inputs = torch.from_numpy(test["input"]).float().squeeze()
train_labels = torch.from_numpy(train["label"]).long()
test_labels = torch.from_numpy(test["label"]).long()

In [6]:
train_dataset = data.TensorDataset(train_inputs, train_labels)
test_dataset = data.TensorDataset(test_inputs, test_labels)

In [7]:
class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(28 * 28, 512)
        self.norm_1 = nn.LayerNorm(512)
        self.drop_1 = nn.Dropout(p=0.5)
        self.linear_2 = nn.Linear(512, 512)
        self.norm_2 = nn.LayerNorm(512)
        self.drop_2 = nn.Dropout(p=0.25)
        self.linear_3 = nn.Linear(512, 256)
        self.norm_3 = nn.LayerNorm(256)
        self.drop_3 = nn.Dropout(p=0.25)
        self.linear_4 = nn.Linear(256, 256)
        self.norm_4 = nn.LayerNorm(256)
        self.linear_5 = nn.Linear(256, 10)
        

    def forward(self, x):
        x = self.drop_1(f.relu(self.norm_1(self.linear_1(x))))
        x = self.drop_2(f.relu(self.norm_2(self.linear_2(x))))
        x = self.drop_3(f.relu(self.norm_3(self.linear_3(x))))
        x = f.relu(self.norm_4(self.linear_4(x)))
        out = self.linear_5(x)
        return out


In [8]:
def create_dataloader(dataset, batch_size):
    return data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=2)

def test(model, dataloader,  device=None, verbose=False):
    if verbose:
        print("Testing has started")
    
    model.eval()
    model = model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    test_loss = 0

    with torch.no_grad():

            
        test_loss /= len(dataloader)
        
    if verbose:
        print(f"Testing complete, loss: {test_loss:.3f}")
        
    return test_loss

def train(model, optimizer, train_dataloader, test_dataloader, epochs=10, device=None, verbose=False):
    if verbose:
        print("Training has started")
        
    model = model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    train_losses = []
    test_losses = []
    
    for epoch in range(epochs):
        train_loss = 0
        model.train()
        
        for inputs, labels in train_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
    
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        test_loss = 0
        accuracy = 0
        num_samples = 0
        model.eval()
        with torch.no_grad():
            for inputs, labels in test_dataloader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                test_loss += loss.item()
                correct = torch.sum(torch.argmax(outputs, dim=-1) == labels)
                accuracy += correct.item()
                num_samples += len(labels)

        train_loss /= len(train_dataloader)
        test_loss /= len(test_dataloader)
        accuracy /= num_samples
        train_losses.append(train_loss)
        test_losses.append(test_loss)

        if verbose:
            print(f"Epoch {epoch + 1} complete, train loss: {train_loss:.3f}, test loss: {test_loss:.3f}, accuracy: {accuracy * 100:.2f}")

    if verbose:
        print("Training is complete")

    return train_losses, test_losses

## Batch Gradient Descent

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataloader = create_dataloader(train_dataset, batch_size=len(train_dataset))
test_dataloader = create_dataloader(test_dataset, batch_size=len(test_dataset))

model = Model()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0)

In [10]:
print(f"parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

parameters: 867338


In [11]:
losses = train(model, optimizer, train_dataloader, test_dataloader, epochs=100, device=device, verbose=True)

Training has started
Epoch 1 complete, train loss: 2.372, test loss: 2.210, accuracy: 24.30
Epoch 2 complete, train loss: 2.302, test loss: 2.111, accuracy: 28.90
Epoch 3 complete, train loss: 2.243, test loss: 1.954, accuracy: 45.00
Epoch 4 complete, train loss: 2.183, test loss: 1.878, accuracy: 45.30
Epoch 5 complete, train loss: 2.131, test loss: 1.797, accuracy: 40.60
Epoch 6 complete, train loss: 2.085, test loss: 2.128, accuracy: 18.30
Epoch 7 complete, train loss: 2.184, test loss: 1.969, accuracy: 29.00
Epoch 8 complete, train loss: 2.150, test loss: 1.916, accuracy: 46.60
Epoch 9 complete, train loss: 2.094, test loss: 2.151, accuracy: 29.60
Epoch 10 complete, train loss: 2.120, test loss: 2.546, accuracy: 21.00
Epoch 11 complete, train loss: 2.450, test loss: 1.967, accuracy: 34.30
Epoch 12 complete, train loss: 2.124, test loss: 1.746, accuracy: 42.50
Epoch 13 complete, train loss: 1.939, test loss: 1.648, accuracy: 40.20
Epoch 14 complete, train loss: 1.881, test loss: 1.8

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/Cellar/python@3.12/3.12.3/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Cellar/python@3.12/3.12.3/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tonimo/git/teaching-networks/venv/lib/python3.12/site-packages/torch/__init__.py", line 1854, in <module>
    from . import _meta_registrations
  File "/Users/tonimo/git/teaching-networks/venv/lib/python3.12/site-packages/torch/_meta_registrations.py", line 6242, in <module>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/Cellar/python@3.12/3.12.3/Frameworks/Python.framework/Versions/3.

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/Users/tonimo/git/teaching-networks/venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/b6/9bc2z2316vx9_v730fprmrnr0000gn/T/ipykernel_85201/3952519297.py", line 1, in <module>
    losses = train(model, optimizer, train_dataloader, test_dataloader, epochs=100, device=device, verbose=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/b6/9bc2z2316vx9_v730fprmrnr0000gn/T/ipykernel_85201/2650912727.py", line 36, in train
    for inputs, labels in train_dataloader:
  File "/Users/tonimo/git/teaching-networks/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/Users/tonimo/git/teaching-networks/venv/lib/python3.12/site-packages/torch/utils/data/dat