In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
import jax

# Generazione Dati

In [3]:
from IPython.utils.sysinfo import num_cpus

# numero di cluster
K=5

# dimensione del campione
d=1

# numero di sample
N=1000

key = random.PRNGKey(2023)
# vectors of mean of clusters
sigma=1
mu=random.normal(key,(K,d))*sigma

# cluster assignment
key = random.PRNGKey(2)
c=random.categorical(key,(1/K)*jnp.ones(shape=(K,)),axis=0,shape=(N,))
C=np.zeros(shape=(N,K))
for i in range(N):
  C[i,c[i]]=1
C=jnp.array(C)

# Data
X=jnp.matmul(C,mu)+random.normal(key,(N,d))



# Variational Inference

We construct the model for our VI algorithm, we stick to the notation of the paper. Our model is described by:

\begin{align*}
    \mu_k\ \mid x &\stackrel{\tiny\mbox{}}{\sim} \mathcal{N}\left(m_k, s^2_k\right) \\
    x_i \mid c_i,μ &\sim \mathcal{N}\left(c^T_iμ, 1\right) \\
    \mu_k &\stackrel{\tiny\mbox{}}{\sim}\mathcal{N}\left(0, \sigma^2\right) \\
\end{align*}

We work in the family of Gaussian distribution for this first attempt, then we could expand it to exponential family.


Il paper descrive l'algoritmo in un caso 1D, dobbiamo adattarlo in un generico caso multidimensionale, io ho iniziato implementando il caso unidimensionale nell'algoritmo copiandolo dal paper, consapevole che andrà adattato

In [4]:

#def update_phi(data,phi,m,s2):
#  for i in jnp.arange(data.shape[0]):
#    for k in jnp.arange(phi.shape[1]):
#      phi[i,k]=jnp.exp(jnp.matmul(m[k,:],data[i,:].transpose())-(s2[k]+jnp.matmul(m[k,:],m[k,:].transpose()))/2) # non sono così sicuro della formula per mk^2
#    phi[i,:]=phi[i,:]/jnp.sum(phi[i,:])
#  return phi

def update_phi(data, phi, m, s2):
    updated_phi = jnp.zeros_like(phi)
    for i in jnp.arange(data.shape[0]):
        for k in jnp.arange(phi.shape[1]):
            log_likelihood = jnp.dot(m[k, :], data[i, :]) - 0.5 * (s2[k] + jnp.dot(m[k, :], m[k, :]))
            updated_phi = updated_phi.at[i, k].set(jnp.exp(log_likelihood))

        updated_phi = updated_phi.at[i, :].set(updated_phi[i, :] / jnp.sum(updated_phi[i, :]))

    return updated_phi
update_phi_jit=jit(update_phi)


def update_mean_and_variance(data,phi,m,s2,sigma):
  for k in jnp.arange(phi.shape[1]):
    m.at[k,:].set(jnp.matmul(phi[:,k].transpose(),data)/(1/sigma**2+jnp.sum(phi[:,k])))
    s2.at[k].set(1/(1/sigma**2+np.sum(phi[:,k])))
  return m,s2

update_mean_and_variance_jit=jit(update_mean_and_variance)


def compute_ELBO(m,s2,phi,data):
  # when computing the ELBO value, we omit constants because once we compute the improvement they would have a total of 0
  # Fn stands for the nth component of the formula (21) in the review paper
  F1=0 # + const
  F5=0
  d=m.shape[1]
  K=phi.shape[1]
  N=data.shape[0]
  for k in jnp.arange(K):
    F1+=-1/2/sigma**2 *( d*s2[k]+jnp.matmul(m[k,:],m[k,:].transpose()))
    F5+=-d/2*jnp.log(2*jnp.pi*s2[k])
  F5+=-K*d/2
  # F2= -log(K) sum over k from 1 to K => constant in every iteration
  F3=0
  F4=0
  for i in jnp.arange(N):
    for k in jnp.arange(K):
      F3+=-0.5*phi[i,k]*(-2*jnp.matmul(data[i,:],m[k,:].transpose()+d*s2[k]+jnp.matmul(m[k,:],m[k,:].transpose())))
      # -d/2*jnp.log(2*jnp.pi)*phi[i,k] summed over i and k should be constant over time, since phi[i,:] is a probability it should sum N every time
      F4+=jnp.log(phi[i,k])*phi[i,k]

  return F1+F3+F4+F5

compute_ELBO_jit=jit(compute_ELBO)



In [5]:
# FUNCTION FOR VARIATIONAL INFERENCE
# Notation of the paper
def VI(data,K,sigma,nMAX):
  # creating our variables as estimation of parameters for posterior probabilities
  # jax arrays are immutable, so I don't know how to create these variables in jax and use them
  # I iniialize them randomly, since there is no a-priori starting point which is best than others
  N=data.shape[0]
  d=data.shape[1]
  phi=jnp.ones((N,K))/K
  m=random.normal(key,shape=(K,d))
  s2=random.uniform(key,minval=0,maxval=10,shape=(K,))
  improvement=1
  tol= 10**-5
  ELBO_old=0 # probabilmente questo andrà modificato
  ELBO_new=0
  nit=0 # number of iterations
  while improvement>tol:
    phi=update_phi_jit(data,phi,m,s2)
    m,s2=update_mean_and_variance_jit(data,phi,m,s2,sigma)
    ELBO_old=ELBO_new
    ELBO_new=compute_ELBO_jit(m,s2,phi,data)
    improvement=np.abs(ELBO_new-ELBO_old)
    nit+=1
    if nit>=nMAX:
      break
  return m,s2,phi




In [None]:
%%timeit
m,s2,phi=VI(X,K,sigma,10000)
print(m,'\n',mu)
print(s2)
print(phi,'\n',C)
