# Create transit animation

## Notebook setup

In [None]:
import multiprocessing
import os

os.environ["XLA_FLAGS"] = (
    f"--xla_force_host_platform_device_count={multiprocessing.cpu_count()}"
)

In [1]:
# import libraries
import pathlib

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import HTML
from jax import numpy as jnp
from jax import random as jr
from jaxoplanet.light_curves import limb_dark_light_curve
from jaxoplanet.orbits import TransitOrbit
from matplotlib.animation import FuncAnimation
from tqdm import tqdm

from gallifrey.data import Dataset
from gallifrey.model import GPConfig, GPModel, unbatch_states

gallifrey: Setting flag `JAX_ENABLE_X64` to `True`
gallifrey: Setting flag `OMP_NUM_THREADS` to `1`


In [2]:
# notebook settings

# making the plots pretty
sns.set_theme(
    context="poster",
    style="ticks",
    palette="rocket",
    font_scale=1,
    rc={
        "figure.figsize": (16, 7),
        "axes.grid": False,
        "font.family": "serif",
        "text.usetex": False,
        "lines.linewidth": 5,
        # "axes.grid": True,
    },
)

# setting saving defaults
save_figures = True

# set saving paths
path = pathlib.Path.cwd().parent
figure_directory = path / "figures/animations/"
if not figure_directory.exists():
    figure_directory.mkdir(parents=True)

# set a random key for for this notebook
rng_key = jr.PRNGKey(77)

## Create data

In [3]:
# create transit model
def transit_model(t, params, period=10.0, duration=0.2, t0=0.0, impact_param=0.0):
    orbit = TransitOrbit(
        period=jnp.array(period),
        duration=jnp.array(duration),
        time_transit=jnp.array(t0),
        impact_param=jnp.array(impact_param),
        radius_ratio=params["r"],
    )
    return limb_dark_light_curve(orbit, jnp.array([params["u1"], params["u2"]]))(t)

In [4]:
key, noise_key = jr.split(rng_key, 2)

# generate the deterministic data
full_time = jnp.linspace(-0.8, 0.8, 1000)
background = 0.003 * (
    5 * full_time**2 + jnp.sin(20 * full_time) + 0.3 * jnp.cos(50 * full_time)
)

# generate white noise
white_noise_stddev = 0.001
white_noise = white_noise_stddev * jr.normal(noise_key, (len(full_time),))
# generate AR(1) noise
noise_auto_corr = 0.1
ar_noise = jnp.zeros(len(full_time))
ar_noise = ar_noise.at[0].set(white_noise[0])
for i in range(1, len(full_time)):
    ar_noise = ar_noise.at[i].set(noise_auto_corr * ar_noise[i - 1] + white_noise[i])

# generate the transit signal
transit_params = {"r": 0.1, "u1": 0.1, "u2": 0.3}
transit = 1.5 * transit_model(full_time, transit_params)

# generate the light curve
full_ight_curve = transit + background + ar_noise

# select a subset of the data as mock observations
num_train = 150
obs_idx = jnp.sort(jr.choice(rng_key, len(full_time), (num_train,), replace=False))
time = full_time[obs_idx]
light_curve = full_ight_curve[obs_idx]

# get transit mask
transit_mask = (time > -0.12) & (time < 0.12)

xtrain = time[~transit_mask]
ytrain = light_curve[~transit_mask]

## Get the `gallifrey` GP model

In [5]:
# create GP model instance
key, gallifrey_key = jr.split(rng_key)
gpmodel = GPModel(
    gallifrey_key,
    x=xtrain,
    y=ytrain,
    num_particles=8,
    config=GPConfig(),
)

In [6]:
# load the model
final_smc_state = gpmodel.load_state(
    str(path / "model_checkpoints/mcmc_comparison_learned_many_rounds/final_state.pkl")
)
history = gpmodel.load_state(
    str(path / "model_checkpoints/mcmc_comparison_learned_many_rounds/history.pkl")
)

## Create predictive distributions for every SMC round

In [7]:
xall = full_time[::2]
xall_norm = gpmodel.x_transform(xall)

In [8]:
history_states = unbatch_states(history)

means = []
lower = []
upper = []
datapoints = []
masks = []

for state in tqdm(history_states):
    gpmodel_hist = gpmodel.update_state(state)

    included_datapoints = state.num_data_points
    data_norm = Dataset(
        x=gpmodel.x_transform(xtrain[:included_datapoints]),
        y=gpmodel.y_transform(ytrain[:included_datapoints]),
    )

    predictive_gmm = gpmodel_hist.get_mixture_distribution(xall_norm, data=data_norm)
    means.append(predictive_gmm.mean())
    stddevs = predictive_gmm.stddev()
    lower.append(predictive_gmm.mean() - stddevs)
    upper.append(predictive_gmm.mean() + stddevs)
    datapoints.append(included_datapoints)
    masks.append(
        [xall < xtrain[included_datapoints], xall >= xtrain[included_datapoints]]
    )

100%|██████████| 129/129 [04:25<00:00,  2.06s/it]


