In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from nanodl import GaussianMixtureModel

# Generate synthetic data using JAX
def generate_data(seed=0):
    key = jax.random.PRNGKey(seed)
    n_samples = 300
    mean1 = jnp.array([0, 0])
    cov1 = jnp.array([[1, 0], [0, 1]])
    data1 = jax.random.multivariate_normal(key, mean1, cov1, (n_samples,))

    mean2 = jnp.array([3, 3])
    cov2 = jnp.array([[1, -0.5], [-0.5, 1]])
    data2 = jax.random.multivariate_normal(key, mean2, cov2, (n_samples,))

    X = jnp.vstack([data1, data2])
    return X

# Generate data
X = generate_data()

# Fit the model
gmm = GaussianMixtureModel(n_components=2, seed=42)
gmm.fit(X)

# Plot the data and the estimated means
plt.scatter(X[:, 0], X[:, 1], s=10, label='Data')
plt.scatter(gmm.means[:, 0], gmm.means[:, 1], color='red', label='Estimated Means')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.title('Gaussian Mixture Model (JAX)')
plt.show()

# Print estimated parameters
print("Estimated Means:", gmm.means)
print("Estimated Covariances:", gmm.covariances)
print("Estimated Weights:", gmm.weights)

labels = X.predict(X)
labels
