# This Jupyter Notebook is about Bayesian Data Analysis for neuroscience data

## Introduction

This notebook is part of a 20-week internship project carried out at Ulster University.
The main goal of the project is to make advanced Bayesian statistical models more accessible
to experimental neuroscientists through user-friendly code, tutorials, and examples.

Specifically, this notebook focuses on applying existing Bayesian models to neuroscience datasets
using libraries in Python.

## Objectives

1. Apply the existing Bayesian models to a neuroscience dataset from scratch,
   documenting each step as if it were for a beginner user.

2. Design a simple and reproducible analysis pipeline using PyMC.

3. Produce clear, well-documented code that can later be integrated into
   an interactive tutorial or a web application.

## Author

- Mathis DA SILVA
- Ulster University Internship (July–December 2025)
- Supervisors: Dr. Cian O'Donnell & Dr. Conor Houghton

## References

- [Dataset from "Classification of psychedelics and psychoactive drugs based on brain-wide imaging of cellular c-Fos expression"](https://www.nature.com/articles/s41467-025-56850-6#Sec25)
- [Hierarchical Bayesian modeling of multi-region brain cell count data](https://elifesciences.org/reviewed-preprints/102391v1)
- [Statistical Rethinking 2023 PDF](https://civil.colorado.edu/~balajir/CVEN6833/bayes-resources/RM-StatRethink-Bayes.pdf)
- [Statistical Rethinking 2023 Videos](https://www.youtube.com/watch?v=FdnMWdICdRs&list=PLDcUM9US4XdPz-KxHM4XHt7uUVGWWVSus)

Here, we call libraries that we will use in this notebook for the moment.

In [None]:
import numpy as np
import pymc as pm
import pandas as pd
import matplotlib.pyplot as plt
import arviz as az
import seaborn as sns

We will use the dataset from the paper "Classification of psychedelics and psychoactive drugs based on brain-wide imaging of cellular c-Fos expression".



In [None]:
dataset = pd.read_excel('data/dataset_neuroscience_vo.xlsx')

dataset

#### Indications:

Previously, we added the dataset. In which, the first three columns represent brain regions with name and abbreviation. Others represent mice group by drugs as **MDMA**, **Ketamine**, **Fluoxetine**, ...

There are **64 mice** in total, and each mouse has a value for each brain region. The values represent the number of cells expressing c-Fos, a marker of **neuronal activity**. Plus, there are **315 brain regions** in the dataset.

### Statistical Models
---

We will use a **Hierarchical Bayesian Model** to analyze the dataset. The model will allow us to account for the hierarchical structure of the data, where measurements are nested within brain regions and mice. Models are: **Poisson**, **Horseshoe** and **Zero-Inflated Poisson (ZIP)**.

For the moment, we will use a part of the dataset, specifically the first 20 brain regions and 2 groups of mice.

In [None]:
dataset1 = pd.read_excel('data/dataset_neuroscience_1.xlsx')

dataset1

We will prepare the data for the model.

In [None]:
def prepare_data(dataset):
    """Prépare les données pour l'analyse PyMC"""
    # Supprimer la dernière ligne qui contient les informations de sexe
    data_clean = dataset.iloc[:-1].copy()

    # Séparer les métadonnées des données de comptage
    metadata = data_clean[['abbreviation', 'region name', 'brain area']]
    count_data = data_clean.iloc[:, 3:]  # Toutes les colonnes de comptage

    # Convertir en format long pour l'analyse
    count_data_long = []
    region_indices = []
    group_indices = []

    for region_idx, region in enumerate(metadata['abbreviation']):
        for col_idx, column in enumerate(count_data.columns):
            group = column.split(' ')[0]  # Extraire le nom du groupe (A-SSRI, C-SSRI)
            count = count_data.iloc[region_idx, col_idx]

            count_data_long.append(count)
            region_indices.append(region_idx)
            group_indices.append(0 if 'A-SSRI' in column else 1)  # A-SSRI=0, C-SSRI=1

    return {
        'counts': np.array(count_data_long, dtype=int),
        'region_idx': np.array(region_indices),
        'group_idx': np.array(group_indices),
        'n_regions': len(metadata),
        'n_groups': 2,
        'region_names': metadata['abbreviation'].tolist(),
        'group_names': ['A-SSRI', 'C-SSRI']
    }

# Préparer les données
data = prepare_data(dataset1)
print(f"Nombre de régions: {data['n_regions']}")
print(f"Nombre de groupes: {data['n_groups']}")
print(f"Nombre total d'observations: {len(data['counts'])}")
print(f"Régions: {data['region_names'][:5]}...")

Now, we will build Poisson model using PyMC.

Here, there are some visulization of this model:

\begin{gather*}
y_i \sim Poisson(\lambda_i)\\
log(\lambda_i) = E_i + \gamma_i\\
\theta_{rg} \sim Normal(5, 2)\\
\tau_{rg} \sim HalfNormal(log(1.05))\\
\gamma_i \sim Normal(\theta_{r[i]g[i]}, \tau_{r[i]g[i]})\\
\end{gather*}

In [None]:
with pm.Model() as poisson_model:

    # Hyperpriors pour chaque combinaison région-groupe
    theta = pm.Normal('theta', mu=5, sigma=2,
                     shape=(data['n_regions'], data['n_groups']))

    tau = pm.HalfNormal('tau', sigma=np.log(1.05),
                       shape=(data['n_regions'], data['n_groups']))

    # Effets individuels pour chaque observation
    gamma = pm.Normal('gamma',
                     mu=theta[data['region_idx'], data['group_idx']],
                     sigma=tau[data['region_idx'], data['group_idx']],
                     shape=len(data['counts']))

    # Paramètre du taux de Poisson
    lambda_i = pm.math.exp(gamma)

    # Vraisemblance
    y_obs = pm.Poisson('y_obs', mu=lambda_i, observed=data['counts'])

print("Modèle Poisson construit avec succès!")
print(poisson_model)


In [None]:
# Ajustement du modèle Poisson
print("Ajustement du modèle Poisson en cours...")
with poisson_model:
    trace_poisson = pm.sample(draws=1000, chains=2,
                             target_accept=0.8,
                             return_inferencedata=True)

print("Ajustement terminé!")

In [None]:
  # Diagnostics du modèle Poisson
print("=== DIAGNOSTICS MODÈLE POISSON ===")

# Résumé
summary_poisson = az.summary(trace_poisson)
print("\nRésumé des paramètres:")
print(summary_poisson.head(10))  # Afficher les 10 premiers paramètres

# Vérification de convergence
rhat_max = summary_poisson['r_hat'].max()
print(f"\nR-hat maximum: {rhat_max:.3f}")
if rhat_max < 1.1:
    print("✓ Convergence OK (R-hat < 1.1)")
else:
    print("Problème de convergence possible")

# Graphiques de diagnostic
az.plot_trace(trace_poisson, var_names=['theta', 'tau'], compact=True)
plt.tight_layout()
plt.show()

Now, we will build Horseshoe model using PyMC.

Here, there are some visulization of this model:

\begin{gather*}
y_i \sim Poisson(\lambda_i)\\
log(\lambda_i) = E_i + \gamma_i\\
\theta_{rg} \sim Normal(5, 2)\\
\tau_{rg} \sim HalfNormal(log(1.05))\\
\kappa_i \sim HalfNormal(1)\\
\gamma_i \sim Normal(\theta_{r[i]g[i]}, \kappa_i\times\tau_{r[i]g[i]})\\
\end{gather*}

In [None]:
with pm.Model() as horseshoe_model:
    # 1. Moyenne globale
    mu_global = pm.Normal("mu_global", mu=5, sigma=2)

    # 2. Effets hiérarchiques par région
    sigma_region = pm.HalfNormal("sigma_region", sigma=1)
    region_raw = pm.Normal("region_raw", mu=0, sigma=1, shape=data['n_regions'])
    region_effects = pm.Deterministic("region_effects", region_raw * sigma_region)

    # 3. Effets hiérarchiques par groupe
    sigma_group = pm.HalfNormal("sigma_group", sigma=1)
    group_raw = pm.Normal("group_raw", mu=0, sigma=1, shape=data['n_groups'])
    group_effects = pm.Deterministic("group_effects", group_raw * sigma_group)

    # 4. Shrinkage Horseshoe pour chaque observation
    kappa = pm.HalfNormal('kappa', sigma=1, shape=len(data['counts']))
    gamma_raw = pm.Normal("gamma_raw", mu=0, sigma=1, shape=len(data['counts']))

    # 5. Combinaison de tous les effets
    mu_i = (mu_global +
            region_effects[data['region_idx']] +
            group_effects[data['group_idx']])

    gamma = pm.Deterministic("gamma", mu_i + gamma_raw * kappa)

    # 6. Paramètre du taux de Poisson
    lambda_i = pm.math.exp(gamma)

    # 7. Vraisemblance
    y_obs = pm.Poisson('y_obs', mu=lambda_i, observed=data['counts'])

print("✓ Modèle Horseshoe construit avec succès!")

Now, we will build Zero-inflated Poisson model using PyMC.

Here, there are some visulization of this model:

\begin{gather*}
y_i \sim ZIPoisson(\lambda_i,\pi)\\
log(\lambda_i) = E_i + \gamma_i\\
\pi \sim Beta(1,5)\\
\theta_{rg} \sim Normal(5, 2)\\
\tau_{rg} \sim HalfNormal(log(1.05))\\
\gamma_i \sim Normal(\theta_{r[i]g[i]}, \tau_{r[i]g[i]})\\
\end{gather*}

In [None]:
with pm.Model() as zero_inflatedpoisson_model:

    # Paramètre de zero-inflation
    pi = pm.Beta('pi', alpha=1, beta=5)

    # Hyperpriors pour chaque combinaison région-groupe
    theta = pm.Normal('theta', mu=5, sigma=2,
                     shape=(data['n_regions'], data['n_groups']))

    tau = pm.HalfNormal('tau', sigma=np.log(1.05),
                       shape=(data['n_regions'], data['n_groups']))

    # Effets individuels
    gamma = pm.Normal('gamma',
                     mu=theta[data['region_idx'], data['group_idx']],
                     sigma=tau[data['region_idx'], data['group_idx']],
                     shape=len(data['counts']))

    # Paramètre du taux de Poisson
    lambda_i = pm.math.exp(gamma)

    # Vraisemblance Zero-Inflated Poisson
    y_obs = pm.ZeroInflatedPoisson('y_obs', mu=lambda_i, psi=pi,
                                  observed=data['counts'])

print("Modèle Zero-Inflated Poisson construit avec succès!")

In [None]:
# Ajustement des trois modèles
models = {
    'Poisson': poisson_model,
    'Horseshoe': horseshoe_model,
    'ZIP': zero_inflatedpoisson_model
}

traces = {}

for model_name, model in models.items():
    print(f"\n=== AJUSTEMENT MODÈLE {model_name.upper()} ===")
    try:
        with model:
            trace = pm.sample(draws=1000, chains=2,
                             target_accept=0.8,
                             return_inferencedata=True)
        traces[model_name] = trace
        print(f"✓ Modèle {model_name} ajusté avec succès!")

    except Exception as e:
        print(f"Erreur modèle {model_name}: {e}")

print(f"\n{len(traces)} modèles ajustés avec succès!")

In [None]:
# Comparaison des modèles
if len(traces) > 1:
    print("=== COMPARAISON DES MODÈLES ===")

    # Calcul WAIC pour chaque modèle
    waic_results = {}
    for name, trace in traces.items():
        waic = az.waic(trace)
        waic_results[name] = waic
        print(f"{name}: WAIC = {waic.waic:.1f} ± {waic.se:.1f}")

    # Comparaison formelle
    comparison = az.compare(traces)
    print("\nComparaison détaillée:")
    print(comparison)

    # Graphique de comparaison
    az.plot_compare(comparison)
    plt.title('Comparaison des modèles (WAIC)')
    plt.show()

In [None]:
# Visualisation des résultats du meilleur modèle
best_model_name = comparison.index[0]  # Le modèle avec le meilleur WAIC
best_trace = traces[best_model_name]

print(f"Visualisation du meilleur modèle: {best_model_name}")

# Graphique des effets par groupe et région
posterior = best_trace.posterior

# Moyennes postérieures de theta
theta_mean = posterior['theta'].mean(dim=['chain', 'draw'])

# Créer un DataFrame pour la visualisation
results_df = []
for r in range(data['n_regions']):
    for g in range(data['n_groups']):
        results_df.append({
            'Region': data['region_names'][r],
            'Group': data['group_names'][g],
            'Theta': float(theta_mean[r, g])
        })

results_df = pd.DataFrame(results_df)

# Graphique en barres
plt.figure(figsize=(12, 8))
sns.barplot(data=results_df, x='Region', y='Theta', hue='Group')
plt.xticks(rotation=45)
plt.title(f'Effets estimés par région et groupe - Modèle {best_model_name}')
plt.tight_layout()
plt.show()

# Heatmap des différences entre groupes
theta_diff = theta_mean[:, 1] - theta_mean[:, 0]  # C-SSRI - A-SSRI
plt.figure(figsize=(10, 6))
plt.barh(range(len(data['region_names'])), theta_diff)
plt.yticks(range(len(data['region_names'])), data['region_names'])
plt.xlabel('Différence C-SSRI - A-SSRI')
plt.title('Différences entre groupes par région')
plt.axvline(x=0, color='red', linestyle='--')
plt.tight_layout()
plt.show()