## Make animation

In [9]:
# function to plot intervals with different colors
def plot_intervals(ax, x, mean, fill_lower, fill_upper, masks, colors):
    last_index = 0
    for i in range(len(masks)):
        mask = masks[i]
        color = colors[i]
        # get indices of mask, add last index from previous mask for smooth plotting
        indices = jnp.insert(jnp.where(mask)[0], 0, last_index)
        xm = x[indices]
        # plot means
        ax.plot(xm, mean[indices], color=color, linewidth=3)
        # plot stddevs
        ax.fill_between(
            xm,
            fill_lower[indices],
            fill_upper[indices],
            color=color,
            alpha=0.3,
        )
        ax.plot(xm, fill_lower[indices], color=color, linestyle="--", linewidth=2)
        ax.plot(xm, fill_upper[indices], color=color, linestyle="--", linewidth=2)
        last_index = indices[-1]  # Update last index

In [10]:
from matplotlib import font_manager

# add xkcd font
font_path = "/home/chris/.local/share/fonts/xkcd.otf"
font_manager.fontManager.addfont(font_path)
prop = font_manager.FontProperties(fname=font_path)

In [None]:
with plt.xkcd():
    # Set up the figure and axis with your seaborn theme
    fig, ax = plt.subplots()

    xmin = jnp.abs(full_time.min())

    # function to initialize the plot
    def init():
        ax.clear()
        # ax.set_xlabel("Time")
        # ax.set_ylabel("Flux")

        ax.set_xticks(jnp.linspace(-0.8 + xmin, 0.8 + xmin, 5))
        ax.set_yticks([-1, 0, 1])

        ax.set_ylim(-2.0, 1.4)
        ax.set_xlim(-0.8 + xmin, 0.8 + xmin)

        sns.scatterplot(
            x=(time + xmin)[~transit_mask],
            y=gpmodel.y_transform(light_curve)[~transit_mask],
            marker="$\circ$",
            color="grey",
            s=300,
            alpha=0.3,
            ax=ax,
        )
        sns.scatterplot(
            x=(time + xmin)[transit_mask],
            y=gpmodel.y_transform(light_curve)[transit_mask],
            color="grey",
            s=300,
            alpha=0.3,
            ax=ax,
        )

        # add note for transit
        ax.annotate(
            "Transit",
            xy=(xmin, -1.4),
            xytext=(xmin + 0.06, -1.8),
            fontsize=28,
            color="C0",
            alpha=0.8,
            arrowprops=dict(
                color="C0",
                alpha=0.8,
                arrowstyle="->",
                connectionstyle="arc3,rad=0.2",
            ),
            fontproperties=prop,
        )
        return []

    # function to update the plot for each frame
    def update(frame):
        init()

        # plot predictive mean and stddev
        plot_intervals(
            ax,
            xall + xmin,
            means[frame],
            lower[frame],
            upper[frame],
            masks[frame],
            ["C1", "C5"],
        )

        # plot observed data points
        included_dp = datapoints[frame]
        sns.scatterplot(
            x=xtrain[:included_dp] + xmin,
            y=gpmodel.y_transform(ytrain[:included_dp]),
            color="C0",
            s=300,
            ax=ax,
        )

        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        fig.tight_layout()
        return []

    # Create the animation
    anim = FuncAnimation(
        fig,
        update,
        frames=len(means),
        init_func=init,
        blit=True,
        interval=200,
    )

    # Save the animation if a path is provided
    if save_figures:
        writervideo = animation.FFMpegWriter(fps=10)
        anim.save(
            str(figure_directory / "transit_animation.webm"),
            writer=writervideo,
            dpi=100,
            savefig_kwargs={"transparent": True, "facecolor": "none"},
        )
        anim.save(
            str(figure_directory / "transit_animation.gif"),
            writer="pillow",
            fps=10,
            dpi=100,
        )

    # display
    plt.close()
HTML(anim.to_jshtml())

findfont: Font family 'xkcd Script' not found.
findfont: Font family 'Comic Neue' not found.
findfont: Font family 'Comic Sans MS' not found.
findfont: Font family 'xkcd Script' not found.
findfont: Font family 'Comic Neue' not found.
findfont: Font family 'Comic Sans MS' not found.
findfont: Font family 'xkcd Script' not found.
findfont: Font family 'Comic Neue' not found.
findfont: Font family 'Comic Sans MS' not found.
findfont: Font family 'xkcd Script' not found.
findfont: Font family 'Comic Neue' not found.
findfont: Font family 'Comic Sans MS' not found.
findfont: Font family 'xkcd Script' not found.
findfont: Font family 'Comic Neue' not found.
findfont: Font family 'Comic Sans MS' not found.
findfont: Font family 'xkcd Script' not found.
findfont: Font family 'Comic Neue' not found.
findfont: Font family 'Comic Sans MS' not found.
findfont: Font family 'xkcd Script' not found.
findfont: Font family 'Comic Neue' not found.
findfont: Font family 'Comic Sans MS' not found.
findfo