# Basic Tutorial

In [None]:
import matplotlib.pyplot as plt
import corner
import numpy as np
import jax
import jax.numpy as jnp

import diffmahnet

## Generate fake data
- (N, 5) array of MAH unbound parameters
- (N, 2) array of conditional variables $M_{\rm obs}$ and $t_{\rm obs}$

In [None]:
randkey = jax.random.key(0)
keys = jax.random.split(randkey, 6)

ndata = 1000

# m_obs and t_obs
fake_conditions = jax.random.normal(keys[0], (ndata, 2)) + 1.5

# Apply some dependence to M_obs and t_obs on the MAH parameters
def gen_uparams(key, condition):
    fake_mah_uparams = jax.random.uniform(key, (condition.shape[0], 5)) + 3.0
    fake_mah_uparams = fake_mah_uparams * condition[:, 0:1] ** 2
    fake_mah_uparams = fake_mah_uparams * condition[:, 1:2] ** 3
    return fake_mah_uparams

fake_mah_uparams = gen_uparams(keys[1], fake_conditions)
scaler = diffmahnet.ScalerHolder.compute(fake_mah_uparams, fake_conditions)

## Create a small flow model with only 914 parameters

In [None]:
flow = diffmahnet.DiffMahFlow(scaler, nn_depth=2, nn_width=50, flow_layers=8)
flow.get_params().size

## Train the model to the fake data we generated above

In [None]:
res = flow.init_fit(
    fake_mah_uparams, fake_conditions, randkey=keys[2], max_epochs=100, max_patience=100)

## Optionally, save the trained model and reload it later

In [None]:
flow.save("fake_model.eqx")

In [None]:
same_flow = diffmahnet.DiffMahFlow.load("fake_model.eqx")
jnp.all(same_flow.get_params() == flow.get_params())

## Make predictive samples from our flow model

In [None]:
test_conditions = jax.random.normal(keys[3], (ndata * 100, 2)) + 1.5
test_uparams = gen_uparams(keys[4], test_conditions)
test_conditions_vs_param1 = np.concatenate(
    [test_conditions, test_uparams[:, 0:1]], axis=1)

# Generate samples, given the new "test" values of m_obs and t_obs
flow_mah_uparams = flow.sample(test_conditions, keys[5])
flow_conditions_vs_param1 = np.concatenate(
    [test_conditions, flow_mah_uparams[:, 0:1]], axis=1)

In [None]:
# Plot the rough agreement between the test and flow prediction distributions
fig = corner.corner(
    test_conditions_vs_param1, labels=["M_obs", "t_obs", "param1"],
    show_titles=True, title_kwargs={"fontsize": 12},
    quantiles=[0.16, 0.5, 0.84], fill_contours=True,
    levels=(0.68, 0.95, 0.995), plot_datapoints=False, color="C0", alpha=0.1)
corner.corner(
    flow_conditions_vs_param1, labels=["M_obs", "t_obs", "param1"], fig=fig,
    show_titles=True, title_kwargs={"fontsize": 12},
    quantiles=[0.16, 0.5, 0.84], fill_contours=True,
    levels=(0.68, 0.95, 0.995), plot_datapoints=False, color="C1", alpha=0.1)
plt.show()

In [None]:
# Note you can also generate actual DiffmahParams using asparams=True
flow.sample(
    test_conditions, keys[5], asparams=True)

## Try improving the fit by adjusting the flow hyperparameters
- Increase the size of the neural network using:
    - nn_depth
    - nn_width
    - flow_layers
- Increase the max_patience and/or max_epochs of the `init_fit` method