# Phase 4.0: Simulation-Based Inference for LIF Neuron Model with BayesFlow

This notebook demonstrates a full simulation-based inference (SBI) workflow for a Leaky Integrate-and-Fire (LIF) neuron model using BayesFlow v2+. We will:
- Define the prior and simulator for the LIF model
- Build and train a neural posterior with BayesFlow
- Perform inference and evaluate parameter recovery

Each step is explained in detail for clarity and reproducibility.

## 1. Environment Setup and Imports

First, ensure you have the required packages installed. You need:
- `bayesflow` (v2+)
- `jax`, `jaxlib` (for JAX backend)
- `keras` (v3+)
- `numpy`, `matplotlib`, `scipy`

You can install them with:
```bash
pip install bayesflow jax jaxlib keras numpy matplotlib scipy
```

Now, import all necessary libraries and set up the backend.

In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"  # Use JAX backend for BayesFlow
os.environ["CUDA_VISIBLE_DEVICES"] = ""  # Force JAX to use CPU only
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import numpy as np
import bayesflow as bf
from src.lif_model import lif_simulate
from scipy.stats import uniform
import matplotlib.pyplot as plt

INFO:bayesflow:Using backend 'jax'


## 2. Define the Prior Distribution

The prior defines plausible ranges for each LIF model parameter. We use uniform distributions for each parameter, matching biological constraints.

In [2]:
# Parameter bounds for the LIF model
PARAM_BOUNDS = {
    'tau_m':   (5.0, 30.0),    # Membrane time constant (ms)
    'E_L':     (-80.0, -65.0), # Resting potential (mV)
    'g_L':     (5.0, 15.0),    # Leak conductance (nS)
    'V_th':    (-60.0, -50.0), # Spike threshold (mV)
    'V_reset': (-80.0, -65.0), # Reset potential (mV)
    'I':       (50.0, 300.0),  # Input current (pA)
}
PARAM_KEYS = list(PARAM_BOUNDS.keys())

def prior(batch_size):
    """Sample parameters from the prior."""
    samples = []
    for key in PARAM_KEYS:
        low, high = PARAM_BOUNDS[key]
        samples.append(uniform.rvs(loc=low, scale=high-low, size=batch_size))
    return np.stack(samples, axis=1)

## 3. Define the Simulator

The simulator generates voltage traces from the LIF model given a set of parameters. For BayesFlow v2+, we wrap the simulator in a class with a `.sample(batch_size)` method that returns a dictionary with each parameter as a separate key, plus the simulated trace.

In [3]:
class LIFSimulator:
    def __init__(self, param_keys, noise_std=0.5):
        self.param_keys = param_keys
        self.noise_std = noise_std

    def sample(self, batch_size):
        params = prior(batch_size)
        traces = []
        param_dicts = {key: [] for key in self.param_keys}
        for p in params:
            param_dict = dict(zip(self.param_keys, p))
            sim_args = {
                'T': 200.0,
                'dt': 0.1,
                'E_L': param_dict['E_L'],
                'V_th': param_dict['V_th'],
                'V_reset': param_dict['V_reset'],
                'tau_m': param_dict['tau_m'],
                'g_L': param_dict['g_L'],
                'I': param_dict['I'],
                'tref': 2.0
            }
            t, V, spikes = lif_simulate(**sim_args)
            V_noisy = V + np.random.normal(0, self.noise_std, size=V.shape)
            traces.append(V_noisy)
            for key in self.param_keys:
                param_dicts[key].append(param_dict[key])
        # Convert lists to arrays
        for key in param_dicts:
            param_dicts[key] = np.array(param_dicts[key]).reshape(-1, 1)
        out = {**param_dicts, "trace": np.stack(traces, axis=0)[..., np.newaxis]}
        return out

# Instantiate the simulator
simulator = LIFSimulator(PARAM_KEYS)

## 4. Check Simulator Output

Let's verify that the simulator returns the correct keys and shapes for BayesFlow.

In [4]:
sim_out = simulator.sample(2)
print("Simulator output keys:", list(sim_out.keys()))
for k, v in sim_out.items():
    print(f"{k}: shape {v.shape}")

Simulator output keys: ['tau_m', 'E_L', 'g_L', 'V_th', 'V_reset', 'I', 'trace']
tau_m: shape (2, 1)
E_L: shape (2, 1)
g_L: shape (2, 1)
V_th: shape (2, 1)
V_reset: shape (2, 1)
I: shape (2, 1)
trace: shape (2, 2000, 1)


## 5. Define BayesFlow Networks and Workflow

We use BayesFlow's `TimeSeriesNetwork` for summary extraction and `CouplingFlow` for posterior estimation. The `BasicWorkflow` object manages the full SBI process.

In [5]:
summary_network = bf.networks.TimeSeriesNetwork()
inference_network = bf.networks.CouplingFlow()

workflow = bf.BasicWorkflow(
    inference_network=inference_network,
    summary_network=summary_network,
    inference_variables=PARAM_KEYS,
    summary_variables=["trace"],
    simulator=simulator
)

2025-07-03 21:19:57.362574: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751577597.379483   38202 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751577597.386780   38202 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1751577597.406409   38202 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1751577597.406427   38202 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1751577597.406430   38202 computation_placer.cc:177] computation placer alr

## 6. Train the Neural Posterior

We train the neural posterior using `workflow.fit_online`, which draws fresh parameter/data pairs on-the-fly. This enables robust amortized inference.

In [6]:
history = workflow.fit_online(
    epochs=10,                # Increase for real use
    batch_size=128,
    verbose=True
)
print("Training complete.")

INFO:bayesflow:Fitting on dataset instance of OnlineDataset.
INFO:bayesflow:Building on a test batch.


Epoch 1/10


: 

## 7. Inference and Parameter Recovery

We now test the trained posterior by simulating new data, running inference, and comparing the posterior samples to the true parameters.

In [None]:
# Simulate new data
sim_out = simulator.sample(1)
obs_dict = {"trace": sim_out["trace"]}
true_params = np.array([sim_out[key][0, 0] for key in PARAM_KEYS])

# Run inference
posterior_samples = workflow.posterior.sample(
    conditions=obs_dict,
    n_samples=1000
)

# Plot posterior samples vs. true parameters
fig, axes = plt.subplots(1, len(PARAM_KEYS), figsize=(3*len(PARAM_KEYS), 3))
for i, key in enumerate(PARAM_KEYS):
    ax = axes[i] if len(PARAM_KEYS) > 1 else axes
    ax.hist(posterior_samples[key].flatten(), bins=30, alpha=0.7, label='Posterior')
    ax.axvline(true_params[i], color='r', linestyle='--', label='True')
    ax.set_title(key)
    ax.legend()
plt.tight_layout()
plt.show()

## 8. Summary

- We defined a prior and simulator for the LIF neuron model.
- Trained a BayesFlow neural posterior for parameter inference.
- Demonstrated parameter recovery on new simulated data.

For best results, increase the number of training epochs and tune the network as needed.