In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as f

from datasets import load_dataset 
from torch.utils.data import DataLoader
from src.trainer import Trainer
from src.model import Model

In [2]:
dataset = load_dataset("ylecun/mnist", num_proc=2)
train = dataset.get("train")
test = dataset.get("test")

In [3]:
def to_numpy(example):
    arr = np.reshape(example["image"], -1) / 255.0
    example["input"] = arr
    return example

train_dataset = train.map(to_numpy, num_proc=2).select_columns(["input", "label"])
test_dataset = test.map(to_numpy, num_proc=2).select_columns(["input", "label"])

In [4]:
def collate_fn(batch):
    inputs = torch.tensor([ex["input"] for ex in batch]).float()
    labels = torch.tensor([ex["label"] for ex in batch]).long()
    return inputs, labels

trainloader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=False, collate_fn=collate_fn, num_workers=2)
testloader = data.DataLoader(test_dataset, batch_size=1024, shuffle=False, drop_last=False, collate_fn=collate_fn, num_workers=2)

In [5]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(784, 512)
        self.drop_1 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(512, 512)
        self.drop_2 = nn.Dropout(0.25)
        self.linear_3 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.drop_1(f.relu(self.linear_1(x)))
        x = self.drop_2(f.relu(self.linear_2(x)))
        return dict(logits=f.relu(self.linear_3(x)))

In [6]:
model = Model(MLP())
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), momentum=0.9)

In [7]:
trainer = Trainer(model, optim, loss_fn, trainloader, testloader, path="./train.h5", verbose=True)

In [9]:
trainer.train(epochs=10)

training started
(epoch: 1): train loss: 0.3334, test loss: 0.2650, train acc: 0.9035, test acc: 0.9234
weights saved to ./train.h5/trajectory/weights-epoch-1
(epoch: 2): train loss: 0.3180, test loss: 0.2533, train acc: 0.9080, test acc: 0.9266
weights saved to ./train.h5/trajectory/weights-epoch-2
(epoch: 3): train loss: 0.3066, test loss: 0.2421, train acc: 0.9113, test acc: 0.9292
weights saved to ./train.h5/trajectory/weights-epoch-3
(epoch: 4): train loss: 0.2935, test loss: 0.2329, train acc: 0.9142, test acc: 0.9320
weights saved to ./train.h5/trajectory/weights-epoch-4
(epoch: 5): train loss: 0.2820, test loss: 0.2236, train acc: 0.9178, test acc: 0.9351
weights saved to ./train.h5/trajectory/weights-epoch-5
(epoch: 6): train loss: 0.2731, test loss: 0.2160, train acc: 0.9205, test acc: 0.9368
weights saved to ./train.h5/trajectory/weights-epoch-6
(epoch: 7): train loss: 0.2635, test loss: 0.2067, train acc: 0.9234, test acc: 0.9408
weights saved to ./train.h5/trajectory/weigh

{'train_losses': [0.33337495325089517,
  0.31802281537162724,
  0.30656916325661676,
  0.2934945853216562,
  0.2820141342466574,
  0.27305524947165427,
  0.26348362365828903,
  0.2550256851194764,
  0.24571655692258623,
  0.23900088671046787],
 'test_losses': [0.2649575486779213,
  0.2533253595232964,
  0.2420528218150139,
  0.2328703783452511,
  0.22356420755386353,
  0.21602563187479973,
  0.20673343017697335,
  0.1998528406023979,
  0.19284727349877356,
  0.18716058358550072],
 'train_accs': [0.9034666666666666,
  0.908,
  0.9112833333333333,
  0.9142,
  0.9178166666666666,
  0.9205166666666666,
  0.9234333333333333,
  0.92535,
  0.9282333333333334,
  0.9297333333333333],
 'test_accs': [0.9234,
  0.9266,
  0.9292,
  0.932,
  0.9351,
  0.9368,
  0.9408,
  0.9415,
  0.9442,
  0.9457]}