# Create quickstart animation

## Notebook setup

In [None]:
import multiprocessing
import os

os.environ["XLA_FLAGS"] = (
    f"--xla_force_host_platform_device_count={multiprocessing.cpu_count()}"
)

In [None]:
# import libraries
import pathlib

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.display import HTML
from jax import numpy as jnp
from jax import random as jr
from matplotlib.animation import FuncAnimation
from tqdm import tqdm

from gallifrey.data import Dataset
from gallifrey.model import GPConfig, GPModel, unbatch_states

In [None]:
# notebook settings

# making the plots pretty
sns.set_theme(
    context="poster",
    style="ticks",
    palette="rocket",
    font_scale=1,
    rc={
        "figure.figsize": (16, 7),
        "axes.grid": False,
        "font.family": "serif",
        "text.usetex": False,
        "lines.linewidth": 5,
        # "axes.grid": True,
    },
)

# setting saving defaults
save_figures = True

# set saving paths
path = pathlib.Path.cwd().parent
figure_directory = path / "figures/animations/"
if not figure_directory.exists():
    figure_directory.mkdir(parents=True)

# set a random key for for this notebook
rng_key = jr.PRNGKey(0)

## Create data

In [None]:
# Mock data
key, data_key = jr.split(rng_key)
n = 160
noise_var = 9.0
x = jnp.linspace(0, 15, n)
y = (x + 0.01) * jnp.sin(x * 3.2) + jnp.sqrt(noise_var) * jr.normal(data_key, (n,))


# mask values
xtrain = x[(x < 10)]
ytrain = y[(x < 10)]

## Initialize the GP Model

In [None]:
config = GPConfig()

key, model_key = jr.split(key)
gpmodel = GPModel(
    model_key,
    x=xtrain,
    y=ytrain,
    num_particles=8,
    config=config,
)

In [None]:
# load the model
final_smc_state = gpmodel.load_state(
    str(path / "model_checkpoints/quickstart_example/final_state.pkl")
)
history = gpmodel.load_state(
    str(path / "model_checkpoints/quickstart_example/history.pkl")
)

## Create predictive distributions for every SMC round

In [None]:
xall = x
xall_norm = gpmodel.x_transform(xall)

In [None]:
history_states = unbatch_states(history)

means = []
lower = []
upper = []
datapoints = []
masks = []

for state in tqdm(history_states):
    gpmodel_hist = gpmodel.update_state(state)

    included_datapoints = state.num_data_points
    data_norm = Dataset(
        x=gpmodel.x_transform(xtrain[:included_datapoints]),
        y=gpmodel.y_transform(ytrain[:included_datapoints]),
    )

    predictive_gmm = gpmodel_hist.get_mixture_distribution(xall_norm, data=data_norm)
    means.append(predictive_gmm.mean())
    stddevs = predictive_gmm.stddev()
    lower.append(predictive_gmm.mean() - stddevs)
    upper.append(predictive_gmm.mean() + stddevs)
    datapoints.append(included_datapoints)
    masks.append(
        [xall < xtrain[included_datapoints], xall >= xtrain[included_datapoints]]
    )

## Make animation

In [None]:
# function to plot intervals with different colors
def plot_intervals(ax, x, mean, fill_lower, fill_upper, masks, colors):
    last_index = 0
    for i in range(len(masks)):
        mask = masks[i]
        color = colors[i]
        # get indices of mask, add last index from previous mask for smooth plotting
        indices = jnp.insert(jnp.where(mask)[0], 0, last_index)
        xm = x[indices]
        # plot means
        ax.plot(xm, mean[indices], color=color, linewidth=3)
        # plot stddevs
        ax.fill_between(
            xm,
            fill_lower[indices],
            fill_upper[indices],
            color=color,
            alpha=0.3,
        )
        ax.plot(xm, fill_lower[indices], color=color, linestyle="--", linewidth=2)
        ax.plot(xm, fill_upper[indices], color=color, linestyle="--", linewidth=2)
        last_index = indices[-1]  # Update last index

In [None]:
from matplotlib import font_manager

# add xkcd font
font_path = "/home/chris/.local/share/fonts/xkcd.otf"
font_manager.fontManager.addfont(font_path)
prop = font_manager.FontProperties(fname=font_path)