In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline

In [None]:
import os
import logging

from jax import random
from torchvision import transforms
import wandb

from src.models import make_Hard_OvR_Ens_loss as make_loss
from src.models import make_Hard_OvR_Ens_MNIST_plots as make_plots
from src.models import make_Cls_Ens_loss as make_pretrain_loss
from src.data import get_image_dataset, NumpyLoader
from src.utils.training import setup_training, train_loop
from experiments.configs.mnist_hard_ovr_classification import get_config

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'train_hard_ovr_classifier.ipynb'
# ^ W&B doesn't know how to handle VS Code notebooks.

wandb.login()

In [None]:
rng = random.PRNGKey(0)

In [None]:
config = get_config()

In [None]:
train_dataset, test_dataset, val_dataset = get_image_dataset(
    dataset_name=config.dataset_name,
    val_percent=config.val_percent,
    flatten_img=True,
    train_augmentations=[
        # transforms.RandomCrop(28, padding=2),
        # transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=15),
        # transforms.RandomHorizontalFlip(),
        # transforms.
    ]
)
train_loader = NumpyLoader(train_dataset, config.batch_size)
val_loader = NumpyLoader(val_dataset, config.batch_size)
test_loader = NumpyLoader(test_dataset, config.batch_size)

In [None]:
pretrain_config = config.copy_and_resolve_references()
pretrain_config.model_name = 'Cls_Ens'
del pretrain_config.β_schedule

In [None]:
setup_rng, rng = random.split(rng)
init_x = train_dataset[0][0]
init_y = train_dataset[0][1]

model, state = setup_training(pretrain_config, setup_rng, init_x, init_y)

In [None]:
_, pre_train_state = train_loop(
    model, state, pretrain_config, rng, make_pretrain_loss, make_pretrain_loss, train_loader, val_loader,
    # test_loader,
    wandb_kwargs={
        'mode': 'offline',
        # 'notes': 'Data augmentation',
        # 'tags': ['MNIST testing'],
    },
    # plot_fn=make_plots,
    # plot_freq=1,
)

In [None]:
setup_rng, rng = random.split(rng)
init_x = train_dataset[0][0]
init_y = train_dataset[0][1]

model, state = setup_training(config, setup_rng, init_x, init_y)

In [None]:
state.replace(params=pre_train_state.params)
# Also replace BN (model) state?

In [None]:
_, state = train_loop(
    model, state, config, rng, make_loss, make_loss, train_loader, val_loader,
    # test_loader,
    wandb_kwargs={
        # 'mode': 'offline',
        'notes': 'pre-trained',
        'tags': ['MNIST testing', 'pre-trained'],
    },
    # plot_fn=make_plots,
    # plot_freq=1,
)