In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
import jax

# Generazione Dati

In [87]:
from IPython.utils.sysinfo import num_cpus

# numero di cluster
K=5

# dimensione del campione
d=2

# numero di sample
N=1000

key = random.PRNGKey(2023)
# vectors of mean of clusters
sigma=1
mu=random.normal(key,(K,d))*sigma

# cluster assignment
key = random.PRNGKey(2)
c=random.categorical(key,(1/K)*jnp.ones(shape=(K,)),axis=0,shape=(N,))
C=np.zeros(shape=(N,K))
for i in range(N):
  C[i,c[i]]=1
C=jnp.array(C)

# Data
X=jnp.matmul(C,mu)+random.normal(key,(N,d))

# Variational Inference

We construct the model for our VI algorithm, we stick to the notation of the paper. Our model is described by:

\begin{align*}
    \mu_k\ \mid x &\stackrel{\tiny\mbox{}}{\sim} \mathcal{N}\left(m_k, s^2_k\right) \\
    x_i \mid c_i,μ &\sim \mathcal{N}\left(c^T_iμ, 1\right) \\
    \mu_k &\stackrel{\tiny\mbox{}}{\sim}\mathcal{N}\left(0, \sigma^2\right) \\
\end{align*}

We work in the family of Gaussian distribution for this first attempt, then we could expand it to exponential family.


Il paper descrive l'algoritmo in un caso 1D, dobbiamo adattarlo in un generico caso multidimensionale, io ho iniziato implementando il caso unidimensionale nell'algoritmo copiandolo dal paper, consapevole che andrà adattato

In [88]:

#def update_phi(data,phi,m,s2):
#  for i in jnp.arange(data.shape[0]):
#    for k in jnp.arange(phi.shape[1]):
#      phi[i,k]=jnp.exp(jnp.matmul(m[k,:],data[i,:].transpose())-(s2[k]+jnp.matmul(m[k,:],m[k,:].transpose()))/2) # non sono così sicuro della formula per mk^2
#    phi[i,:]=phi[i,:]/jnp.sum(phi[i,:])
#  return phi

def update_phi(data, phi, m, s2):
    updated_phi = jnp.zeros_like(phi)
    M=jnp.matmul(m,m.T)
    log_likelihood = jnp.matmul(data,m.T) - 0.5 * jnp.matmul(jnp.ones(shape=(N,1)),(s2.T + jnp.resize(jnp.diag(M),(1,K))))
    updated_phi = jnp.exp(log_likelihood)

    updated_phi /= jnp.sum(updated_phi,axis=1,keepdims=True)

    return updated_phi
update_phi_jit=jit(update_phi)


def update_mean_and_variance(data,phi,sigma):
  K=phi.shape[1]
  d=data.shape[1]
  N=data.shape[0]
  updated_m=(jnp.matmul(phi.T,data)/(1/sigma**2*jnp.ones(shape=(K,d))+jnp.matmul(jnp.resize(jnp.sum(phi,axis=0),(K,1)),jnp.ones(shape=(1,d)))))
  updated_s2=(1/(1/sigma**2*jnp.ones(shape=(K,1))+jnp.resize(jnp.sum(phi,axis=0),(K,1))))
  return updated_m,updated_s2

update_mean_and_variance_jit=jit(update_mean_and_variance)


def compute_ELBO(m,s2,phi,data):
  # when computing the ELBO value, we omit constants because once we compute the improvement they would have a total of 0
  # Fn stands for the nth component of the formula (21) in the review paper
  d=m.shape[1]
  M=jnp.matmul(m,m.T)
  F1=-0.5*sigma**2 *jnp.sum( d*s2+jnp.diag(M))
  F5=-d/2*jnp.sum(jnp.log(2*jnp.pi*s2))-0.5*(jnp.sum(jnp.diag(jnp.matmul(phi.T,jnp.matmul(data,m.T))))+d*jnp.sum(jnp.matmul(phi,s2))+jnp.sum(jnp.matmul(phi,jnp.diag(M))))
  # F2= -log(K) sum over k from 1 to K => constant in every iteration
  F3=-0.5*jnp.sum(jnp.matmul(phi.T,jnp.diag(jnp.matmul(data,data.T))))
  # (-2*jnp.matmul(data[i,:],m[k,:].transpose()+d*s2[k]+jnp.matmul(m[k,:],m[k,:].transpose())))
  # -d/2*jnp.log(2*jnp.pi)*phi[i,k] summed over i and k should be constant over time, since phi[i,:] is a probability it should sum N every time
  F4=jnp.sum(jnp.log(phi)*phi)

  return F1+F3+F4+F5

