In [1]:
import jax
import jax.numpy as jnp

In [2]:
key = jax.random.PRNGKey(42)
N = 100
A = jax.random.normal(key, shape=(N, N))

In [3]:
A_inv = jnp.linalg.inv(A)

In [4]:
batch_size = 50
key, subkey = jax.random.split(key)

A_bacthed = jax.random.normal(subkey, shape=(batch_size, N, N), dtype=jnp.float64)
print(A_bacthed.dtype)
A_bacthed = jnp.asarray(A_bacthed, dtype=jnp.float64)
print(A_bacthed.dtype)
batched_inv = jax.vmap(jnp.linalg.inv, (0), 0)

A_inv_batched = batched_inv(A_bacthed)

float64
float64


In [7]:
A_inv_batched.shape

supposed_eye = jnp.matmul(A_inv_batched, A_bacthed, precision=jax.lax.Precision.HIGH)

print(jnp.linalg.norm(supposed_eye - jnp.eye(N), axis=0).mean())

1.7136684115479701e-13


In [216]:
from sklearn.cluster import kmeans_plusplus

class GuassianMixture:
    def __init__(self, dim, n_comp, n_init):
        self.dim = dim
        self.n_comp = n_comp
        self.means = None
        self.cov = None
        self.weigths = None
        self.n_init = n_init

    def fit(self, data, max_iter):
        new_weights, new_means, new_cov = [], [], []
        for i in range(max_iter):
            for j in range(self.n_init):
                log_lik = self.loglikelihood(data, j)
                new_weights_j, new_means_j, new_cov_j = self.em_step(log_lik, data, self.weights[j], self.means[j], self.cov[j])
                new_weights.append(new_weights_j)
                new_means.append(new_means_j)
                new_cov.append(new_cov_j)
                
            self.weights = jnp.asarray(new_weights)
            self.means = jnp.asarray(new_means)
            self.cov = jnp.asarray(new_cov)
            print(self.score(data))
        
    def em_step(self, log_lik, data, weigths, means, cov):
        # Expectation
        new_z = jax.nn.softmax(log_lik, 0)
        # Minimization
        # Weights
        new_weights = jnp.mean(new_z, 1)
        
        # Means
        mul = jax.vmap(lambda x, y: jnp.expand_dims(x, 1) * y, (0, None), 0)
        new_mu = jnp.sum(mul(new_z, data), 1) / jnp.expand_dims(jnp.sum(new_z, 1), 1)
        
        # Sigma
        outer_product_helper_1 = jax.vmap(lambda data, mu, z: z * jnp.expand_dims((data - mu), 1) @ jnp.expand_dims((data - mu), 1).T, (None, 0, 0), 0)
        outer_product_helper_2 = jax.vmap(outer_product_helper_1, (0, None, 0), 0)
        new_sigma = jnp.sum(outer_product_helper_2(data, new_mu, new_z.T), axis=0) / jnp.expand_dims(jnp.sum(new_z, 1), (1, 2))
        
        return new_weights, new_mu, new_sigma
            
    def k_pp_init(self, data):
        means = []
        for i in range(self.n_init):
            means.append(kmeans_plusplus(data, self.n_comp)[0])
            
        means = jnp.asarray(means)
        
        cov = jnp.expand_dims(jnp.eye(self.dim), 0)
        cov = cov.repeat(self.n_comp, 0)
        cov = jnp.expand_dims(cov, 0)
        cov = cov.repeat(self.n_init, 0)
        
        self.means = means
        self.cov = cov
        self.weights = jnp.ones([self.n_init, self.n_comp]) / self.n_comp
            
    def loglikelihood(self, data, n_init=0):
        N = data.shape[0]
        res = []
        
        exp_sq = lambda data, mean, cov: -1 / 2 * (data - mean).T @ jnp.linalg.inv(cov) @ (data - mean)
        exp_sq_vmap = jax.vmap(exp_sq, (0, None, None), 0)
        
        for i in range(self.n_comp):
            log_lik_data_comp = - self.dim / 2 * jnp.log(2 * jnp.pi) - 1 / 2 * jnp.log(jnp.linalg.det(self.cov[n_init][i]))
            
            res_exp_sq = exp_sq_vmap(data, self.means[n_init][i], self.cov[n_init][i])
            log_lik_data_comp += res_exp_sq
            res.append(log_lik_data_comp)
        
        res = jnp.asarray(res) + jnp.log(jnp.expand_dims(self.weights[i], 1))
        return res
        
    def score(self, data):
        ret_val = []
        for i in range(self.n_init):
            log_lik_i = self.loglikelihood(data, n_init=i)
            log_lik_i = jax.scipy.special.logsumexp(log_lik_i, axis=0)
            log_lik_i = log_lik_i.mean()
            ret_val.append(log_lik_i)
            
        return ret_val

In [217]:
from sklearn import datasets

data = datasets.make_blobs(n_samples=100)[0]

gmm_jax = GuassianMixture(dim=data.shape[1], n_comp=3, n_init=10)

gmm_jax.k_pp_init(data)

gmm_jax.loglikelihood(data)

gmm_jax.score(data)

gmm_jax.fit(data, max_iter=5)

[Array(-5.06478636, dtype=float64), Array(-12.18989855, dtype=float64), Array(-4.40431633, dtype=float64), Array(-4.90278576, dtype=float64), Array(-5.13034089, dtype=float64), Array(-5.39848367, dtype=float64), Array(-4.32437989, dtype=float64), Array(-4.72667753, dtype=float64), Array(-4.61508997, dtype=float64), Array(-4.31340248, dtype=float64)]
[Array(-3.95643908, dtype=float64), Array(-4.34296159, dtype=float64), Array(-3.95029736, dtype=float64), Array(-3.96914911, dtype=float64), Array(-4.01983422, dtype=float64), Array(-4.02073494, dtype=float64), Array(-3.95801431, dtype=float64), Array(-3.96190086, dtype=float64), Array(-3.97195546, dtype=float64), Array(-3.95510778, dtype=float64)]
[Array(-3.95643908, dtype=float64), Array(-4.34296159, dtype=float64), Array(-3.95029736, dtype=float64), Array(-3.96914911, dtype=float64), Array(-4.01983422, dtype=float64), Array(-4.02073494, dtype=float64), Array(-3.95801431, dtype=float64), Array(-3.96190086, dtype=float64), Array(-3.9719554

In [221]:
from sklearn.mixture import GaussianMixture

gmm_sklearn = GaussianMixture(3, verbose=2, verbose_interval=1)

gmm_sklearn.fit(data)
print(f'Reference LogLik: {gmm_sklearn.score(data)}')

weights = gmm_sklearn.weights_
means = gmm_sklearn.means_
cov = gmm_sklearn.covariances_


gmm_jax.weights = jnp.asarray(jnp.expand_dims(weights, 0).repeat(10, 0))
gmm_jax.means = jnp.asarray(jnp.expand_dims(means, 0).repeat(10, 0))
gmm_jax.cov = jnp.asarray(jnp.expand_dims(cov, 0).repeat(10, 0))

gmm_jax.score(data)

Initialization 0
  Iteration 1	 time lapse 0.00409s	 ll change inf
  Iteration 2	 time lapse 0.00044s	 ll change 0.00071
Initialization converged: True	 time lapse 0.00455s	 ll -3.94960
Reference LogLik: -3.949423290853575


[Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64),
 Array(-3.94942329, dtype=float64)]