# Simulation-Based Inference for a LIF Neuron (BayesFlow)

This notebook is a *frozen snapshot* of the project used to generate the figures in `figures/`.

**Workflow overview**
1. Imports + plotting defaults
2. LIF simulator and parameterization
3. Prior definition and BayesFlow simulator wrapper
4. Sanity checks (prior draws)
5. BayesFlow workflow (adapter + networks)
6. Offline training
7. Diagnostics: loss, recovery, calibration (SBC)
8. Single-example posterior + posterior predictive checks (PPC)

> **Reproducibility note:** This notebook depends on BayesFlow + JAX/Keras backend versions.

---


In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ["KERAS_BACKEND"] = "jax"
import bayesflow
# @title Figure Settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True

import ipywidgets as widgets  # interactive display
%config InlineBackend.figure_format = 'retina'
# use NMA plot style
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")
my_layout = widgets.Layout()

np.random.seed(45)

## 1) Imports and environment

- Sets the Keras backend to **JAX** (required by the original setup).
- Imports BayesFlow and plotting utilities.
- Optional: applies a custom Matplotlib style.


In [None]:
# @markdown Execute this code to initialize the default parameters
from keras.src.utils.module_utils import jax
# Set seeds for reproducibility
jax_seed = jax.random.PRNGKey(45)

# Default parameters
def default_pars(**kwargs):
    pars = {
        'tref': 2.,  # Refractory time (ms)
        'T': 200.,   # Simulation duration (ms)
        'dt': 0.2,   # Time step (ms)
    }
    pars['range_t'] = np.arange(0, pars['T'], pars['dt'])
    for k in kwargs:
        pars[k] = kwargs[k]
    return pars



## 2) Default simulation settings

Defines the time grid and shared simulation hyperparameters (duration, time step, refractory period). Also sets the random seed for reproducibility.


In [None]:
# LIF model
def run_LIF(pars, Iinj=None, noise_std=50):
    # Validate parameters
    if pars['tau_m'] <= 0 or pars['g_L'] <= 0:
        raise ValueError("tau_m and g_L must be positive")
    if pars['V_th'] <= pars['V_reset']:
        raise ValueError("V_th must be greater than V_reset")
    
    V_th, V_reset = pars['V_th'], pars['V_reset']
    tau_m, g_L = pars['tau_m'], pars['g_L']
    V_init, E_L = pars['V_init'], pars['E_L']
    dt, range_t = pars['dt'], pars['range_t']
    Lt = range_t.size
    tref = pars['tref']
    
    v = np.zeros(Lt)
    v[0] = V_init
    
    if Iinj is None:
        Iinj = np.random.normal(loc=200, scale=noise_std, size=Lt)  # Noisy input
    
    rec_spikes = []
    tr = 0.
    for it in range(Lt - 1):
        if tr > 0:
            v[it] = V_reset
            tr -= 1
        elif v[it] >= V_th:
            rec_spikes.append(it)
            v[it] = V_reset
            tr = tref / dt
        dv = (-(v[it] - E_L) + Iinj[it] / g_L) * (dt / tau_m)
        v[it + 1] = v[it] + dv
        v[it + 1] = np.clip(v[it + 1], -100, 100)  # Prevent overflow
    
    rec_spikes = np.array(rec_spikes) * dt
    spike_train = np.zeros(Lt)
    spike_indices = (rec_spikes / dt).astype(int)
    spike_train[spike_indices] = 1
    return v, spike_train


## 3) LIF simulator

Implements the leaky integrate-and-fire (LIF) dynamics and returns:
- `v`: membrane potential trace
- `spike_train`: binary spike indicator per time step

Includes basic parameter validation.


In [None]:
from scipy.stats import truncnorm
def truncated_normal(mean, std, low, high):
    a, b = (low - mean) / std, (high - mean) / std
    return truncnorm.rvs(a, b, loc=mean, scale=std)

def prior():
    # Sample voltages (mV) with loose dependencies
    V_th = truncated_normal(mean=-55, std=2, low=-60, high=-50)
    E_L = truncated_normal(mean=-70, std=2, low=-75, high=-60)
    V_init = truncated_normal(mean=-70, std=2, low=-80, high=min(E_L - 0.5, -65))
    V_reset = truncated_normal(mean=-72, std=2, low=-85, high=min(V_init - 0.5, -65))
    
    # Sample positive-only parameters with constraints
    tau_m = np.clip(np.random.lognormal(mean=np.log(10), sigma=0.2), 5, 20)  # ms
    g_L = np.clip(np.random.lognormal(mean=np.log(10), sigma=0.2), 5, 15)    # nS
    
    return {
        "V_th": V_th,
        "tau_m": tau_m,
        "g_L": g_L,
        "V_reset": V_reset,
        "V_init": V_init,
        "E_L": E_L,
    }

