In [1]:
from collections import Counter
from dataclasses import dataclass, field
from itertools import product
import math
from operator import add, eq
import pathlib
from typing import NamedTuple, Self

import numpy as np
import torch as t
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from tqdm import tqdm

## Utils

In [34]:
################
#    Config    #
################

REFERENCE_MODEL_SIZE = 6

@dataclass(frozen=True, slots=True)
class Config:
    """Experiment configuration"""
    model_size: int
    """Number of (hidden) channels in each of the two convolutional layers
    in terms of power of sqrt(2) (rounded if necessary).
    """
    dataset_size: int
    """Fraction of training data used for training
    in terms of the powers of (1/2).
    """
    # KW
    fc_dim: int = field(default=64, kw_only=True)
    """Fully connected layer dimension"""
    kernel_size: int = field(default=3, kw_only=True)
    """Convolution kernel size"""
    n_epochs: int = field(default=1, kw_only=True)
    
    def __post_init__(self) -> None:
        assert 0 < self.model_size, f"model_size={self.model_size}"
        assert 0 <= self.dataset_size, f"dataset_size={self.dataset_size}"
        assert 0 < self.n_epochs, f"n_epochs={self.n_epochs}"
    
    @classmethod
    def default(cls) -> Self:
        return cls(
            model_size=1,
            dataset_size=0
        )

    @property
    def n_channels(self) -> int:
        return REFERENCE_MODEL_SIZE * int(math.sqrt(2) ** self.model_size)
    
    @property
    def dataset_fraction(self) -> float:
        return 0.5 ** self.dataset_size
        
cfg = Config.default()

#################
#    Dataset    #
#################


DATA_PATH = pathlib.Path("../data")
assert DATA_PATH.exists()

def preprocess_batch(batch: t.Tensor) -> t.Tensor:
    assert batch.ndim == 3
    assert eq(*batch.shape[1:])
    batch_dim, im_dim = batch.shape[:2]
    processed_batch = batch.to(dtype=t.float32).unsqueeze(-1).reshape(batch_dim, 1, im_dim, im_dim)
    return (processed_batch - processed_batch.mean()) / processed_batch.std()

@dataclass(frozen=True, slots=True)
class Dataset:
    train_x: t.Tensor
    train_y: t.Tensor
    test_x: t.Tensor
    test_y: t.Tensor
    
    @classmethod
    def make(cls, cfg: Config = Config.default()) -> Self:
        train = MNIST(str(DATA_PATH), train=True, download=True)
        test = MNIST(str(DATA_PATH), train=False, download=True)
        dataset_size = int(cfg.dataset_fraction * len(train.data))
        return cls(
            train_x=preprocess_batch(train.data)[:dataset_size],
            train_y=train.targets[:dataset_size],
            test_x=preprocess_batch(test.data),
            test_y=test.targets
        )

    def __post_init__(self) -> None:
        assert len(self.train_x) == len(self.train_y)
        assert len(self.test_x) == len(self.test_y)
        assert self.train_x.ndim == self.test_x.ndim == 4
        assert self.train_y.ndim == self.test_y.ndim == 1
        assert self.train_x.shape[1:] == self.test_x.shape[1:]
        
        
ds = Dataset.make()
ds.train_x.shape


torch.Size([60000, 1, 28, 28])

## Model

In [20]:

