[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pranavm19/SBI-Tutorial/blob/main/notebooks/03_NFlows_BioMime.ipynb)

## Applying SBI to Motor Unit Physiology
**Pranav Mamidanna, PhD** (p.mamidanna22@imperial.ac.uk), April 2025

In this notebook, we will apply the concepts learned in the previous tutorials to a real-world problem in motor unit physiology. We will use the BioMime model, which simulates motor unit action potentials (MUAPs) based on six physiological parameters:

1. Fiber Density (FD)
2. Depth (D)
3. Angle (A)
4. Innervation Zone (IZ)
5. Conduction Velocity (CV)
6. Fiber Length (FL)

Our goal is to use SBI to estimate these parameters from observed MUAP data. This is particularly challenging because:
1. The relationship between parameters and MUAPs is highly non-linear
2. The parameter space is 6-dimensional

In [None]:
# # Uncomment and run the following in your colab env
# !pip install git+https://github.com/shihan-ma/BioMime.git
# !gdown 1RIYnYxLkBZ9_7MJQgQBSjAk_oXBTqY0b -O model_linear.pth

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import interact_manual, FloatSlider

import BioMime.utils.basics as bm_basics
import BioMime.models.generator as bm_gen

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config = bm_basics.update_config('../data/config.yaml')

def initialize_generator(model_path):
    generator = bm_gen.Generator(config.Model.Generator)
    generator = bm_basics.load_generator(model_path, generator, 'cpu')
    return generator

def sample_biomime(generator, pars):
    if pars.ndim == 1:
        pars = pars[None, :]

    n_MU = pars.shape[0]
    sim_muaps = [generator.sample(n_MU, pars.to(device).float(), pars.to(device)).to("cpu") for _ in range(10)]
    muap = torch.stack(sim_muaps).mean(0)
    
    return muap.flatten()

# Initialize the generators
BIOMIME6 = initialize_generator('../data/model_linear.pth')

# Define simulators
def simulator_biomime6(pars):
    return sample_biomime(BIOMIME6, pars)

In [3]:
def plot_muap(muap, overlap=False):
    """
    Plot sEMG MUAP with minimal axis style.

    Parameters
    ----------
    muap : np.ndarray
        (Average) MUAP of shape [n_row, n_col, n_time], units: mV.
    overlap : bool
        Whether rows of the MUAP should overlap.
    """
    n_row, n_col, n_time = muap.shape

    # Create figure
    plt.close('all')
    _, ax = plt.subplots(figsize=[int(n_col*0.6), 6])
    
    # Reshape for plotting
    row_space = np.max(np.abs(muap)) * 1.5
    col_space = np.zeros((n_row, n_col, int(n_time/2))) * np.nan
    plotable = np.dstack((muap, col_space)).reshape([n_row, -1])
    plotable -= np.linspace(0, n_row-1, n_row)[:, None] * row_space

    # Plot
    ax.plot(plotable.T, linewidth=0.5, color='k')
    ax.axis('off')

    # Get data range for scaling
    min_y = np.nanmin(plotable)
    y_range = np.ptp(muap.reshape((n_row*n_col, -1)))
    x_range = muap.shape[2]  # in samples

    # Add scale bars
    ytick_length = min(y_range, 0.25)  # mV
    ytick_pos = -0.1 * x_range  # Place tick 10% to the left of data
    ax.plot([ytick_pos, ytick_pos], [min_y, min_y+ytick_length], color='k', linewidth=1)
    ax.text(ytick_pos, min_y + ytick_length/2, f'{ytick_length:.2f}mV', ha='right', va='center', fontsize=8, rotation=90)

    xtick_length = min(x_range, 100)  # samples
    xtick_pos = min_y - 0.1 * y_range  # Place tick 10% below data
    ax.plot([0, xtick_length], [xtick_pos, xtick_pos], color='k', linewidth=1)
    ax.text(xtick_length/2, xtick_pos - 0.1 * y_range, f'{int(xtick_length)}ms', ha='center', va='top', fontsize=8)


### Understanding the BioMime Simulator

The BioMime simulator takes six physiological parameters and generates a MUAP. Let's first explore how changing these parameters affects the generated MUAP.

