In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import logging
from functools import partial

import jax
from jax import random
from jax import numpy as jnp
import wandb
import matplotlib
import matplotlib.pyplot as plt

from src.models import make_PoG_Ens_loss as make_loss
from src.models import make_PoG_plots as make_plots
import src.data
from src.data import NumpyLoader
from src.utils.training import setup_training, train_loop
from experiments.configs.toy_pog_ens import get_config
from src.models.pog import calculate_pog_loc_scale

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'train_pon_ens.ipynb'
# ^ W&B doesn't know how to handle VS Code notebooks.

wandb.login()

In [None]:
rng = random.PRNGKey(0)

In [None]:
config = get_config()

In [None]:
data_gen_fn = getattr(src.data, config.dataset_name)
train_dataset, test_dataset, val_dataset = data_gen_fn(**config.dataset.to_dict())
train_loader = NumpyLoader(train_dataset, config.batch_size)
val_loader = NumpyLoader(val_dataset, config.batch_size)
test_loader = NumpyLoader(test_dataset, config.batch_size)

In [None]:
setup_rng, rng = random.split(rng)
init_x = train_dataset[0][0]
init_y = train_dataset[0][1]

model, state = setup_training(config, setup_rng, init_x, init_y)

In [None]:
state = train_loop(
    model, state, config, rng, make_loss, make_loss, train_loader, val_loader,
    # test_loader,
    wandb_kwargs={
        'mode': 'offline',
        # 'notes': '',
    },
    plot_fn=make_plots,
)

## Paper Plots

In [None]:
text_width = 6.75133 # in  --> Confirmed with template explanation
line_width = 3.25063
dpi = 400

fs_m1 = 7  # for figure ticks
fs = 8  # for regular figure text
fs_p1 = 9 #  figure titles

matplotlib.rc('font', size=fs)          # controls default text sizes
matplotlib.rc('axes', titlesize=fs)     # fontsize of the axes title
matplotlib.rc('axes', labelsize=fs)    # fontsize of the x and y labels
matplotlib.rc('xtick', labelsize=fs_m1)    # fontsize of the tick labels
matplotlib.rc('ytick', labelsize=fs_m1)    # fontsize of the tick labels
matplotlib.rc('legend', fontsize=fs_m1)    # legend fontsize
matplotlib.rc('figure', titlesize=fs_p1)  # fontsize of the figure title


matplotlib.rc('font', **{'family':'serif', 'serif': ['Palatino']})
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [None]:
xs = jnp.linspace(-2.25, 2.25, num=501)

# pog preds
pred_fun = partial(
    model.apply,
    {"params": state.params, **state.model_state},
    train=False, return_ens_preds=True,
    method=model.pred
)
_, (locs, scales) = jax.vmap(
    pred_fun, out_axes=(0, 1), in_axes=(0,), axis_name="batch"
)(xs.reshape(-1, 1))

X_train, y_train = list(zip(*train_loader.dataset))

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(text_width, text_width/2.4), dpi=dpi, sharey=True, sharex=True, layout='tight')

ORDER = [jnp.array([0, 3, 2, 4, 1]), jnp.array([4, 3, 2, 1, 0])]

for r, axrow in enumerate(axs):
    for c, ax in enumerate(axrow):
        ax.scatter(X_train, y_train, c='k', s=1, lw=0.5, alpha=0.5)

        for j in range(c + 2, config.model.size):
            ax.plot(xs, locs[ORDER[r][j]] + scales[ORDER[r][j]], '--', c=f'C{ORDER[r][j] + 2}', alpha=0.4, lw=0.4)
            ax.plot(xs, locs[ORDER[r][j]] - scales[ORDER[r][j]], '--', c=f'C{ORDER[r][j] + 2}', alpha=0.4, lw=0.4)

        if c + 1 < len(axrow):
            ax.fill_between(
                xs,
                locs[ORDER[r][c+1], :, 0] - scales[ORDER[r][c+1], :, 0],
                locs[ORDER[r][c+1], :, 0] + scales[ORDER[r][c+1], :, 0],
                color='C1', alpha=0.15, lw=0.1
            )

        loc, scale = calculate_pog_loc_scale(locs[ORDER[r][:c+1], :, 0], scales[ORDER[r][:c+1], :, 0])
        ax.plot(xs, loc, c='C0', lw=1, alpha=0.5)
        ax.fill_between(xs, loc - scale, loc + scale, color='C0', alpha=0.4, lw=0.1)

        ax.set_ylim(-2.25, 2.25)
        ax.set_xlim(-2.25, 2.25)
        ax.tick_params(bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)
        # ax.grid(0.3)

plt.savefig('toy_evolution.pdf', dpi=dpi, bbox_inches='tight')