# Import libraries

In [None]:
import jax
import ximinf.nn_inf as nni
import ximinf.nn_train as nnt
from flax import nnx
import pathlib  # File path handling library
import orbax.checkpoint as ocp  # Checkpointing library

# Set device type

In [None]:
jax.devices()
cpu = jax.devices("cpu")[0]
gpu = jax.devices("gpu")[0]

# Load NN

In [None]:
# Parameters
# N = 300_000  # Number of samples # RETRIEVE FROM DATA
M = 100      # Number of points per sample # RETRIEVE FROM DATA

Nsize_p = 2*M #64
Nsize_r = 20*M
phi_batch = 1

# Define the checkpoint directory
ckpt_dir = os.path.abspath('../data/NNs/nn_model')
ckpt_dir = pathlib.Path(ckpt_dir).resolve()

# Ensure the folder is removed before saving
if ckpt_dir.exists()==False:
    # Make an error
    raise ValueError(f"Checkpoint directory {ckpt_dir} does not exist. Please check the path.")
    
# 1. Re-create the checkpointer
checkpointer = ocp.StandardCheckpointer()

# Split the model into GraphDef (structure) and State (parameters + buffers)
abstract_model = nnx.eval_shape(lambda: nnt.DeepSetClassifier(0.05, Nsize_p, Nsize_r, phi_batch, rngs=nnx.Rngs(0)))
abs_graphdef, abs_rngkey, abs_rngcount, abstract_state = nnx.split(abstract_model, nnx.RngKey, nnx.RngCount, ...)

# 3. Restore
state_restored = checkpointer.restore(ckpt_dir / 'state')
#jax.tree.map(np.testing.assert_array_equal, abstract_state, state_restored)
print('NNX State restored: ')
nnx.display(state_restored)

model = nnx.merge(abs_graphdef, abs_rngkey, abs_rngcount, state_restored)

nnx.display(model)

# MCMC

In [None]:
model.eval()

# ========== Global Settings ==========
#BOUNDS = jnp.array([[0., 1.], [0., 1.]]) #normalised
BOUNDS = jnp.array([[a_min,a_max], [b_min, b_max]])
NDIM = BOUNDS.shape[0]

In [None]:
print("Launch MCMC ...")

def log_post(theta):
        return nni.log_prob_fn(theta, model, interleaved)

# Run MCMC
key, post = nni.sample_posterior(
    log_post,
    n_warmup=200,
    n_samples=2000,
    init_position = theta_true,
    rng_key=key
)

print("...finished")