## Import

In [1]:
from dataclasses import dataclass, field, asdict
import itertools as it
import math
import pickle
import random
from typing import NamedTuple

import numpy as np
import pandas as pd
from plotly import express as px, graph_objects as go
import torch as t
from torch import nn
from tqdm import tqdm

from src import CNN, Dataset

## Training

We vary the model size and dataset size.

For the **model size**, we multiply the width by powers of $\sqrt2$, rounding down if necessary. The idea is to vary the amount of compute used per forward pass by powers of $2$.

For the **dataset size**, we multiply the fraction of the full dataset used by powers of $2$, i.e. $1$, $\frac12$, $\frac14$, and so on.

To reduce noise, use a few random seeds and always use the full validation set.

In [2]:
# Model size
REFERENCE_MODEL_SIZE = 6
N_MODEL_SIZES = 8
MODEL_SIZES = (REFERENCE_MODEL_SIZE * math.sqrt(2) ** t.arange(N_MODEL_SIZES)).to(dtype=t.int64).tolist()

# Dataset size
N_DATASET_SIZES = 10
DATASET_FRACTIONS =  (1 / (2 ** t.arange(N_DATASET_SIZES))).tolist()

# Seeds
MASTER_SEED = 42
N_SEEDS = N_MODEL_SIZES * N_DATASET_SIZES
random.seed(MASTER_SEED)
SEEDS = random.sample(range(10 * N_SEEDS), k=N_SEEDS)

# Check
print(f"{MODEL_SIZES = }")
print(f"{DATASET_FRACTIONS = }")
print(f"{SEEDS = }")

MODEL_SIZES = [6, 8, 11, 16, 23, 33, 47, 67]
DATASET_FRACTIONS = [1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625, 0.001953125]
SEEDS = [654, 114, 25, 759, 281, 250, 228, 142, 754, 104, 692, 758, 558, 89, 604, 432, 32, 30, 95, 223, 238, 517, 616, 27, 574, 203, 733, 665, 718, 429, 225, 459, 603, 284, 6, 777, 163, 714, 348, 159, 220, 781, 344, 94, 389, 99, 367, 352, 618, 270, 44, 747, 470, 549, 127, 387, 80, 565, 300, 643, 633, 370, 591, 196, 721, 71, 46, 677, 233, 791, 296, 81, 103, 464, 650, 373, 166, 379, 363, 214]


In [3]:
@dataclass(frozen=True, slots=True)
class TrainingResult:
    index: int
    model_size: int
    dataset_fraction: float
    seed: int
    train_loss: float
    test_loss: float
    train_acc: float
    test_acc: float

def acc_fn(
    logits: t.Tensor, y: t.Tensor, *, as_pct: bool = True, pct_round_digits: int = 2
) -> float:
    preds = logits.argmax(-1)
    acc = (preds == y).to(dtype=t.float).mean().item()
    if as_pct:
        acc = round(100 * acc, pct_round_digits + 2)
    return acc

def train(index: int, model_size: int, dataset_fraction: float, seed: int) -> TrainingResult:
    random.seed(seed)
    t.manual_seed(seed)
    model = CNN(model_size)
    ds = Dataset.load(dataset_fraction)
    LR = 1e-3
    optimizer = t.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    
    # Train for one epoch
    train_logits = model(ds.train_x)
    train_loss = loss_fn(train_logits, ds.train_y)
    train_loss.backward()
    optimizer.step()
    
    # Measure
    with t.no_grad():
        # Training set 
        train_logits = model(ds.train_x)
        train_loss = loss_fn(train_logits, ds.train_y).item()
        train_acc = acc_fn(train_logits, ds.train_y)
        
        # Test set
        test_logits = model(ds.test_x)
        test_loss = loss_fn(test_logits, ds.test_y).item()
        test_acc = acc_fn(test_logits, ds.test_y)
    
    t.save(model, f"models/model_{index}.pt")
    
    return TrainingResult(
        index,
        model_size,
        dataset_fraction,
        seed,
        train_loss,
        test_loss,
        train_acc,
        test_acc
    )

