# TODO

- [ ] Figure out how to do hyperparam search
- [ ] Move code into separate `.py` files
- [ ] Run experiments with 4 optimizers (SGD, HeavyBall, Adam, Adam w/o momentum) and 2 datasets (MNIST, FashionMNIST)

In [1]:
from __future__ import annotations

from datetime import datetime
import itertools as it
import json
from pathlib import Path
from types import NoneType
from typing import Literal, TypedDict

from plotly import express as px, graph_objects as go
import torch as t
from torch.optim.lr_scheduler import LRScheduler

from src.simple_cnn import SimpleCNN, DatasetSimpleCNN
from src.utils import split_into_batches

# from kfac.preconditioner import KFACPreconditioner

  from kfac.distributed import get_rank


## Model, data, and training setup

In [2]:
model = SimpleCNN()
ds_m = DatasetSimpleCNN.load() # MNIST
ds_f = DatasetSimpleCNN.load(fashion_mnist=True) # FashionMNIST
BATCH_SIZES: list[int] = (2 ** t.arange(4, 14)).tolist()

In [3]:
TARGET_ACCURACY = 0.992 # Target accuract on the val/test set after which training terminates
MAX_N_STEPS = 2 ** 14 # Maximum number of steps (batches) after which training terminates
ACC_MEASURE_FREQ = 100 # How often to measure accuracy, each 100 steps (batches) by default

def acc_fn(
    logits: t.Tensor,
    target: t.Tensor,
) -> float:
    preds = logits.argmax(-1)
    acc = (preds == target).to(dtype=t.float).mean().item()
    return acc

class TrainingResult(TypedDict):
    timestamp: str
    # names
    model_name: str
    dataset_name: str
    optimizer_name: OptimizerType
    scheduler_name: str
    # numbers
    batch_size: int
    target_accuracy: float
    acc_measure_freq: int
    train_losses: list[float]
    test_accuracies: list[float]
    steps_to_target: int

OptimizerType = Literal[
    "SGD",
    "HeavyBall",
    "Adam",   # with momentum
    "Adam-m", # without momentum
    "K-FAC",  # with momentum
    "K-FAC-m" # without momentum
]

def get_optimizer_name(
    optimizer: t.optim.Optimizer,
    preconditioner: KFACPreconditioner | None = None
) -> OptimizerType:
    assert isinstance(optimizer, (t.optim.SGD, t.optim.Adam)), f"Invalid optimizer: type={optimizer.__class__.__name__}"
    assert isinstance(preconditioner, (NoneType, KFACPreconditioner)), f"Invalid {preconditioner=}"
    if isinstance(optimizer, t.optim.SGD):
        if optimizer.param_groups[0]["momentum"] == 0:
            return "SGD" if preconditioner is None else "K-FAC-m"
        else:
            return "HeavyBall" if preconditioner is None else "K-FAC"
    if isinstance(optimizer, t.optim.Adam):
        if optimizer.param_groups[0]["momentum"] != 0:
            return "Adam"
        else:
            return "Adam-m"

