In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from time import time

### Load data

In [3]:
from sklearn.preprocessing import MinMaxScaler

def load_synthetic_dataset(filename):
    data = np.load(filename)
    scaler = MinMaxScaler()
    scaler.fit(data)
    data = scaler.transform(data)
    return data

def load_cloud_dataset():
    cloud_data_name = 'cloud.data'
    with open(cloud_data_name) as f:
        cloud_data = pd.DataFrame([item.split() for item in f.readlines()])
    cloud_data = cloud_data.astype(float).to_numpy()
    scaler = MinMaxScaler()
    scaler.fit(cloud_data)
    cloud_data = scaler.transform(cloud_data)
    return cloud_data


### Jax GMM

In [36]:
from sklearn.cluster import kmeans_plusplus
from sklearn.cluster import KMeans

class GuassianMixture:
    def __init__(self, dim, n_comp, n_init, verbose=False):
        self.dim = dim
        self.n_comp = n_comp
        self.means = None
        self.cov = None
        self.weigths = None
        self.n_init = n_init
        self.verbose = verbose
        self.is_init = False

    def fit(self, data, max_iter):
        
        if not self.is_init:
            self.km_init(data)
            
        new_weights, new_means, new_cov = [], [], []
        for i in range(max_iter):
            for j in range(self.n_init):
                start = time()
                log_lik = self.loglikelihood(data, self.weights[j], self.means[j], self.cov[j])
                print(f'LogLik time: {time() - start}')
                
                start = time()
                new_weights_j, new_means_j, new_cov_j = self.em_step(log_lik, data, self.weights[j], self.means[j], self.cov[j])
                
                print(f'EM time: {time() - start}')
                
                new_weights.append(new_weights_j)
                new_means.append(new_means_j)
                new_cov.append(new_cov_j)
                
            self.weights = jnp.asarray(new_weights)
            self.means = jnp.asarray(new_means)
            self.cov = jnp.asarray(new_cov)
#             print(self.score(data))
        
    @staticmethod
    @jax.jit
    def em_step(log_lik, data, weigths, means, cov):
        # Expectation
        new_z = jax.nn.softmax(log_lik, 0)
        # Minimization
        # Weights
        new_weights = jnp.mean(new_z, 1)
        
        # Means
        mul = jax.vmap(lambda x, y: jnp.expand_dims(x, 1) * y, (0, None), 0)
        new_mu = jnp.sum(mul(new_z, data), 1) / jnp.expand_dims(jnp.sum(new_z, 1), 1)
        
        # Sigma
        outer_product_helper_1 = jax.vmap(lambda data, mu, z: z * jnp.expand_dims((data - mu), 1) @ jnp.expand_dims((data - mu), 1).T, (None, 0, 0), 0)
        outer_product_helper_2 = jax.vmap(outer_product_helper_1, (0, None, 0), 0)
        new_sigma = jnp.sum(outer_product_helper_2(data, new_mu, new_z.T), axis=0) / jnp.expand_dims(jnp.sum(new_z, 1), (1, 2))
        
        return new_weights, new_mu, new_sigma
            
    def k_pp_init(self, data):
        means = []
        for i in range(self.n_init):
            means.append(kmeans_plusplus(data, self.n_comp)[0])
            
        means = jnp.asarray(means)
        
        cov = jnp.expand_dims(jnp.eye(self.dim), 0)
        cov = cov.repeat(self.n_comp, 0)
        cov = jnp.expand_dims(cov, 0)
        cov = cov.repeat(self.n_init, 0)
        
        self.means = means
        self.cov = cov
        self.weights = jnp.ones([self.n_init, self.n_comp]) / self.n_comp
        
    def km_init(self, data):
        means = []
        
        for i in range(self.n_init):
            km = KMeans(n_clusters=self.n_comp, n_init=1).fit(data)
            means.append(km.cluster_centers_)
            
        cov = jnp.expand_dims(jnp.eye(self.dim), 0)
        cov = cov.repeat(self.n_comp, 0)
        cov = jnp.expand_dims(cov, 0)
        cov = cov.repeat(self.n_init, 0)
        
        self.means = means
        self.cov = cov
        self.weights = jnp.ones([self.n_init, self.n_comp]) / self.n_comp
         
    @staticmethod
    @jax.jit
    def loglikelihood(data, weigths, means, cov):
        N = data.shape[0]
        res = []
        dim = means.shape[1]
        
        exp_sq = lambda data, mean, cov: -1 / 2 * (data - mean).T @ jnp.linalg.inv(cov) @ (data - mean)
        exp_sq_vmap = jax.vmap(exp_sq, (0, None, None), 0)
        
        norm_const_fn = lambda cov: - dim / 2 * jnp.log(2 * jnp.pi) - 1 / 2 * jnp.log(jnp.linalg.det(cov))
        
        norm_const_fn_vmap = jax.vmap(norm_const_fn, (0), 0)

        norm_const = norm_const_fn_vmap(cov)
        
        log_lik_fn = lambda data, means, cov, norm_const: exp_sq_vmap(data, means, cov) + norm_const
        log_lik_fn_vmap = jax.vmap(log_lik_fn, (None, 0, 0, 0), 0)
        
        res = log_lik_fn_vmap(data, means, cov, norm_const)
        res += jnp.log(jnp.expand_dims(weights, 1))
        return res
        
    def score(self, data):
        ret_val = []
        for i in range(self.n_init):
            log_lik_i = self.loglikelihood(data, self.weights[i], self.means[i], self.cov[i])

            log_lik_i = jax.scipy.special.logsumexp(log_lik_i, axis=0)
            log_lik_i = log_lik_i.mean()
            ret_val.append(log_lik_i)
            
        return ret_val

