In [None]:
import os
os.environ["CC"] = "/opt/rh/devtoolset-10/root/usr/bin/gcc"
os.environ["CXX"] = "/opt/rh/devtoolset-10/root/usr/bin/g++"

In [None]:
import nest_asyncio
nest_asyncio.apply()

import stan
import numpy as np

In [None]:
!gcc --version

In [None]:
cm_code = """
data {
  int<lower=1> N; // number of confusion matrices
  int<lower=1> K; // total count
  int y[N, 4]; // multinomial observations
  real<lower=0> alpha; // hyperprior for population distribution

  real<lower=0> gam_a;
  real<lower=0> gam_b;
}

parameters {
  simplex[4] pi_pop; // overarching distribution of CM probabilities
  simplex[4] pi[N]; // probabilities for all N confusion matrices
  real<lower=0.000001> gam;
}

model {
  pi_pop ~ dirichlet(rep_vector(alpha, 4));
  gam ~ gamma(gam_a,gam_b);
  {
  for (n in 1:N) {
      pi[n] ~ dirichlet(gam * pi_pop);
      y[n] ~ multinomial(pi[n]);
    }
  }
}

generated quantities {
    vector[4] pi_hat = dirichlet_rng(pi_pop * gam);
    int y_hat[4] = multinomial_rng(pi_hat, K);
}
"""

In [None]:
K = 1000 # n_obs
probs = [0.1,0.2,0.5,0.2]
N = 10 # n_cms
gam_a = 0.001
gam_b = 0.001

y= np.random.multinomial(K,probs,N)
pooled_data_dict = {
    'N': N,
    'K':K,
    'y': y, 
    'alpha':1,
    'gam_a':gam_a,
    'gam_b':gam_b
  }

In [None]:
posterior = stan.build(cm_code,data=pooled_data_dict,random_seed=1)
fit = posterior.sample(num_chains=4,num_samples=10000)

In [None]:
print(np.mean(fit['pi_pop'],axis=1))

In [None]:
fit['pi_hat'].mean(axis=1) # 
probs = [0.1,0.2,0.5,0.2]

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.style.use('ggplot')

In [None]:
fit['y_hat'].sum(axis=0)

In [None]:
d = 0
fig, ax = plt.subplots(figsize=(14, 7))
sns.kdeplot(fit['pi_hat'][d, :], ax=ax)
ax.axvline(x=probs[d], label='probability')
ax.axvline(x=(y/K).mean(axis=0)[d], label='y_proba',ls='--')
ax.legend()

In [None]:
(y/K).mean(axis=0)

In [None]:
d = 1
fig, ax = plt.subplots(figsize=(14, 7))
sns.kdeplot(fit['pi_hat'][d, :], ax=ax)
ax.axvline(x=probs[d], label='probability')
ax.axvline(x=(y/K).mean(axis=0)[d], label='y_proba',ls='--')
ax.legend()

In [None]:
d = 2
fig, ax = plt.subplots(figsize=(14, 7))
sns.kdeplot(fit['pi_hat'][d, :], ax=ax)
ax.axvline(x=probs[d], label='probability')
ax.axvline(x=(y/K).mean(axis=0)[d], label='y_proba',ls='--')
ax.legend()

In [None]:
d = 3
fig, ax = plt.subplots(figsize=(14, 7))
sns.kdeplot(fit['pi_hat'][d, :], ax=ax)
ax.axvline(x=probs[d], label='probability')
ax.axvline(x=(y/K).mean(axis=0)[d], label='y_proba',ls='--')
ax.legend()

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
sns.kdeplot(fit['gam'].flatten(), ax=ax, label='post')
sns.kdeplot(np.random.gamma(gam_a, gam_b, size=fit['gam'].size), label='prior')
ax.legend()