In [1]:
import silence_tensorflow.auto
import numpy as np
from tensorboard import data
import tensorflow as tf
import tensorflow_probability as tfp

from tqdm import tqdm # progress-meter

In [2]:
def em(dataset, n_classes, n_iterations, random_seed):
    n_samples = dataset.shape[0]

    np.random.seed(random_seed)

    # Initial guesses for the parameters
    mus = np.random.rand(n_classes)
    sigmas = np.random.rand(n_classes)
    class_probs = np.random.dirichlet(np.ones(n_classes))

    for em_iter in tqdm(range(n_iterations)):
        # E-Step
        responsibilities = tfp.distributions.Normal(loc=mus, scale=sigmas).prob(
            dataset.reshape(-1, 1)
        ).numpy() * class_probs
        
        responsibilities /= np.linalg.norm(responsibilities, axis=1, ord=1, keepdims=True)

        class_responsibilities = np.sum(responsibilities, axis=0)

        # M-Step
        for c in range(n_classes):
            class_probs[c] = class_responsibilities[c] / n_samples
            mus[c] = np.sum(responsibilities[:, c] * dataset) / class_responsibilities[c]
            sigmas[c] = np.sqrt(
                np.sum(responsibilities[:, c] * (dataset - mus[c])**2) / class_responsibilities[c]
            )
    
    return class_probs, mus, sigmas


def main():
    class_probs_true = [0.6, 0.4]
    mus_true = [2.5, 4.8]
    sigmas_true = [0.6, 0.3]
    random_seed = 42 # for reproducability
    n_samples = 1000
    n_iterations = 10
    n_classes = 2

    # generate the data
    univariate_gmm = tfp.distributions.MixtureSameFamily(
        mixture_distribution=tfp.distributions.Categorical(probs=class_probs_true),
        components_distribution=tfp.distributions.Normal(
            loc=mus_true,
            scale=sigmas_true,
        )
    )

    dataset = univariate_gmm.sample(n_samples, seed=random_seed).numpy()

    class_probs, mus, sigmas = em(dataset, n_classes, n_iterations, random_seed)

    print(class_probs)
    print(mus)
    print(sigmas)

if __name__ == "__main__":
    main()

100%|██████████| 10/10 [00:00<00:00, 110.83it/s]

[0.4109166 0.5890834]
[4.7861406  2.47294214]
[0.30559934 0.565118  ]



