In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import logging
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import wandb
import matplotlib
import matplotlib.pyplot as plt

import src.data
from src.models import make_Hard_OvR_Ens_loss as make_loss
from src.models import make_Hard_OvR_Ens_toy_plots as make_plots
from src.data import NumpyLoader
from src.utils.training import setup_training, train_loop
from experiments.configs.spirals_hard_ovr_classification import get_config

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'train_hard_ivr_classifier_spirals.ipynb'
# ^ W&B doesn't know how to handle VS Code notebooks.

wandb.login()

In [None]:
rng = random.PRNGKey(0)

In [None]:
config = get_config()

In [None]:
data_gen_fn = getattr(src.data, config.dataset_name)
train_dataset, test_dataset, val_dataset = data_gen_fn(**config.dataset.to_dict())
train_loader = NumpyLoader(train_dataset, config.batch_size)
val_loader = NumpyLoader(val_dataset, config.batch_size)
test_loader = NumpyLoader(test_dataset, config.batch_size)

In [None]:
X_train, y_train = zip(*train_loader.dataset)
X_train, y_train = jnp.array(X_train), jnp.array(y_train)

for i in range(config.n_classes):
    idxs = (y_train == i)

    plt.plot(X_train[idxs, 0], X_train[idxs, 1], '.', c=f'C{i}', alpha=1, ms=1)

In [None]:
setup_rng, rng = random.split(rng)
init_x = train_dataset[0][0]
init_y = train_dataset[0][1]

model, state = setup_training(config, setup_rng, init_x, init_y)

In [None]:
state, best_state = train_loop(
    model, state, config, rng, make_loss, make_loss, train_loader, val_loader,
    # test_loader,
    wandb_kwargs={
        'mode': 'offline',
        # 'notes': '',
    },
    plot_fn=make_plots
)

## Paper plot

In [None]:
text_width = 6.75133 # in  --> Confirmed with template explanation
line_width = 3.25063
dpi = 200

fs_m1 = 7  # for figure ticks
fs = 8  # for regular figure text
fs_p1 = 9 #  figure titles

matplotlib.rc('font', size=fs)          # controls default text sizes
matplotlib.rc('axes', titlesize=fs)     # fontsize of the axes title
matplotlib.rc('axes', labelsize=fs)    # fontsize of the x and y labels
matplotlib.rc('xtick', labelsize=fs_m1)    # fontsize of the tick labels
matplotlib.rc('ytick', labelsize=fs_m1)    # fontsize of the tick labels
matplotlib.rc('legend', fontsize=fs_m1)    # legend fontsize
matplotlib.rc('figure', titlesize=fs_p1)  # fontsize of the figure title


matplotlib.rc('font', **{'family':'serif', 'serif': ['Palatino']})
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(text_width, text_width/4.6), dpi=dpi, sharey=True, sharex=True, layout='tight')

hard_ovr_params, hard_ovr_model_state = best_state.params, best_state.model_state

n_class = int(y_train.max()) + 1

# hard_ovr preds
h = .05  # step size in the mesh
# create a mesh to plot in
x_min, x_max = X_train[:, 0].min() * 1.25, X_train[:, 0].max() * 1.25
y_min, y_max = X_train[:, 1].min() * 1.25, X_train[:, 1].max() * 1.25
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                        np.arange(y_min, y_max, h))
xs = np.c_[xx.ravel(), yy.ravel()]

pred_fun = partial(
    model.apply,
    {"params": hard_ovr_params, **hard_ovr_model_state},
    train=False, return_ens_preds=True, β=best_state.β, hard_pred = True,
    method=model.pred
)

_, ens_preds = jax.vmap(
    pred_fun, out_axes=(0, 1), in_axes=(0,), axis_name="batch"
)(xs)



colormaps = ['Blues', 'Oranges', 'Greens', 'Reds', 'Purples']

for j in range(ens_preds.shape[0]):
    preds = ens_preds[:j+1].prod(axis=0)

    for i in range(n_class):
        f = preds[:, i].reshape(xx.shape)
        step = 0.05
        max = np.amax(f)
        min = np.amin(f)
        levels = np.arange(min, max, step) + step
        axs[j].contourf(xx, yy, f, levels, alpha=0.25, cmap=colormaps[i], antialiased=True)

    # for i in range(depth + 1):
    #     axs[j].contour(xx, yy, ens_preds[:, i, 0].reshape(xx.shape), cmap=plt.cm.gray, levels=[.5], alpha=0.3)

    markers = ['o', 'v', 's', 'P', 'X']
    for i in range(n_class):
        idxs = (y_train == i)
        axs[j].plot(X_train[idxs, 0], X_train[idxs, 1], markers[0], c=f'C{i}', alpha=1, ms=0.1)
    axs[j].tick_params(bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)

plt.savefig('spiral_evolution.pdf', dpi=dpi, bbox_inches='tight')    