In [1]:
import torch_utils as tu
import experiment_utils as eu
import torch
import torch.nn as nn
import datasets
from PIL import Image
import numpy as np
from torchvision import transforms

In [2]:
data = datasets.load_dataset("ylecun/mnist")
data

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

In [3]:
def transform(samples):
    transform = transforms.Compose(
        [
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.1309,), (0.2893,)),
        ]
    )

    samples["image"] = [transform(img) for img in samples["image"]]
    return samples

In [4]:
train_data = data["train"]
test_data = data["test"]
train_data.set_transform(transform)
test_data.set_transform(transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x = torch.softmax(x, dim=1)
        return x


class ModelWrapper(nn.Module):
    metrics = {"accuracy": eu.compare_fns.max}

    def __init__(self):
        super().__init__()
        self.model = Model()

    def forward(self, batch):
        x = batch["image"]
        y = batch["label"]
        y_pred = self.model(x)
        loss = nn.functional.cross_entropy(y_pred, y)
        output = {"loss": loss, "output": y_pred}
        # if not self.training:
        accuracy = torch.mean((y_pred.argmax(dim=1) == y).float())
        output["accuracy"] = accuracy
        return output


model = ModelWrapper()
model

ModelWrapper(
  (model): Model(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (fc1): Linear(in_features=2304, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [6]:
lr = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [7]:
logger = eu.Logger("../.logs")
logger.start_experiment()

In [8]:
tu.train(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    train_steps=1000,
    log_steps=100,
    logger=logger
)

Training:   0%|          | 0/1000 [00:00<?, ?it/s]

In [9]:
logger.end_experiment()