In [None]:
import numpy as np
import pandas as pd
import flax.linen as flax_nn
from jax import nn
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.diagnostics import hpdi

import patsy
import seaborn as sns
import matplotlib.pyplot as plt
import arviz as az
from GLM import models


np.random.seed(42)
n_samples = 10000
x1 = np.random.uniform(0, 3, n_samples)
x2 = np.linspace(0, 6.28, n_samples)
y = np.random.poisson((1 + 0.1 * x1 +  np.sin(x2)))
X1 = x1.reshape(-1, 1)
X2 = x2.reshape(-1, 1)
X_both = np.column_stack([x1, x2,x1*x2])
df=pd.DataFrame(X_both)
df.rename(columns={0:'x1',1:'x2',2:'x1x2'}, inplace=True)

#Define formulw
basis_x = patsy.dmatrix("bs(x1, degree=0,df=5)+bs(x2,degree=0, df=5)", return_type="dataframe")



Fit

In [None]:
nuts_kernel = NUTS(models.gaussian_prior)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=3000)
mcmc.run(jax.random.PRNGKey(0), basis_x=jnp.array(basis_x.values), y=y)
samples = mcmc.get_samples()

Plot

In [None]:
for i, coef in enumerate(samples["coefs"].T):
    sns.kdeplot(coef, label=f'Coefficient {i+1}')
plt.legend()
plt.xlabel('Coefficient Value')
plt.ylabel('Density')
plt.title('Posterior Distributions of Coefficients')
plt.show()

Compute credible intervals

In [None]:
# Compute 95% credible intervals for each coefficient
credible_intervals = np.percentile(samples["coefs"], [2.5, 97.5], axis=0)

# Compute the mean estimate of each coefficient
mean_coefs = np.mean(samples["coefs"], axis=0)

# Zero out coefficients where the 95% CI overlaps with zero
adjusted_coefs = np.where((credible_intervals[0] <= 0) & (credible_intervals[1] >= 0), 0, mean_coefs)

Make predicted 'tuning curves'

In [1]:
# Define ranges for x1 and x2
x1_range = jnp.linspace(jnp.median(x1), jnp.median(x1), 100)  # Example range for x1
x2_range = jnp.linspace(jnp.min(x2), jnp.max(x2), 100)  # Example range for x2
data = {"x1": x1_range, "x2": x2_range}
# x1_grid, x2_grid = jnp.meshgrid(x1_range, x2_range)
# 
# # Flatten grids and prepare data for patsy transformation
# x1_flat = x1_grid.flatten()
# x2_flat = x2_grid.flatten()
# data = {"x1": x1_flat, "x2": x2_flat}

# Transform the data using the CR basis function (update formula as needed)
basis_x = patsy.dmatrix("bs(x1,degree=3, df=5) + bs(x2,degree=3, df=5)", data, return_type="dataframe")
basis_x = jnp.array(basis_x.values)  # Convert to JAX array

# Compute predictions for each posterior sample
linear_preds_samples = jnp.dot(samples["coefs"], basis_x.T)
rate_preds_samples = jnp.exp(linear_preds_samples)  # Poisson rate (lambda)

# Compute mean prediction and HPDI for each x1 value
mean_rate_pred = rate_preds_samples.mean(axis=0)
hpdi_rate_pred = hpdi(rate_preds_samples, prob=0.95) 


hpdi_pred_samples = hpdi(rate_preds_samples, prob=0.95)  # 95% HPDI for predictive samples
hpdi_mean = hpdi(mean_rate_pred, prob=0.95)  # Uncertainty in mean




# 
# 
# ## FOR mean prediciton
# # Compute the predictions using adjusted coefficients
# linear_pred = jnp.dot(basis_x, adjusted_coefs)
# rate_pred = jnp.exp(linear_pred)  # Poisson rate (lambda) for the GLM
# 
# # Reshape predictions to match grid shape
# rate_pred = rate_pred.reshape(x1_grid.shape)

NameError: name 'jnp' is not defined

Model selection

In [None]:
from numpyro.infer import Predictive

# Get posterior samples
samples = mcmc.get_samples()

# Set up predictive distribution
predictive = Predictive(model_Gaussian, posterior_samples=samples)

# Generate posterior predictive samples
y_pred = predictive(jax.random.PRNGKey(1), jnp.array(basis_x))
##UNCOMMENT usage for mcmc penRegSpline models
# y_pred = predictive(jax.random.PRNGKey(1), basis_x_list=basis_x_list, S_list=S_list)

# Compare predicted y_pred['y'] to your actual y values


# Convert samples to ArviZ inference data
inference_data = az.from_numpyro(mcmc)

# Compute WAIC
waic = az.waic(inference_data)
print("WAIC:", waic)
loo = az.loo(inference_data)
 print("LOO-CV:", loo)

# Assume waic_scores contains WAIC scores for models
weights = np.exp(-0.5 * np.array(waic_scores))
weights /= weights.sum()  # Normalize to sum to 1