# binary concrete spike and slab group-wise  poisson regression
In Binary Concrete, you relax the hard binary variable $\gamma$ element in {0,1} to a continuous approximation using the gumbel-softmax trick:

$$
\gamma_i = \mathrm{sigmoid}\left(\frac{\log \alpha + \text{Gumbel noise}}{\tau}\right)
$$

- you sample $\gamma_i$ during inference
- Those samples represent how "on" or "off" the gate is for each group
- over many posterior samples,  ($\gamma_i$), but tends to 0 or 1 (with annealing of  $\tau$ )

## use this if: 
- You want principled model selection
- You're comparing sparsity structures
- You want to report posterior inclusion probabilities (which only stochastic gates give you)

## deets

 - Works in SVI	
-  Matches prior stochasticity	
-  Encourages sparsity		
- Supports uncertainty quantification by capturing gate uncertainty
- Convergence stabilityL:️ Slightly noisier early on
🔮 Predictive calibration		✅ Better calibrated posterior
🧊 Temperature Annealing		✅ Crucial for gating effect


## posterior inclusion computation
So how do you get the posterior inclusion probability?
After inference (SVI), collect all the samples of gamma_i:

gamma_samples = posterior_samples["gamma_3"]  # shape: (n_samples,)
Then simply compute the mean:

inclusion_prob = jnp.mean(gamma_samples) <br>
- If inclusion_prob ≈ 1.0: very strong evidence this group matters
- If ≈ 0.0: the group is off
- If ≈ 0.5: uncertain, possibly borderline <br>

This is a posterior estimate of our posterior belief of group inclusion: <br>
$$
p(\gamma_i=1|data) ~= Expectation[$\gamma_i$|data]
$$



In [1]:
#imports
from BayesBrain import models,datasim,glm,utils
import jax.numpy as jnp


  from .autonotebook import tqdm as notebook_tqdm


#### Simulate poisson data with 2 groups, 1 is very relevant, the other not really

In [2]:
sims=datasim.simulate_poisson_grouped()
X_dsgn=sims[0]
Y=sims[1]

#### obtain default params for model

In [3]:
paramglm=utils.param_defaults_bayes(modname='grouped_ss_concrete')
paramglm.update({'probs':0.1})
paramglm['visteps']=50000
paramglm['prior_alpha']=0.01
paramglm['type']='zip'
mod2fitall = glm.PoissonGLMbayes()

mod2fitall.add_data(y=jnp.array(Y))

# Learn smoothness from data
mod2fitall.define_model(model='grouped_ss_concrete', basis_x_list=X_dsgn, S_list=None,
                          tensor_basis_list=None, S_tensor_list=None)

mod2fitall.fit(params=paramglm,  fit_intercept=True)

#Got to fix post sampling checking to beta_i * lambda_ard[i]
# mod2fitall.sample_posterior(5000).summarize_posterior(90).coeff_relevance()


100%|██████████| 10000/10000 [00:02<00:00, 4896.99it/s, init loss: 1032116.3125, avg. loss [9501-10000]: 1934.1447]


<GLM.glm.PoissonGLMbayes at 0x2aed58ca0>

In [None]:
import jax
nsamples=5000
poster = mod2fitall.guide.sample_posterior(jax.random.PRNGKey(1), mod2fitall.svi_result.params,
                                                                 sample_shape=(nsamples,))