# Phase 4: Neural Posterior Training with BayesFlow

---

## 1. Introduction

**Goal:**  
Train a neural posterior using BayesFlow for the LIF neuron model, leveraging the prior and simulator defined in previous phases. This phase demonstrates the full SBI workflow: training, inference, and evaluation.

---

## 2. BayesFlow Workflow Recap

- **Prior:** Samples plausible parameters.
- **Simulator:** Generates simulated data given parameters.
- **AmortizedPosterior:** Neural network that learns to approximate the posterior $p(\theta|x)$.
- **Trainer:** Handles simulation-based training using the prior and simulator.

BayesFlow trains the neural posterior by drawing fresh parameter/data pairs on-the-fly, ensuring robust and generalizable inference.

---

## 3. Setup: Imports and Functions

Make sure BayesFlow is installed:
```bash
pip install bayesflow
```

Import BayesFlow and bring in your prior and simulator from Phase 2.

In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import numpy as np
import bayesflow as bf
from src.lif_model import lif_simulate
from scipy.stats import uniform

# Parameter bounds (should match previous phases)
PARAM_BOUNDS = {
    'tau_m':   (5.0, 30.0),    # ms
    'E_L':     (-80.0, -65.0), # mV
    'g_L':     (5.0, 15.0),    # nS
    'V_th':    (-60.0, -50.0), # mV
    'V_reset': (-80.0, -65.0), # mV
    'I':       (50.0, 300.0),  # pA
}
PARAM_KEYS = list(PARAM_BOUNDS.keys())

# Prior function
def prior(batch_size):
    samples = []
    for key in PARAM_KEYS:
        low, high = PARAM_BOUNDS[key]
        samples.append(uniform.rvs(loc=low, scale=high-low, size=batch_size))
    return np.stack(samples, axis=1)

# Simulator function
def simulator(params, T=200.0, dt=0.1, tref=2.0, noise_std=0.5):
    traces = []
    for p in params:
        param_dict = dict(zip(PARAM_KEYS, p))
        sim_args = {
            'T': T,
            'dt': dt,
            'E_L': param_dict['E_L'],
            'V_th': param_dict['V_th'],
            'V_reset': param_dict['V_reset'],
            'tau_m': param_dict['tau_m'],
            'g_L': param_dict['g_L'],
            'I': param_dict['I'],
            'tref': tref
        }
        t, V, spikes = lif_simulate(**sim_args)
        V_noisy = V + np.random.normal(0, noise_std, size=V.shape)
        traces.append(V_noisy)
    return np.stack(traces, axis=0)

INFO:bayesflow:Using backend 'jax'


In [None]:
# Simulator class for BayesFlow v2+
class LIFSimulator:
    def __init__(self, param_keys, noise_std=0.5):
        self.param_keys = param_keys
        self.noise_std = noise_std

    def sample(self, batch_size):
        params = prior(batch_size)
        traces = []
        param_dicts = {key: [] for key in self.param_keys}
        for p in params:
            param_dict = dict(zip(self.param_keys, p))
            sim_args = {
                'T': 200.0,
                'dt': 0.1,
                'E_L': param_dict['E_L'],
                'V_th': param_dict['V_th'],
                'V_reset': param_dict['V_reset'],
                'tau_m': param_dict['tau_m'],
                'g_L': param_dict['g_L'],
                'I': param_dict['I'],
                'tref': 2.0
            }
            t, V, spikes = lif_simulate(**sim_args)
            V_noisy = V + np.random.normal(0, self.noise_std, size=V.shape)
            traces.append(V_noisy)
            for key in self.param_keys:
                param_dicts[key].append(param_dict[key])
        # Convert lists to arrays
        for key in param_dicts:
            param_dicts[key] = np.array(param_dicts[key]).reshape(-1, 1)
        out = {**param_dicts, "trace": np.stack(traces, axis=0)}
        return out

# Instantiate the simulator object
simulator = LIFSimulator(PARAM_KEYS)