# Simulator
def simulator(V_th, tau_m, g_L, V_reset, V_init, E_L):
    pars = default_pars(
        V_th=V_th, tau_m=tau_m, g_L=g_L,
        V_reset=V_reset, V_init=V_init, E_L=E_L
    )
    v, spike_train = run_LIF(pars, Iinj=None)
    v += np.random.normal(0, 1, v.shape)  # Add observation noise
    return {"voltage": v[..., np.newaxis], "spikes": spike_train[..., np.newaxis]}

# BayesFlow setup (simplified)
sim = bayesflow.make_simulator([prior, simulator])

 # from (4000,) to (4000, 1)

## 4) Prior + simulator wrapper

- Defines a *truncated normal / lognormal* prior for LIF parameters.
- Wraps the simulator to return the BayesFlow expected dictionary format:
  - `voltage`: shape `(T, 1)`
  - `spikes`: shape `(T, 1)`
- Creates a BayesFlow simulator via `bayesflow.make_simulator`.


In [None]:
sim_draws = sim.sample(500)
print(sim_draws["V_th"].shape)
print(sim_draws["voltage"].shape)

## 5) Quick sanity check: simulator output shapes

Draws a small batch from the simulator to confirm shapes and basic functionality.


In [None]:
samples = [prior() for _ in range(10000)]
import matplotlib.pyplot as plt
plt.hist([s["V_th"] for s in samples], bins=30, label="V_th")
plt.hist([s["E_L"] for s in samples], bins=30, label="E_L")
plt.legend()
plt.show()

## 6) Quick sanity check: prior marginals

Samples from the prior and visualizes marginals for a couple parameters (helps catch obvious prior mistakes).


In [None]:
adapter = (
    bayesflow.Adapter()
    .convert_dtype("float64", "float32")
    .concatenate(["V_th", "tau_m", "g_L", "V_reset", "V_init", "E_L"], into="inference_variables")
    .concatenate(["voltage", "spikes"], into="summary_variables")
)
summary_network = bayesflow.networks.TimeSeriesNetwork(hidden_dim=64, summary_dim=32)
inference_network = bayesflow.networks.CouplingFlow(num_params=6)
workflow = bayesflow.BasicWorkflow(
    simulator=sim,
    adapter=adapter,
    inference_network=inference_network,
    summary_network=summary_network,
    standardize=["inference_variables", "summary_variables"]
)

## 7) BayesFlow workflow setup

Defines the data adapter, summary network, and inference network, then combines them into a `BasicWorkflow`.

- Adapter concatenates parameters into `inference_variables`
- Adapter concatenates time series into `summary_variables`


In [None]:
training_data = workflow.simulate(20000)
validation_data = workflow.simulate(1000)
history = workflow.fit_offline(
    data=training_data,
    epochs=50,
    batch_size=128,
    validation_data=validation_data
)

## 8) Offline simulation + training

Generates training/validation simulations and fits the amortized inference network offline.


In [None]:
# Set global font sizes for all plots
plt.rcParams.update({
    'font.size': 16,          # Base font (ticks, legends, 'r' text)
    'axes.labelsize': 18,     # X/Y labels
    'axes.titlesize': 20,     # Titles
    'xtick.labelsize': 16,    # X ticks
    'ytick.labelsize': 16,    # Y ticks
    'legend.fontsize': 16,    # Legends
})

## 9) Plot styling for diagnostics

Sets global font sizes for consistent and presentation-ready diagnostic plots.


In [None]:
# Generate the loss plot
plot_loss = bayesflow.diagnostics.plots.loss(history)

# Fine-tune all elements in the plot
for ax in plot_loss.axes:
    # Title
    ax.title.set_size(20)
    
    # X and Y labels
    ax.xaxis.label.set_size(20)
    ax.yaxis.label.set_size(20)
    
    # Ticks
    ax.tick_params(axis='both', which='major', labelsize=16)
    
    # Legend (if present; loss plot may have multiple lines)
    if ax.get_legend():
        plt.setp(ax.get_legend().get_texts(), fontsize=16)
        plt.setp(ax.get_legend().get_title(), fontsize=20)
    
    # Any annotations or text (e.g., if there are 'r' or other info; adapt if needed)
    for text in ax.texts:
        text.set_fontsize(16)