In [37]:
from sklearn import datasets

data = datasets.make_blobs(n_samples=100)[0]

gmm_jax = GuassianMixture(dim=data.shape[1], n_comp=3, n_init=4)

gmm_jax.fit(data, max_iter=5)
print(gmm_jax.score(data))

LogLik time: 0.44620847702026367
EM time: 0.2828030586242676
LogLik time: 0.010161161422729492
EM time: 0.004685640335083008
LogLik time: 0.004673957824707031
EM time: 0.0043735504150390625
LogLik time: 0.004444122314453125
EM time: 0.004311800003051758
LogLik time: 0.0074310302734375
EM time: 0.007873058319091797
LogLik time: 0.008213043212890625
EM time: 0.0062520503997802734
LogLik time: 0.006018400192260742
EM time: 0.011748075485229492
LogLik time: 0.025290966033935547
EM time: 0.013001203536987305
LogLik time: 0.015445470809936523
EM time: 0.014984130859375
LogLik time: 0.014832496643066406
EM time: 0.013680219650268555
LogLik time: 0.01606464385986328
EM time: 0.019208669662475586
LogLik time: 0.0071179866790771484
EM time: 0.006487131118774414
LogLik time: 0.007162570953369141
EM time: 0.006973981857299805
LogLik time: 0.006180286407470703
EM time: 0.006962299346923828
LogLik time: 0.006498575210571289
EM time: 0.011763811111450195
LogLik time: 0.00930929183959961
EM time: 0.00

In [32]:
from sklearn.mixture import GaussianMixture

gmm_sklearn = GaussianMixture(3, verbose=2, verbose_interval=1)

gmm_sklearn.fit(data)
print(f'Reference LogLik: {gmm_sklearn.score(data)}')

weights = gmm_sklearn.weights_
means = gmm_sklearn.means_
cov = gmm_sklearn.covariances_

gmm_jax.weights = jnp.asarray(jnp.expand_dims(weights, 0).repeat(10, 0))
gmm_jax.means = jnp.asarray(jnp.expand_dims(means, 0).repeat(10, 0))
gmm_jax.cov = jnp.asarray(jnp.expand_dims(cov, 0).repeat(10, 0))

gmm_jax.score(data)

Initialization 0
  Iteration 1	 time lapse 0.00637s	 ll change inf
  Iteration 2	 time lapse 0.00075s	 ll change 0.00130
  Iteration 3	 time lapse 0.00073s	 ll change 0.00007
Initialization converged: True	 time lapse 0.00789s	 ll -3.72336
Reference LogLik: -3.7233357658935575


[Array(-3.72333577, dtype=float64),
 Array(-3.72333577, dtype=float64),
 Array(-3.72333577, dtype=float64),
 Array(-3.72333577, dtype=float64)]

### Try on real data

In [332]:
data = load_cloud_dataset()
n_init = 10

gmm_jax = GuassianMixture(dim=data.shape[1], n_comp=5, n_init=n_init)

gmm_jax.fit(data, max_iter=10)

print(gmm_jax.score(data))

[Array(10.92622015, dtype=float64), Array(10.86955722, dtype=float64), Array(10.84269447, dtype=float64), Array(10.9388318, dtype=float64), Array(10.91938932, dtype=float64), Array(10.86551461, dtype=float64), Array(10.83850061, dtype=float64), Array(10.86955722, dtype=float64), Array(10.83950758, dtype=float64), Array(10.94752381, dtype=float64)]


In [327]:
data = load_cloud_dataset()

gmm_sklearn = GaussianMixture(5, verbose=0, verbose_interval=1, max_iter=10, init_params='kmeans', n_init=n_init)

gmm_sklearn.fit(data)

print(gmm_sklearn.score(data))

22.211747257375087




### Try on more challenging synthetics

In [310]:

data_synth = load_synthetic_dataset('synth_dim_10.npy')

gmm_jax = GuassianMixture(dim=data_synth.shape[1], n_comp=10, n_init=1)