class CNN(nn.Module):
    def __init__(self, cfg: Config = Config.default()) -> None:
        super().__init__()
        self.cfg = cfg
        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=cfg.n_channels, kernel_size=cfg.kernel_size, padding=1
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(
            in_channels=cfg.n_channels, out_channels=cfg.n_channels, kernel_size=cfg.kernel_size, padding=1
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(cfg.n_channels * 7 * 7, cfg.fc_dim)
        self.fc2 = nn.Linear(cfg.fc_dim, 10)
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

####################
#    Evaluation    #
####################

def assert_valid_x(x: t.Tensor) -> None:
    assert x.ndim == 4
    assert x.shape[1] == 1
    assert eq(*x.shape[2:])
    
def assert_valid_y(y: t.Tensor) -> None:
    assert y.ndim == 1
    assert y.min() >= 0
    assert y.max() <= 9

def acc_fn(
    logits: t.Tensor, y: t.Tensor, *, as_pct: bool = True, pct_round_digits: int = 2
) -> float:
    # assert_valid_y(y)
    preds = logits.argmax(-1)
    acc = (preds == y).to(dtype=t.float).mean().item()
    if as_pct:
        acc = round(100 * acc, pct_round_digits + 2)
    return acc

model = CNN()
train_logits = model(ds.train_x)
test_logits = model(ds.test_x)

print(f"Initial training accuracy: {acc_fn(train_logits, ds.train_y)}%")
print(f"Initial test accuracy: {acc_fn(test_logits, ds.test_y)}%")

Initial training accuracy: 9.9217%
Initial test accuracy: 10.1%


## Training

In [38]:
@dataclass(frozen=True, slots=True)
class TrainingResult:
    cfg: Config
    ds: Dataset
    model: CNN
    train_loss: list[float] = field(default_factory=list, kw_only=True)
    test_loss: list[float] = field(default_factory=list, kw_only=True)
    train_acc: list[float] = field(default_factory=list, kw_only=True)
    test_acc: list[float] = field(default_factory=list, kw_only=True)
    
    def __post_init__(self) -> None:
        assert len(self.train_loss) == len(self.test_loss) == len(self.train_acc) == len(self.test_acc)
    
    def __len__(self) -> int:
        return len(self.train_acc)
        
    def append(
        self,
        *,
        train_loss: float,
        test_loss: float,
        train_acc: float,
        test_acc: float
    ) -> None:
        self.train_loss.append(train_loss)
        self.test_loss.append(test_loss)
        self.train_acc.append(train_acc)
        self.test_acc.append(test_acc)
    
    def log_last(self, *, loss_digits: int = 4, acc_digits: int = 2) -> None:
        if len(self) == 0:
            print(f"[0]")
            return
        train_loss = round(self.train_loss[-1], loss_digits)
        test_loss = round(self.test_loss[-1], loss_digits)
        train_acc = round(100 * self.train_acc[-1], acc_digits + 2) if 0 <= self.train_acc[-1] <= 1 else self.train_acc[-1]
        test_acc = round(100 * self.test_acc[-1], acc_digits + 2) if 0 <= self.test_acc[-1] <= 1 else self.test_acc[-1]
        print(f"[{len(self)}]: {train_loss=}, {test_loss=}, train_acc={train_acc}%, test_acc={test_acc}%")    

def train(cfg: Config) -> TrainingResult:
    ds = Dataset.make(cfg)
    model = CNN(cfg)
    tr = TrainingResult(cfg, ds, model)
    optimizer = t.optim.AdamW(model.parameters()) #TODO: change/tweak/experiment with hyperparams?
    loss_fn = nn.CrossEntropyLoss()
    
    for epoch_i in range(cfg.n_epochs):
        # Training 
        train_logits = model(ds.train_x)
        train_loss = loss_fn(train_logits, ds.train_y)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        # Measure
        with t.inference_mode():
            train_logits = model(ds.train_x)
            train_loss = loss_fn(train_logits, ds.train_y).item()
            train_acc = acc_fn(train_logits, ds.train_y)
            test_logits = model(ds.test_x)
            test_loss = loss_fn(test_logits, ds.test_y).item()
            test_acc = acc_fn(test_logits, ds.test_y)
            tr.append(train_loss=train_loss, test_loss=test_loss, train_acc=train_acc, test_acc=test_acc)
        
    
    return tr

cfg = Config(1, 0, n_epochs=5)
tr = train(cfg)
tr.log_last()

[5]: train_loss=2.2226, test_loss=2.2239, train_acc=23.8333%, test_acc=23.84%


## Experiment

In [44]:
MODEL_SIZES: list[int] = list(range(1, 10))
DATASET_SIZES: list[int] = list(range(0, 10))

PARAMS = list(product(MODEL_SIZES, DATASET_SIZES))
N = len(PARAMS)
print(f"{N=}")

trs: list[TrainingResult] = []

for param_i, (model_size, dataset_size) in enumerate(PARAMS):
    print(f"[{param_i}/{N}] model: {model_size}, dataset: {dataset_size}")
    cfg = Config(model_size, dataset_size)
    tr = train(cfg)
    trs.append(tr)

N=90
[0/90] model: 1, dataset: 0
[1/90] model: 1, dataset: 1
[2/90] model: 1, dataset: 2
[3/90] model: 1, dataset: 3
[4/90] model: 1, dataset: 4
[5/90] model: 1, dataset: 5
[6/90] model: 1, dataset: 6
[7/90] model: 1, dataset: 7
[8/90] model: 1, dataset: 8
[9/90] model: 1, dataset: 9
[10/90] model: 2, dataset: 0
[11/90] model: 2, dataset: 1
[12/90] model: 2, dataset: 2
[13/90] model: 2, dataset: 3
[14/90] model: 2, dataset: 4
[15/90] model: 2, dataset: 5
[16/90] model: 2, dataset: 6
[17/90] model: 2, dataset: 7
[18/90] model: 2, dataset: 8
[19/90] model: 2, dataset: 9
[20/90] model: 3, dataset: 0
[21/90] model: 3, dataset: 1
[22/90] model: 3, dataset: 2
[23/90] model: 3, dataset: 3
[24/90] model: 3, dataset: 4
[25/90] model: 3, dataset: 5
[26/90] model: 3, dataset: 6
[27/90] model: 3, dataset: 7
[28/90] model: 3, dataset: 8
[29/90] model: 3, dataset: 9
[30/90] model: 4, dataset: 0
[31/90] model: 4, dataset: 1
[32/90] model: 4, dataset: 2
[33/90] model: 4, dataset: 3
[34/90] model: 4, d

## Experiment

In [96]:
MODEL_SIZES: list[int] = np.arange(3, 20, step=3).tolist()
DATASET_SIZES: list[float] = np.linspace(0.1, 1, 10).round(1).tolist()
print(f"{MODEL_SIZES = }\n{DATASET_SIZES = }")

class Result(NamedTuple):
    cfg: Config
    cnn: CNN
    train_loss: float
    test_loss: float
    train_acc: float
    test_acc: float

def acc_fn(logits: t.Tensor, y: t.Tensor) -> float:
    preds = logits.argmax(-1)
    acc = (preds == y).to(t.float).mean().item()
    return acc

def train(model_size: int, dataset_size: float) -> Result:
    # Setup
    cfg = Config(model_size=model_size, dataset_size=dataset_size)
    cnn = CNN(cfg)
    ds = Dataset.make(cfg)    
    optimizer = t.optim.AdamW(cnn.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    # Training
    train_logits = cnn(ds.train_x)
    train_loss = loss_fn(train_logits, ds.train_y)
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    # Eval
    with t.inference_mode():
        train_logits = cnn(ds.train_x)
        train_loss = loss_fn(train_logits, ds.train_y).item()
        test_logits = cnn(ds.test_x)
        test_loss = loss_fn(test_logits, ds.test_y).item()
        train_acc = acc_fn(train_logits, ds.train_y)
        test_acc = acc_fn(test_logits, ds.test_y)
    return Result(
        cfg=cfg,
        cnn=cnn,
        train_loss=train_loss,
        test_loss=test_loss,
        train_acc=train_acc,
        test_acc=test_acc
    )


MODEL_SIZES = [3, 6, 9, 12, 15, 18]
DATASET_SIZES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]


In [98]:
PARAMS = list(product(MODEL_SIZES, DATASET_SIZES))
N = len(PARAMS)

results: list[Result] = []
for i, (model_size, dataset_size) in enumerate(PARAMS):
    print(f"[{i}/{N}]: {model_size=}, {dataset_size=}")
    result = train(model_size, dataset_size)
    results.append(result)


[0/60]: model_size=3, dataset_size=0.1
[1/60]: model_size=3, dataset_size=0.2
[2/60]: model_size=3, dataset_size=0.3
[3/60]: model_size=3, dataset_size=0.4
[4/60]: model_size=3, dataset_size=0.5
[5/60]: model_size=3, dataset_size=0.6
[6/60]: model_size=3, dataset_size=0.7
[7/60]: model_size=3, dataset_size=0.8
[8/60]: model_size=3, dataset_size=0.9
[9/60]: model_size=3, dataset_size=1.0
[10/60]: model_size=6, dataset_size=0.1
[11/60]: model_size=6, dataset_size=0.2
[12/60]: model_size=6, dataset_size=0.3
[13/60]: model_size=6, dataset_size=0.4
[14/60]: model_size=6, dataset_size=0.5
[15/60]: model_size=6, dataset_size=0.6
[16/60]: model_size=6, dataset_size=0.7
[17/60]: model_size=6, dataset_size=0.8
[18/60]: model_size=6, dataset_size=0.9
[19/60]: model_size=6, dataset_size=1.0
[20/60]: model_size=9, dataset_size=0.1
[21/60]: model_size=9, dataset_size=0.2
[22/60]: model_size=9, dataset_size=0.3
[23/60]: model_size=9, dataset_size=0.4
[24/60]: model_size=9, dataset_size=0.5
[25/60]: m

KeyboardInterrupt: 

## Plot and analyze the results