def train(
    model: SimpleCNN,
    ds: DatasetSimpleCNN,
    optimizer: t.optim.Optimizer, #: OptimizerType
    batch_size: int,
    *,
    preconditioner: KFACPreconditioner | None = None,
    target_accuracy: float = TARGET_ACCURACY,
    acc_measure_freq: int = ACC_MEASURE_FREQ,
    max_n_steps: int = MAX_N_STEPS,
    verbose: bool = True,
    scheduler: LRScheduler | None = None
) -> tuple[SimpleCNN, TrainingResult]:
    loss_fn = t.nn.CrossEntropyLoss()
    
    train_x_batches = split_into_batches(ds.train_x, batch_size)#.tolist()
    train_y_batches = split_into_batches(ds.train_y, batch_size)#.tolist()
    batch_iter = it.cycle(zip(train_x_batches, train_y_batches))
    
    train_losses: list[float] = []
    test_accuracies: list[float] = []
    
    for step_i, (batch_x, batch_y) in enumerate(batch_iter):
        if step_i >= max_n_steps:
            break
        optimizer.zero_grad()
        train_logits = model(batch_x)
        train_loss = loss_fn(train_logits, batch_y)
        train_loss.backward()
        if preconditioner:
            preconditioner.step()
        optimizer.step()
        if scheduler:
            scheduler.step()
        train_losses.append(train_loss.item())
        
        
        if step_i % acc_measure_freq == 0:
            with t.no_grad():
                test_logits = model(ds.test_x)
                test_acc = acc_fn(test_logits, ds.test_y)
                test_accuracies.append(test_acc)
            if verbose:
                running_loss = t.tensor(train_losses[-acc_measure_freq:]).mean().item()
                print(f"Step {step_i}: {running_loss=:.4f}, {test_acc=:.2%}")
            if test_acc >= target_accuracy:
                print(f"{target_accuracy=} achieved after {step_i} steps")
    
    timestamp = datetime.now().isoformat("T", "minutes").replace(":", "")
    optimizer_name = get_optimizer_name(optimizer)
    tr: TrainingResult = {
        "timestamp": timestamp,
        "model_name": "SimpleCNN",
        "dataset_name": "FashionMNIST" if ds.fashion_mnist else "MNIST",
        "optimizer_name": optimizer_name,
        "scheduler_name": scheduler.__class__.__name__, #TODO instead of scheduler name, should be params of scheduler (since it's always Linear I think?) and optimizer
        "batch_size": batch_size,
        "target_accuracy": target_accuracy,
        "acc_measure_freq": acc_measure_freq,
        "train_losses": train_losses,
        "test_accuracies": test_accuracies,
        "steps_to_target": len(train_losses),
    }
    
    path = Path("results")
    if not path.exists():
        path.mkdir()
    suffix = f"{optimizer_name}_{tr['dataset_name'][0].lower()}_{timestamp}"
    
    model_filename = f"model_{suffix}.pt"
    t.save(model, path / model_filename)
    
    tr_filename = f"tr_{suffix}.json" 
    with open(path / tr_filename, "w", encoding="utf-8") as f:
        json.dump(tr, f)
    
    return model, tr

## Experiments

### Optimizer: SGD

In [4]:
model = SimpleCNN()
#TODO tune LR (and other hyperparams for other optimizers)
lr = 1e-3
optimizer = t.optim.SGD(model.parameters(), lr)
#TODO fine-tune scheduler
scheduler = t.optim.lr_scheduler.LinearLR(optimizer)
# trs: dict[int, TrainingResult] = {
#     batch_size: train(model, ds_m,)
# }
tr = train(model, ds_m, optimizer, BATCH_SIZES[0])


Step 0: running_loss=6.9812, test_acc=0.00%
Step 100: running_loss=6.5339, test_acc=18.16%


KeyboardInterrupt: 

### Optimizer: HeavyBall

In [17]:
model = SimpleCNN()
lr = 1e-3
optimizer = t.optim.SGD(model.parameters(), lr, momentum=0.8)
tr = train(model, ds_m, optimizer, BATCH_SIZES[0])

Step 0: running_loss=7.0442, test_acc=0.00%
Step 100: running_loss=2.5846, test_acc=70.55%
Step 200: running_loss=0.6326, test_acc=87.40%
Step 300: running_loss=0.4278, test_acc=81.19%


KeyboardInterrupt: 

### Optimizer: Adam with momentum
    

In [None]:
model = SimpleCNN()
lr = 1e-3
optimizer = t.optim.Adam(model.parameters(), lr)
tr = train(model, ds_m, optimizer, BATCH_SIZES[0])

### Optimizer: Adam without momentum

In [5]:
model = SimpleCNN()
lr = 1e-3
optimizer = t.optim.Adam(model.parameters(), lr, betas=(0, 0))
tr = train(model, ds_m, optimizer, BATCH_SIZES[0])

Step 0: running_loss=6.8807, test_acc=22.50%
Step 100: running_loss=0.9023, test_acc=85.82%


KeyboardInterrupt: 