In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from nanodl import GaussianMixtureModel

# Generate synthetic data using JAX
def generate_data(seed=0):
    key = jax.random.PRNGKey(seed)
    n_samples = 300
    mean1 = jnp.array([0, 0])
    cov1 = jnp.array([[1, 0], [0, 1]])
    data1 = jax.random.multivariate_normal(key, mean1, cov1, (n_samples,))

    mean2 = jnp.array([3, 3])
    cov2 = jnp.array([[1, -0.5], [-0.5, 1]])
    data2 = jax.random.multivariate_normal(key, mean2, cov2, (n_samples,))

    X = jnp.vstack([data1, data2])
    return X

# Generate data
X = generate_data()

# Fit the model
gmm = GaussianMixtureModel(n_components=2, seed=42)
gmm.fit(X)

# Plot the data and the estimated means
plt.scatter(X[:, 0], X[:, 1], s=10, label='Data')
plt.scatter(gmm.means[:, 0], gmm.means[:, 1], color='red', label='Estimated Means')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.title('Gaussian Mixture Model (JAX)')
plt.show()

# Print estimated parameters
print("Estimated Means:", gmm.means)
print("Estimated Covariances:", gmm.covariances)
print("Estimated Weights:", gmm.weights)

labels = X.predict(X)
labels


In [None]:
from nanodl import PCA

# Create dummy data
data = jax.random.normal(jax.random.key(0), (1000, 10))

# Create an instance of the PCA class
pca = PCA(n_components=2)

# Fit the model
pca.fit(data)

# Transform the data to the new lower dimension space
transformed_data = pca.transform(data)

# Inverse transform the data to the original space
original_data = pca.inverse_transform(transformed_data)

# Sample from the model
X_sampled = pca.sample(n_samples=1000, key=None)

print(X_sampled.shape, original_data.shape, transformed_data.shape)

In [None]:
from nanodl import LogisticRegression

num_samples = 100
input_dim = 2

x_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
logits = jnp.dot(x_data, jnp.array([0.5, -0.5])) - 0.1
y_data = (logits > 0).astype(jnp.float32)

lr_model = LogisticRegression(input_dim)
lr_model.train(x_data, y_data)

test_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
predictions = lr_model.predict(test_data)
print("Predictions:", predictions)