# Show the updated plot
plt.show()

## 10) Training diagnostics: loss trajectory

Plots training vs. validation loss across epochs.
- Useful to spot under/overfitting
- Confirms stable optimization


In [None]:
# Set the number of posterior draws you want to get
num_datasets = 1000
num_samples = 100

# Simulate 300 scenarios
test_sims = workflow.simulate(num_datasets)

# Obtain num_samples posterior samples per scenario
samples = workflow.sample(conditions=test_sims, num_samples=num_samples)
j = bayesflow.diagnostics.plots.recovery(samples, test_sims)
for ax in j.axes:
    ax.title.set_size(36)
    ax.xaxis.label.set_size(28)
    ax.yaxis.label.set_size(28)
    ax.tick_params(axis='both', labelsize=18)
    if ax.get_legend():
        plt.setp(ax.get_legend().get_texts(), fontsize=18)
    for text in ax.texts:  # Targets 'r' annotations
        text.set_fontsize(32)

plt.show()

## 11) Parameter recovery (simulation-based check)

Simulates many datasets, draws posterior samples, and plots **recovery** (ground truth vs posterior estimates).


In [None]:
# Set the number of posterior draws you want to get
num_datasets = 1000
num_samples = 100

# Simulate 300 scenarios
test_sims = workflow.simulate(num_datasets)

# Obtain num_samples posterior samples per scenario
samples = workflow.sample(conditions=test_sims, num_samples=num_samples)
g = bayesflow.diagnostics.plots.calibration_histogram(samples, test_sims)
for ax in g.axes:
    ax.title.set_size(36)
    ax.xaxis.label.set_size(28)
    ax.yaxis.label.set_size(28)
    ax.tick_params(axis='both', labelsize=18)
    if ax.get_legend():
        plt.setp(ax.get_legend().get_texts(), fontsize=18)
    for text in ax.texts:  # Targets 'r' annotations
        text.set_fontsize(32)

plt.show()

## 12) Calibration / SBC rank histograms

Simulation-based calibration (SBC) style diagnostic:
- For well-calibrated posteriors, rank histograms should be ~uniform.


In [None]:
metrics = workflow.compute_default_diagnostics(test_data=1000)
metrics

## 13) Default diagnostics summary

Computes BayesFlow’s default diagnostic metrics on held-out simulations.


In [None]:
# --- Posterior Inference on New Observation ---
# Simulate new test observation from true parameters
true_params = prior()
test_sim = simulator(**true_params)

# Extract and reshape both voltage and spikes
observed_voltage = test_sim["voltage"].reshape(1, -1, 1)  # Shape: (1, 1000, 1)
observed_spikes = test_sim["spikes"].reshape(1, -1, 1)    # Shape: (1, 1000, 1)

# Sample from the posterior with both summary variables
posterior_samples = workflow.sample(
    conditions={"voltage": observed_voltage, "spikes": observed_spikes},
    num_samples=1000
)

## 14) Posterior inference for one synthetic observation

Samples a single set of “true” parameters, simulates an observation, and draws posterior samples conditioned on that observation.


In [None]:
# --- Visualization ---
import seaborn as sns

param_names = ["V_th", "tau_m", "g_L", "V_reset", "V_init", "E_L"]

fig, axs = plt.subplots(3, 2, figsize=(12, 12))
axs = axs.flatten()

for i, param in enumerate(param_names):
    sns.kdeplot(posterior_samples[param].reshape(-1), ax=axs[i], fill=True, label='Posterior')
    axs[i].axvline(true_params[param], color='r', linestyle='--', label='True value')
    axs[i].set_title(f"Posterior for {param}", fontsize=28)
    axs[i].set_xlabel(param, fontsize=18)  # Add explicit x-label if needed
    axs[i].set_ylabel('Density', fontsize=18)
    axs[i].tick_params(axis='both', labelsize=16)
    axs[i].legend(fontsize=16)

plt.tight_layout()
plt.show()

## 15) Posterior visualization

Plots marginal posterior densities (KDE) for each parameter and overlays the true value.


In [None]:
# --- Minimal PPC (no plots) ---
np.random.seed(45)
K = 200  # posterior predictive draws
param_names = ["V_th", "tau_m", "g_L", "V_reset", "V_init", "E_L"]