In [4]:
model = CNN(MODEL_SIZES[0])
random.seed(SEEDS[0])
ds = Dataset.load(DATASET_FRACTIONS[0])
train_logits = model(ds.train_x)
test_logits = model(ds.test_x)

print(f"Initial training accuracy: {acc_fn(train_logits, ds.train_y)}%")
print(f"Initial test accuracy: {acc_fn(test_logits, ds.test_y)}%")

Initial training accuracy: 10.4433%
Initial test accuracy: 10.27%


In [5]:
Index = int
ModelSize = int
DatasetFraction = float
Seed = int

Param = tuple[Index, ModelSize, DatasetFraction, Seed]

PARAMS: list[Param] = [
    (index, model_size, dataset_fraction, seed)
    for index, ((model_size, dataset_fraction), seed)  in
    enumerate(zip(it.product(MODEL_SIZES, DATASET_FRACTIONS), SEEDS, strict=True))
]

In [6]:
# results: dict[Param, TrainingResult] = {
#     param: train(*param) for param in tqdm(PARAMS)
# }

# with open("results.pkl", "wb") as f:
#     pickle.dump(results, f)

with open("results.pkl", "rb") as f:
    results = pickle.load(f)


## Plot results

In [41]:
df = pd.DataFrame([asdict(tr) for tr in results.values()])

df["dataset_fraction_pow"] = df["dataset_fraction"].map(lambda frac: f"2^{int(math.log2(frac))}")

heatmap_cols = [col for col in df.columns if col.endswith("loss") or col.endswith("acc")]
print(f"{heatmap_cols = }")

ROUND_DECIMALS = 3

for heatmap_col in heatmap_cols:
    df[heatmap_col] = df[heatmap_col].round(decimals=ROUND_DECIMALS)

df.head()

heatmap_cols = ['train_loss', 'test_loss', 'train_acc', 'test_acc']


Unnamed: 0,index,model_size,dataset_fraction,seed,train_loss,test_loss,train_acc,test_acc,dataset_fraction_pow
0,0,6,1.0,654,2.286,2.287,10.477,10.58,2^0
1,1,6,0.5,114,2.295,2.295,9.977,9.95,2^-1
2,2,6,0.25,25,2.297,2.297,10.773,10.8,2^-2
3,3,6,0.125,759,2.299,2.299,11.12,11.17,2^-3
4,4,6,0.0625,281,2.305,2.306,10.96,10.6,2^-4


In [42]:
heatmap_dfs = {
    heatmap_col: df.pivot(
        index="model_size",
        columns="dataset_fraction_pow",
        values=heatmap_col
    ).rename(index=str, columns=str)
    for heatmap_col in heatmap_cols
}

In [43]:
for heatmap_col, heatmap_df in heatmap_dfs.items():
    fig = px.imshow(heatmap_df, title=heatmap_col, text_auto=True, height=800, width=800)
    fig.show()

### Log scale compute against test loss

It's supper fucking noisy but yeah...

In [49]:
df["compute"] = df["dataset_fraction"] * df["model_size"]
df["log_compute"] = df["compute"].map(math.log)
fig = px.line(df.sort_values("log_compute"), x="log_compute", y="test_loss")
fig.show()

So let's try model size and dataset fraction separately

In [54]:
df["log_model_size"] = df["model_size"].map(math.log)
fig = px.line(df.groupby("log_model_size")["test_loss"].mean())
fig.show()

In [55]:
df["log_dataset_fraction"] = df["dataset_fraction"].map(math.log)
fig = px.line(df.groupby("log_dataset_fraction")["test_loss"].mean())
fig.show()

So apparently "compute" means ~model size?

Let's check validation accuracy.

In [56]:
fig = px.line(df.groupby("log_model_size")["test_acc"].mean())
fig.show()

Interesting..., apparently starts overfitting at the biggest model.