# Example Usage for `mix_gamma_vi`

In [10]:
from mix_gamma_vi import mix_gamma_vi
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

## Generate Dataset

Generate 10000 data from a mixture of gamma two gamma distributions.

In [11]:
N = 10000
pi_true = [0.5, 0.5]
a_true  = [20,  80 ]
B_true  = [20,  40 ]

mix_gamma = tfp.distributions.MixtureSameFamily(
    mixture_distribution=tfp.distributions.Categorical(probs=pi_true),
    components_distribution=tfp.distributions.Gamma(concentration=a_true, rate=B_true))

x = mix_gamma.sample(N)

## Variational Inference Under the Shape-Mean Parameterisation

The defualt parameterisation for the function `mix_gamma_vi` is the mean-shape parameterisation under which the variational approximations to the posterior are

\begin{align*}
q^*(\mathbf{\pi}) &= \mathrm{Dirichlet} \left( \zeta_1, ..., \zeta_K \right) ,  \\
q^*(\alpha_k) &= \mathcal{N}(\hat{\alpha}_k, \sigma_j^2) ,  \\
q^* (\mu_k) &=  \operatorname{Inv-Gamma} \left( \gamma_k, \lambda_k \right) .  
\end{align*}

\begin{align*}
q^*(\mathbf{\pi}) &= \mathrm{Dirichlet} \left( \zeta_1, ..., \zeta_K \right) .
\end{align*}

The product approximates the joint posterior

\begin{align*}
p(\mathbf{\pi}, \mathbf{\alpha}, \mathbf{\mu} \mid \mathbf{x}) &= q^*(\mathbf{\pi}) \prod_{k=1}^K q^*(\alpha_k) q^*(\mu_k).
\end{align*}

In [12]:
# Fit a model
fit = mix_gamma_vi(x, 2)

# Get the fitted distribution
distribution = fit.distribution()

# Get the means of the parameters under the fitted posterior
distribution.mean()

{'pi': <tf.Tensor: id=2699, shape=(1, 2), dtype=float32, numpy=array([[0.502409  , 0.49759102]], dtype=float32)>,
 'mu': <tf.Tensor: id=2706, shape=(1, 2), dtype=float32, numpy=array([[1.0024861, 2.0024283]], dtype=float32)>,
 'alpha': <tf.Tensor: id=2710, shape=(1, 2), dtype=float32, numpy=array([[20.434174, 77.96415 ]], dtype=float32)>}

In [13]:
# Get the posterior standard deviations
distribution.stddev()

{'pi': <tf.Tensor: id=2720, shape=(1, 2), dtype=float32, numpy=array([[0.00499919, 0.00499919]], dtype=float32)>,
 'mu': <tf.Tensor: id=2733, shape=(1, 2), dtype=float32, numpy=array([[0.00312878, 0.00321495]], dtype=float32)>,
 'alpha': <tf.Tensor: id=2737, shape=(1, 2), dtype=float32, numpy=array([[0.4077023, 1.5630537]], dtype=float32)>}

## Variational Inference Under the Shape-Rate Parameterisation

The traditional parameterisation for gamma distribution is the shape-rate parameterisation which this package also supports (although it is not recommended). In this case, the variational approximations to the posterior are

\begin{align*}
q^*(\mathbf{\pi}) &= \mathrm{Dirichlet} \left( \zeta_1, ..., \zeta_K \right) ,  \\
q^*(\alpha_k) &= \mathcal{N}(\hat{\alpha}_k, \sigma_k^2) , \\
q^* (\beta_k) &=  \operatorname{Gamma} \left( \gamma_j, \lambda_j \right) .  
\end{align*}

The product approximates the joint posterior

\begin{align*}
p(\mathbf{\pi}, \mathbf{\alpha}, \mathbf{\beta} \mid \mathbf{x}) &= q^*(\mathbf{\pi}) \prod_{k=1}^K q^*(\alpha_k) q^*(\beta_k) .
\end{align*}

In [14]:
# Fit a model
fit = mix_gamma_vi(x, 2, parameterisation="shape-rate")

# Get the fitted distribution
distribution = fit.distribution()

# Get the means of the parameters under the fitted posterior
distribution.mean()

{'pi': <tf.Tensor: id=2748, shape=(1, 2), dtype=float64, numpy=array([[0.50260262, 0.49739738]])>,
 'beta': <tf.Tensor: id=2755, shape=(1, 2), dtype=float64, numpy=array([[0.04935887, 0.0256532 ]])>,
 'alpha': <tf.Tensor: id=2759, shape=(1, 2), dtype=float64, numpy=array([[20.3168614 , 78.06337287]])>}

In [15]:
# Get the posterior standard deviations
distribution.stddev()

{'pi': <tf.Tensor: id=2769, shape=(1, 2), dtype=float64, numpy=array([[0.00499918, 0.00499918]])>,
 'beta': <tf.Tensor: id=2782, shape=(1, 2), dtype=float64, numpy=array([[1.54471779e-04, 4.11686881e-05]])>,
 'alpha': <tf.Tensor: id=2786, shape=(1, 2), dtype=float64, numpy=array([[0.06295893, 0.12455087]])>}

So, the standard deviation of $\mathbf{\alpha}$ under the shape-rate parameterisation is much lower than it is under the shape-mean parameterisation.