# Redefine the workflow to use the simulator object
workflow = bf.BasicWorkflow(
    inference_network=inference_network,
    summary_network=summary_network,
    inference_variables=PARAM_KEYS,
    summary_variables=["trace"],
    simulator=simulator
)


In [15]:
# DEBUG: Check simulator output keys and shapes
sim_out = simulator.sample(1)
print("Simulator output keys:", list(sim_out.keys()))
for k, v in sim_out.items():
    print(f"{k}: shape {v.shape}")
# Should see all parameter names and 'trace' as keys


Simulator output keys: ['parameters', 'trace']
parameters: shape (1, 6)
trace: shape (1, 2000)


## 4. Define and Initialize the Neural Posterior

We use BayesFlow's `AmortizedPosterior` to learn the mapping from simulated traces to parameters.

In [9]:
# Use BayesFlow v2+ BasicWorkflow for amortized inference
# Define summary and inference networks
summary_network = bf.networks.TimeSeriesNetwork()
inference_network = bf.networks.CouplingFlow()

# Define the workflow
workflow = bf.BasicWorkflow(
    inference_network=inference_network,
    summary_network=summary_network,
    inference_variables=PARAM_KEYS,
    summary_variables=["trace"],  # adjust as needed for your simulator output
    simulator=simulator
)

## 5. Training the Posterior with BayesFlow

The `workflow.fit_online(...)` method in BayesFlow v2+ handles simulation-based training. It repeatedly samples parameters from the prior, simulates data using your LIF model, and updates the neural posterior. This approach enables efficient amortized inference and parameter recovery for new simulated trajectories.

In [13]:
# Train the neural posterior with BayesFlow v2+
# This cell assumes the simulator and workflow are defined as above.

n_epochs = 10  # Increase for real use
n_simulations_per_epoch = 128

history = workflow.fit_online(
    epochs=n_epochs,
    batch_size=n_simulations_per_epoch,
    verbose=True
)

print("Training complete.")

INFO:bayesflow:Fitting on dataset instance of OnlineDataset.
INFO:bayesflow:Building on a test batch.
INFO:bayesflow:Building on a test batch.


KeyError: "Missing keys: {'V_reset', 'V_th', 'E_L', 'tau_m', 'g_L', 'I'}"

## 6. Inference: Posterior Sampling

After training, you can use the neural posterior to infer parameters from new (possibly real) data.

In [None]:
# Parameter recovery: simulate new data, run inference, and compare posterior samples to true parameters

# Simulate new parameters and data
n_test = 1
sim_out = simulator.sample(n_test)
obs_dict = {"trace": sim_out["trace"]}
true_params = np.array([sim_out[key][0, 0] for key in PARAM_KEYS])

# Run inference
posterior_samples = workflow.posterior.sample(
    conditions=obs_dict,
    n_samples=1000
)

# Plot posterior samples vs. true parameters
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, len(PARAM_KEYS), figsize=(3*len(PARAM_KEYS), 3))
for i, key in enumerate(PARAM_KEYS):
    ax = axes[i] if len(PARAM_KEYS) > 1 else axes
    ax.hist(posterior_samples[key].flatten(), bins=30, alpha=0.7, label='Posterior')
    ax.axvline(true_params[i], color='r', linestyle='--', label='True')
    ax.set_title(key)
    ax.legend()
plt.tight_layout()
plt.show()

TypeError: 'LIFSimulator' object is not callable

## 7. Evaluation: Parameter Recovery

Let's visualize the true parameters vs. the posterior samples for a sanity check.

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()
for i, key in enumerate(PARAM_KEYS):
    axes[i].hist(posterior_samples[:, i], bins=30, color='lightgreen', edgecolor='k', alpha=0.7)
    axes[i].axvline(true_params[0, i], color='r', linestyle='--', label='True')
    axes[i].set_title(key)
    axes[i].legend()
plt.tight_layout()
plt.show()

## 8. Summary & Next Steps

- You have trained a BayesFlow neural posterior for the LIF neuron model.
- You can now use it for inference and parameter recovery on new data.
- For best results, increase the number of training iterations and tune the network as needed.

---