gmm_jax.km_init(data_synth)

print(gmm_jax.score(data_synth))

gmm_jax.fit(data_synth, max_iter=10)




[Array(-9.5784723, dtype=float64)]
[0.09847184 0.10335273 0.09829694 0.09478363 0.10178235 0.10238119
 0.10483296 0.10070083 0.0932159  0.10218163]
[Array(5.57020881, dtype=float64)]
[0.09636037 0.10535736 0.09872072 0.09765784 0.10400059 0.09789174
 0.105164   0.09758194 0.09674804 0.1005174 ]
[Array(5.57085649, dtype=float64)]
[0.09454035 0.10730969 0.09912313 0.10012078 0.10615426 0.09378679
 0.10539643 0.09478114 0.09991338 0.09887405]
[Array(5.57139572, dtype=float64)]
[0.09298017 0.10921202 0.09951533 0.10221437 0.10824458 0.09002532
 0.10553828 0.09226239 0.10274692 0.09726062]
[Array(5.57184801, dtype=float64)]
[0.09165158 0.11106568 0.09990549 0.10397957 0.11027172 0.08657005
 0.10559613 0.08999337 0.10528368 0.09568274]
[Array(5.57223052, dtype=float64)]
[0.09052944 0.11287145 0.10029949 0.10545505 0.11223526 0.08338751
 0.10557555 0.0879453  0.10755734 0.09414362]
[Array(5.57255687, dtype=float64)]
[0.08959156 0.11462987 0.10070149 0.10667645 0.11413444 0.08044787
 0.1054814

In [299]:

gmm_sklearn = GaussianMixture(10, verbose=2, verbose_interval=1, max_iter=10, init_params='kmeans')

gmm_sklearn.fit(data_synth)

Initialization 0
  Iteration 1	 time lapse 0.06119s	 ll change inf
  Iteration 2	 time lapse 0.00853s	 ll change 0.02054
  Iteration 3	 time lapse 0.00586s	 ll change 0.01005
  Iteration 4	 time lapse 0.00576s	 ll change 0.00995
  Iteration 5	 time lapse 0.00578s	 ll change 0.00394
  Iteration 6	 time lapse 0.00574s	 ll change 0.00242
  Iteration 7	 time lapse 0.00577s	 ll change 0.00252
  Iteration 8	 time lapse 0.00575s	 ll change 0.00313
  Iteration 9	 time lapse 0.00580s	 ll change 0.00330
  Iteration 10	 time lapse 0.00571s	 ll change 0.00367
Initialization converged: False	 time lapse 0.11590s	 ll 10.86906




## Time measurements

In [39]:
from time import time

from sklearn.datasets import load_digits
digits = load_digits()['data']
data = load_synthetic_dataset('synth_dim_70.npy')

n_runs = 10

time_storage = []
for i in range(n_runs):
    gmm_sklearn = GaussianMixture(3, verbose=2, verbose_interval=1, max_iter=2, init_params='kmeans')
    
    start = time()

    gmm_sklearn.fit(data)
    
    time_storage.append(time() - start)
    
print(f'Time for one EM step: {np.mean(time_storage)} s +- {np.std(time_storage)}')

time_storage = []
for i in range(n_runs):
    
    gmm_jax = GuassianMixture(dim=data.shape[1], n_comp=3, n_init=1)

    start = time()
    gmm_jax.km_init(data)
    gmm_jax.fit(data, max_iter=2)
    
    time_storage.append(time() - start)

print(f'Time for one EM step jax: {np.mean(time_storage)} s +- {np.std(time_storage)}')

Initialization 0
  Iteration 1	 time lapse 0.01530s	 ll change inf
  Iteration 2	 time lapse 0.00752s	 ll change 0.00000
Initialization converged: True	 time lapse 0.02284s	 ll 68.88811
Initialization 0
  Iteration 1	 time lapse 0.01529s	 ll change inf
  Iteration 2	 time lapse 0.00929s	 ll change 0.00000
Initialization converged: True	 time lapse 0.02459s	 ll 68.64238
Initialization 0
  Iteration 1	 time lapse 0.01247s	 ll change inf
  Iteration 2	 time lapse 0.00609s	 ll change 0.00000
Initialization converged: True	 time lapse 0.01858s	 ll 68.59084
Initialization 0
  Iteration 1	 time lapse 0.01259s	 ll change inf
  Iteration 2	 time lapse 0.00600s	 ll change 0.00000
Initialization converged: True	 time lapse 0.01862s	 ll 68.59084
Initialization 0
  Iteration 1	 time lapse 0.01548s	 ll change inf
  Iteration 2	 time lapse 0.00648s	 ll change 0.00000
Initialization converged: True	 time lapse 0.02198s	 ll 68.59084
Initialization 0
  Iteration 1	 time lapse 0.01384s	 ll change inf
  I