In [54]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

# Generazione Dati

In [57]:
from IPython.utils.sysinfo import num_cpus

# numero di cluster
K=5

# dimensione del campione
d=2

# numero di sample
N=1000

key = random.PRNGKey(2023)
# vectors of mean of clusters
mu=random.normal(key,(K,d))

# cluster assignment
key = random.PRNGKey(2)
c=random.categorical(key,(1/K)*jnp.ones(shape=(K,)),axis=0,shape=(N,))
C=np.zeros(shape=(N,K))
for i in range(N):
  C[i,c[i]]=1
C=jnp.array(C)

# Data
X=jnp.matmul(C,mu)+random.normal(key,(N,d))

Array([[ 0.90678126,  0.44637144],
       [-0.6740475 ,  0.31190777],
       [-1.8890792 ,  1.2733475 ],
       ...,
       [-0.38996434, -0.1008999 ],
       [ 0.58381635, -0.42401785],
       [-0.79596436, -0.98130786]], dtype=float32)

# Variational Inference

We construct the model for our VI algorithm, we stick to the notation of the paper. Our model is described by:

\begin{align*}
    \mu_k\ \mid x &\stackrel{\tiny\mbox{iid}}{\sim} \mathcal{N}\left(m_k, s^2_k\right) \\
    x_i \mid c_i,μ &\sim \mathcal{N}\left(c^T_iμ, 1\right) \\
    \mu_k &\stackrel{\tiny\mbox{iid}}{\sim} \mathcal{N}\left(0, \sigma^2\right) \\
\end{align*}

We work in the family of Gaussian distribution for this first attempt, then we could expand it to exponential family.


Il paper descrive l'algoritmo in un caso 1D, dobbiamo adattarlo in un generico caso multidimensionale, io ho iniziato implementando il caso unidimensionale nell'algoritmo copiandolo dal paper, consapevole che andrà adattato

In [64]:
def update_phi(data,phi,m,s2):
  for i in jnp.arange(data.shape[0]):
    for k in jnp.arange(phi.shape[1]):
      phi[i,k]=jnp.exp(jnp.matmul(m[k,],data[i,].transpose())-(s2[k]+jnp.matmul(m[k,],m[k,].transpose()))/2) # non sono così sicuro della formula per mk^2
    phi[i,]=phi[i,]/jnp.sum(phi[i,])
  return phi

update_phi_jit=jit(update_phi)

def update_mean_and_variance(data,phi,m,s2,sigma):
  for k in jnp.arange(phi.shape[1]):
    m[k]=jnp.matlmul(phi[:,k].transpose(),data)/(1/sigma**2+jnp.sum(phi[:,k]))
    s2[k]=1/(1/sigma**2+jnp.sum(phi[:,k]))
  return m,s2

update_mean_and_variance_jit=jit(update_mean_and_variance)

def compute_ELBO(m,s2,phi):

  return

compute_ELBO_jit=jit(compute_ELBO)



In [60]:
# FUNCTION FOR VARIATIONAL INFERENCE
# Notation of the paper
def VI(data,K,sigma):
  # creating our variables as estimation of parameters for posterior probabilities
  # jax arrays are immutable, so I don't know how to create these variables in jax and use them
  # I iniialize them randomly, since there is no a-priori starting point which is best than others
  N=data.shape[0]
  d=data.shape[1]
  phi=np.zeros((N,K))
  for i in range(N):
    phi[np.random.categorical(K,p=1/K*np.ones(K))]=1
  m=np.random.normal(size=(K,d))
  s2=np.random.uniform(0,10,size=(K,))
  improvement=1
  tol= 10**-10
  ELBO_old=0 # probabilmente questo andrà modificato
  ELBO_new=0
  nit=0 # number of iterations
  while improvement>tol:
    phi=update_phi_jit(data,phi,m,s2)
    m,s2=update_mean_and_variance_jit(data,phi,m,s2,sigma)
    ELBO_old=ELBO_new
    ELBO_new=compute_ELBO_jit(m,s2,phi)
    improvement=jnp.abs(ELBO_new-ELBO_old)
    nit+=1
  return m,s2,phi


