# Import libraries

In [None]:
import jax
import numpy as np
import ximinf.nn_inference as nninf
# import ximinf.generate_sim as gsim
import pandas as pd
import jax.numpy as jnp

# Set device type

In [None]:
jax.devices()
cpu = jax.devices("cpu")[0]
gpu = jax.devices("gpu")[0]

key = jax.random.PRNGKey(42)

# Load data

In [None]:
columns = ['magobs', 'magobs_err','x1', 'x1_err', 'c', 'c_err', 'prompt', 'localcolor', 'localcolor_err', 'z']
df = pd.read_parquet("../data/inference_data_frame.parquet")
# flat = gsim.flatten_df(df, columns)

# Load NN

In [None]:
# Parameters

model = nninf.load_nn('../data/NNs/nn_model')


# MCMC

In [None]:
model.eval()

# ========== Global Settings ==========
BOUNDS = jnp.array([[-0.3, 0], [2, 4], [-20, -18], [0, 0.3]])

In [None]:
print("Launch MCMC ...")

# List of all column names
data_names = list(df.keys())  # or your column dictionary keys

# Stack columns along a new axis, shape: (N, M, C)
stacked = jnp.stack([df[name].squeeze() for name in data_names], axis=-1)

# Interleave along M axis: (N, M*C)
interleaved = stacked.reshape(stacked.shape[0], -1)

def log_post(theta):
        return nninf.log_prob_fn(theta, model, interleaved)

# Initial position at the middle of the priors from bounds
theta_init = (BOUNDS[:, 0] + BOUNDS[:, 1]) / 2.0

# Run MCMC
key, post = nninf.sample_posterior(
    log_post,
    n_warmup=200,
    n_samples=2000,
    init_position = theta_init,
    rng_key=key
)

print("...finished")

In [None]:
np.save(post, "../data/results/sbi_posterior.npy")