In [1]:
# =============================================================================
# Main script to run a full end-to-end example for the abcnre package.
# =============================================================================

import jax
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# --- Imports from abcnre package ---
from abcnre.simulation import ABCSimulator
from abcnre.simulation.models import GaussGaussModel
from abcnre.inference.persistence import save_classifier
from abcnre.inference.config import ExperimentConfig, get_experiment_config
from abcnre.inference.estimator import NeuralRatioEstimator
from abcnre.inference.networks.base import create_network_from_config


In [2]:
output_dir = Path("./gauss_example")
simulator_path = output_dir / "gauss_1D_simulator.yml"
network_config_path = output_dir / "config_mlp_reduce_on_plateau.yml"


In [3]:
true_theta = 2.5
epsilon_quantile = 0.05
n_obs = 100

# --- Step 1: Create and Configure Simulator ---
print("--- Step 1: Create and Configure Simulator ---")
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)


gauss_model = GaussGaussModel(mu0=0.0, sigma0=2.0, sigma=0.5)
observed_data = true_theta + gauss_model.sigma * jax.random.normal(subkey, shape=(n_obs,))

simulator = ABCSimulator(
    model=gauss_model,
    observed_data=observed_data,
    quantile_distance=epsilon_quantile
)
print(f"Simulator created with epsilon = {simulator.epsilon:.4f}")



--- Step 1: Create and Configure Simulator ---
Computing epsilon for 5.0% quantile...
Computed epsilon = 0.270363
Simulator created with epsilon = 0.2704


In [4]:
# --- Step 2: Save Simulator Configuration ---
print("\n--- Step 2: Save Simulator Configuration ---")
output_dir.mkdir(exist_ok=True, parents=True)
simulator.save(simulator_path)
print(f"Simulator config saved to: {simulator_path}")


--- Step 2: Save Simulator Configuration ---
✅ Simulator saved with hash: 6b74249163b0
   - Configuration: gauss_example/gauss_1D_simulator.yml
   - Observed Data: gauss_example/observed_data_6b74249163b0.npy
Simulator config saved to: gauss_example/gauss_1D_simulator.yml


In [5]:
# --- Step 3: Create Network and Training Configuration ---

print("\n--- Step 3: Create Network Configuration ---")
exp_config = get_experiment_config('default_mlp_plateau')
exp_config.training.num_epochs = 50
exp_config.training.n_samples_per_epoch = 10240
exp_config.training.num_thetas_to_store = 10000

exp_config.save(network_config_path)
print(f"Training configuration saved to: {network_config_path}")


--- Step 3: Create Network Configuration ---
Training configuration saved to: gauss_example/config_mlp_reduce_on_plateau.yml


In [6]:
# --- Step 4: Training the Classifier ---
print("\n--- Step 4: Training the Classifier ---")

loaded_simulator = ABCSimulator.load(simulator_path)
network = create_network_from_config(exp_config.network.to_dict())
estimator = NeuralRatioEstimator(
    network=network,
    training_config=exp_config.training,
    random_seed=exp_config.random_seed
)

key, train_key = jax.random.split(key)
estimator.train(
    simulator=loaded_simulator,
    output_dir=output_dir,
    num_epochs=exp_config.training.num_epochs,
    n_samples_per_epoch=exp_config.training.n_samples_per_epoch,
    batch_size=exp_config.training.batch_size
)


--- Step 4: Training the Classifier ---
✅ Simulator loaded from: gauss_example/gauss_1D_simulator.yml
Initialized network with 10,753 parameters

--- DEBUG INFO ---
Does epoch_data contain phi_samples? True
Shape of phi_samples: (5120, 1)
--- END DEBUG INFO ---

Epoch 1/50 | Train Loss: 0.6856, Val Loss: 0.6949, Train Acc 55.08%, Val Acc: 48.58%, Learning rate = 0.001000
Epoch 2/50 | Train Loss: 0.6838, Val Loss: 0.6840, Train Acc 71.88%, Val Acc: 55.86%, Learning rate = 0.001000
Epoch 3/50 | Train Loss: 0.6697, Val Loss: 0.6697, Train Acc 70.70%, Val Acc: 78.03%, Learning rate = 0.001000
Epoch 4/50 | Train Loss: 0.6436, Val Loss: 0.6373, Train Acc 72.27%, Val Acc: 81.10%, Learning rate = 0.001000
Epoch 5/50 | Train Loss: 0.5900, Val Loss: 0.6083, Train Acc 80.08%, Val Acc: 66.46%, Learning rate = 0.001000
Epoch 6/50 | Train Loss: 0.5468, Val Loss: 0.5436, Train Acc 76.56%, Val Acc: 79.35%, Learning rate = 0.001000
Epoch 7/50 | Train Loss: 0.5576, Val Loss: 0.5013, Train Acc 70.70%, V

In [7]:
# --- Step 5: Saving All Classifier Artifacts ---
print("\n--- Step 5: Saving All Classifier Artifacts ---")

final_config_path = save_classifier(
    estimator=estimator,
    simulator=loaded_simulator, # <<< On passe le simulateur pour sauvegarder les dimensions
    output_dir=output_dir,
    filename_base="gauss_1D"
)


--- Step 5: Saving All Classifier Artifacts ---
✅ Classifier saved. Master config: gauss_example/gauss_1D_classifier.yml