In [4]:
# Define sliders for each parameter
fdensity = FloatSlider(value=0.75, min=0.5, max=1.0, step=0.01, description='FD')
depth = FloatSlider(value=0.75, min=0.5, max=1.0, step=0.01, description='D')
angle = FloatSlider(value=0.75, min=0.5, max=1.0, step=0.01, description='A')
izone = FloatSlider(value=0.75, min=0.5, max=1.0, step=0.01, description='IZ')
cvel = FloatSlider(value=0.75, min=0.5, max=1.0, step=0.01, description='CV')
flength = FloatSlider(value=0.75, min=0.5, max=1.0, step=0.01, description='FL')

def generate_plot_muap(fd, d, a, iz, cv, fl):
    # Generate MUAP given specified conditions
    context = torch.tensor((fd, d, a, iz, cv, fl))[None, :]
    sim_muaps = simulator_biomime6(context).reshape((-1, 10, 32)).detach().numpy()
    plot_muap(sim_muaps.transpose((1, 2, 0))[:, ::2, :])
    plt.show()
    return

widget = interact_manual(generate_plot_muap, fd=fdensity, d=depth, a=angle, iz=izone, cv=cvel, fl=flength)
display(widget)

interactive(children=(FloatSlider(value=0.75, description='FD', max=1.0, min=0.5, step=0.01), FloatSlider(valu…

<function __main__.generate_plot_muap(fd, d, a, iz, cv, fl)>

> **Task 3.1** Explore the simulator
1. How does changing each parameter affect the MUAP?
2. Are there any parameters that seem to have similar effects?
3. Which parameters appear to have the most significant impact on the MUAP shape?

### Setting up SBI for BioMime

Now that we understand the simulator, let's set up SBI to estimate the parameters from observed MUAPs. We'll need to:
1. Define a prior distribution over the parameters
2. Create a neural network to process the MUAP data
3. Train the SBI model

In [25]:
from sbi.analysis import pairplot
from sbi.inference import NPE
from sbi.utils import BoxUniform
from sbi.neural_nets import posterior_nn
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)

In [None]:
# # import the different choices of pre-configured embedding networks
# from sbi.neural_nets.embedding_nets import (
#     FCEmbedding,
#     CNNEmbedding,
#     PermutationInvariantEmbedding
# )

# choose which type of pre-configured embedding net to use (e.g. CNN)
# embedding_net = CNNEmbedding(input_shape=(32, 32))

In [26]:
# Define the prior distribution
n_dim = 6
prior = BoxUniform(low=0.5 * torch.ones(n_dim), high=torch.ones(n_dim))

# Define a neural network to process MUAPs
class SummaryNet_2D(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2)
        self.pool = torch.nn.MaxPool2d(kernel_size=8, stride=8)
        self.fc = torch.nn.Linear(in_features=16 * 40 * 12, out_features=16)

    def forward(self, x):
        x = x.view(-1, 1, 320, 96)
        x = self.pool(torch.nn.functional.relu(self.conv1(x)))
        x = x.view(-1, 16 * 40 * 12)
        x = torch.nn.functional.relu(self.fc(x))
        return x

embedding_net = SummaryNet_2D()

In [27]:
# instantiate the conditional neural density estimator
neural_posterior = posterior_nn(model="maf", embedding_net=embedding_net)

# setup the inference procedure with NPE
inference = NPE(prior=prior, density_estimator=neural_posterior)

# Check prior
prior, num_parameters, prior_returns_numpy = process_prior(prior)

# Check simulator
simulator = process_simulator(simulator_biomime6, prior, prior_returns_numpy)

# Consistency check after making ready for sbi.
check_sbi_inputs(simulator, prior)

In [None]:
# Generate samples from the prior
num_simulations = 100
theta = prior.sample((num_simulations,))
x = simulator(theta)
print("theta.shape", theta.shape)
print("x.shape", x.shape)

In [None]:
# train the density estimator
density_estimator = inference.append_simulations(theta, x).train()

# build the posterior
posterior = inference.build_posterior(density_estimator)

> **Task 3.2** Training the SBI model
1. Run the training with 5000 simulations
2. Save the trained model
3. Test the model with a known parameter set

In [None]:
# Your code here
# 1. Generate training data
# 2. Train the model
# 3. Save the model
# 4. Test with known parameters

### Analyzing the Results

Once we have a trained model, we can use it to estimate parameters from observed MUAPs. Let's analyze how well our model performs.

In [None]:
# Code for analyzing results will go here
# This will include:
# 1. Generating test cases
# 2. Running parameter estimation
# 3. Visualizing the results

> **Task 3.3** Model Analysis
1. How accurate is the parameter estimation?
2. Are there any parameters that are harder to estimate than others?
3. How does the uncertainty in parameter estimates vary across the parameter space?