# Flatten helpers
obs_v = observed_voltage.reshape(-1)
obs_spike_count = int(observed_spikes.sum())

# Sample K synthetic datasets from posterior
pp_volt = []
pp_spike_counts = []
post_idx = np.random.randint(0, posterior_samples[param_names[0]].size, size=K)
for i in post_idx:
    pars = {p: posterior_samples[p].reshape(-1)[i] for p in param_names}
    sim = simulator(**pars)  # uses your default noise/current; good enough for PPC-lite
    pp_volt.append(sim["voltage"].reshape(-1))
    pp_spike_counts.append(int(sim["spikes"].sum()))
pp_volt = np.vstack(pp_volt)
pp_spike_counts = np.array(pp_spike_counts)

# Compute pointwise predictive band + median
v_lo = np.percentile(pp_volt, 5, axis=0); v_hi = np.percentile(pp_volt, 95, axis=0); v_med = np.median(pp_volt, axis=0)

# 1) Voltage coverage (target ~90%)
coverage = np.mean((obs_v >= v_lo) & (obs_v <= v_hi))

# 2) Spike count percentile (should be non-extreme, e.g., 0.1–0.9)
spike_percentile = np.mean(pp_spike_counts <= obs_spike_count)

# 3) Subthreshold deviation (mV) using posterior-median V_th
Vth_med = np.median(posterior_samples["V_th"])
mask_sub = obs_v < Vth_med
subthreshold_deviation = np.mean(np.abs(obs_v[mask_sub] - v_med[mask_sub]))

print(f"PPC — voltage coverage in 90% band: {coverage*100:.1f}%")
print(f"PPC — spike count percentile: {spike_percentile:.2f}")
print(f"PPC — subthreshold deviation vs. predictive median: {subthreshold_deviation:.2f} mV")


## 16) Posterior predictive checks (PPC-lite)

Samples parameter draws from the posterior, simulates replicated datasets, and compares:
- Spike count distribution vs observed spike count
- Voltage trace predictive band vs observed voltage

This checks whether the inferred posterior can reproduce key features of the observed data.


In [None]:
# --- PPC-lite with plots ---
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(45)
K = 200  # number of posterior predictive draws
param_names = ["V_th", "tau_m", "g_L", "V_reset", "V_init", "E_L"]

obs_v = observed_voltage.reshape(-1)
obs_spike_count = int(observed_spikes.sum())
t = default_pars()['range_t']

pp_volt = []
pp_spike_counts = []
post_idx = np.random.randint(0, posterior_samples[param_names[0]].size, size=K)
for i in post_idx:
    pars = {p: posterior_samples[p].reshape(-1)[i] for p in param_names}
    sim = simulator(**pars)  # same default noise/current
    pp_volt.append(sim["voltage"].reshape(-1))
    pp_spike_counts.append(int(sim["spikes"].sum()))
pp_volt = np.vstack(pp_volt)
pp_spike_counts = np.array(pp_spike_counts)

# Compute PPC bands
v_lo = np.percentile(pp_volt, 5, axis=0)
v_hi = np.percentile(pp_volt, 95, axis=0)
v_med = np.median(pp_volt, axis=0)

# Plot: Voltage PPC ribbon
plt.figure(figsize=(10, 4))
plt.fill_between(t, v_lo, v_hi, alpha=0.3, label='Posterior predictive 90% band')
plt.plot(t, v_med, 'k-', lw=1.5, label='Posterior predictive median')
plt.plot(t, obs_v, 'r-', lw=1, alpha=0.8, label='Observed voltage')
plt.xlabel('Time (ms)')
plt.ylabel('Membrane potential (mV)')
plt.title('PPC-lite: Voltage Trace')
plt.legend()
plt.tight_layout()
plt.show()

# Plot: Spike count PPC histogram
plt.figure(figsize=(6, 4))
plt.hist(pp_spike_counts, bins=np.arange(pp_spike_counts.min()-0.5,
                                         pp_spike_counts.max()+1.5, 1),
         alpha=0.7, color='skyblue', edgecolor='black')
plt.axvline(obs_spike_count, color='red', linestyle='--', lw=2, label=f'Observed: {obs_spike_count}')
plt.xlabel('Spike count')
plt.ylabel('Frequency')
plt.title('PPC-lite: Spike Count')
plt.legend()
plt.tight_layout()
plt.show()
