In [1]:
import logging

logging.basicConfig()

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

import math
from datasets import load_dataset
import sadl
from sadl import xp
from tqdm import tqdm



In [2]:
N_TRAIN_SAMPLES = 60_000
N_TEST_SAMPLES = 10_000

BATCH_SIZE = 256

N_TRAIN_BATCHES = math.ceil(N_TRAIN_SAMPLES / BATCH_SIZE) # mnist train has 60k images
N_TEST_BATCHES = math.ceil(N_TEST_SAMPLES / BATCH_SIZE) # mnist test has 10k images

N_EPOCHS = 10

DEVICE = "cpu"

In [3]:
ds = load_dataset("ylecun/mnist")



In [4]:
def normalize(examples):
    # we could also use sadl.tensor here, but xp (numpy/cupy) is sufficient because we just transform the data once
    pixel_values = [xp.array(img, dtype=xp.float32).flatten() for img in examples["image"]]
    examples["pixel_values"] = [(pv / 255.0 - 0.1307) / 0.3081 for pv in pixel_values]
    return examples


In [5]:
ds_train = ds["train"].map(normalize, remove_columns=["image"], batched=True)
ds_eval = ds["test"].map(normalize, remove_columns=["image"], batched=True)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [6]:
def to_sadl_tensors(batch, onehot=True):
    x = sadl.tensor(batch["pixel_values"], dtype=xp.float32)
    y = sadl.tensor(xp.eye(10)[batch["label"]] if onehot else batch["label"])
    return x, y

In [7]:
model = sadl.Mlp([
    sadl.Linear(dim_in=784, dim_out=784),
    sadl.Linear(dim_in=784, dim_out=10),
])
log_softmax = sadl.LogSoftmax()

In [8]:
optimizer = sadl.SGD(params=list(model.parameters), lr=1e-3)

In [9]:
model = model.copy_to_device(device=DEVICE)
log_softmax = log_softmax.copy_to_device(device=DEVICE)

In [10]:
@sadl.no_grad_fn
def eval(model, ds_eval) -> float:
    n_correct = 0
    n_seen = 0

    for batch in tqdm(
        ds_eval.iter(batch_size=BATCH_SIZE),
        desc=f"Evaluating",
        total=N_TEST_BATCHES,
    ):
        x, y, = to_sadl_tensors(batch, onehot=False)

        x = x.copy_to_device(device=DEVICE)
        y = y.copy_to_device(device=DEVICE)

        logits = model(x)

        n_correct += xp.sum(logits.argmax(axis=-1) == y).item()
        n_seen += y.shape[0]


    return n_correct / n_seen
    

In [None]:
for epoch in range(N_EPOCHS):

    ds_train_iter = ds_train.shuffle(seed=epoch).iter(batch_size=BATCH_SIZE)
    
    for batch in tqdm(
        ds_train_iter,
        desc=f"Epoch {epoch+1}",
        total=N_TRAIN_BATCHES,
    ):

        optimizer.zero_grad()

        x, y, = to_sadl_tensors(batch)

        x = x.copy_to_device(device=DEVICE)
        y = y.copy_to_device(device=DEVICE)

        logits = model(x)

        loss = -xp.mean(xp.sum(log_softmax(logits) * y, axis=-1))

        optimizer.backward(loss=loss)
        optimizer.step()


    eval_accuracy = eval(model, ds_eval)

    logger.info(f"Train loss: {loss.item()}")
    logger.info(f"Eval accuracy: {eval_accuracy*100:.2f}%")


Epoch 1:   0%|          | 0/1875 [00:00<?, ?it/s]

Epoch 1: 100%|██████████| 1875/1875 [00:14<00:00, 128.25it/s]
INFO:__main__:Loss: 0.5011394479866884
Epoch 2: 100%|██████████| 1875/1875 [00:14<00:00, 128.72it/s]
INFO:__main__:Loss: 0.3186600973746674
Epoch 3: 100%|██████████| 1875/1875 [00:14<00:00, 133.13it/s]
INFO:__main__:Loss: 0.2250154400856804
Epoch 4:  37%|███▋      | 694/1875 [00:05<00:08, 131.28it/s]


KeyboardInterrupt: 