compute_ELBO_jit=jit(compute_ELBO)



In [98]:
# FUNCTION FOR VARIATIONAL INFERENCE
# Notation of the paper
def VI(data,K,sigma,nMAX,niniz):
  # creating our variables as estimation of parameters for posterior probabilities
  # jax arrays are immutable, so I don't know how to create these variables in jax and use them
  # I iniialize them randomly, since there is no a-priori starting point which is best than others
  ELBO_max=0
  for i in jnp.arange(niniz):
    N=data.shape[0]
    d=data.shape[1]
    phi=jnp.ones((N,K))/K
    key = random.PRNGKey(i)
    m=random.normal(key,shape=(K,d))
    s2=random.uniform(key,minval=0,maxval=10,shape=(K,1))
    improvement=1
    tol= 10**-12
    ELBO_old=0 # probabilmente questo andrà modificato
    ELBO_new=compute_ELBO_jit(m,s2,phi,data)
    nit=0
    if i==0:
      ELBO_max=ELBO_new
      m_max=m
      s2_max=s2
      phi_max=phi
      n_max=0
    print('Iter ',nit,'\t ELBO: ',ELBO_new,'\t Improvement: ',improvement,'\n')
    print('=================================================\n')
    # number of iterations
    while (improvement>tol and nit<nMAX) or (nit<100):
      phi=update_phi_jit(data,phi,m,s2)
      m,s2=update_mean_and_variance_jit(data,phi,sigma)
      ELBO_old=ELBO_new
      ELBO_new=compute_ELBO_jit(m,s2,phi,data)
      improvement=ELBO_new-ELBO_old
      nit+=1
      print('Iter ',nit,'\t ELBO: ',ELBO_new,'\t Improvement: ',improvement,'\n')
      print('=================================================\n')
    if ELBO_new>ELBO_max:
      ELBO_max=ELBO_new
      m_max=m
      s2_max=s2
      phi_max=phi
      n_max=i
  print('Best initialization at ',n_max+1,' \t ELBO: ',ELBO_max)
  return m_max,s2_max,phi_max




In [99]:

m,s2,phi=VI(X,K,sigma,10000,10)
print(m,'\n',mu)
print(s2)
print(phi,'\n',C)


Iter  0 	 ELBO:  -8938.941 	 Improvement:  1 


Iter  1 	 ELBO:  -3550.0874 	 Improvement:  5388.854 


Iter  2 	 ELBO:  -4196.209 	 Improvement:  -646.1216 


Iter  3 	 ELBO:  -4281.743 	 Improvement:  -85.53418 


Iter  4 	 ELBO:  -4343.073 	 Improvement:  -61.33008 


Iter  5 	 ELBO:  -4381.3066 	 Improvement:  -38.2334 


Iter  6 	 ELBO:  -4402.3193 	 Improvement:  -21.012695 


Iter  7 	 ELBO:  -4413.4365 	 Improvement:  -11.1171875 


Iter  8 	 ELBO:  -4419.4043 	 Improvement:  -5.9677734 


Iter  9 	 ELBO:  -4422.716 	 Improvement:  -3.3115234 


Iter  10 	 ELBO:  -4424.6113 	 Improvement:  -1.8955078 


Iter  11 	 ELBO:  -4425.723 	 Improvement:  -1.1118164 


Iter  12 	 ELBO:  -4426.379 	 Improvement:  -0.6557617 


Iter  13 	 ELBO:  -4426.765 	 Improvement:  -0.38623047 


Iter  14 	 ELBO:  -4426.985 	 Improvement:  -0.21972656 


Iter  15 	 ELBO:  -4427.1006 	 Improvement:  -0.115722656 


Iter  16 	 ELBO:  -4427.1484 	 Improvement:  -0.047851562 


Iter  17 	 ELBO:  -4427.1