In [1]:
import jax
import jax.numpy as jnp
import genjaxmix.vectorized as vectorized
from genjax import pretty
pretty()

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1732740992.315939 19079572 service.cc:145] XLA service 0x13f419b70 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1732740992.315961 19079572 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1732740992.317166 19079572 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1732740992.317172 19079572 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



In [22]:
key = jax.random.key(0)

model = vectorized.generate(N_max=1000)
tr_gt = jax.jit(model.simulate)(key, (1.0, 0.0, 3.0, 3.0, 0.5))
tr_gt

In [23]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

obs = tr_gt.get_choices()
c = obs["assignments", "c"]
y1 = obs["assignments", "y1"]
mu = obs["hyperparameters", "mu"]

sigma = jnp.sqrt(obs["hyperparameters", "sigma"])

# Create a 2D plot
plt.figure(figsize=(10, 8))

# Plot each cluster
for cluster in np.unique(c):
    cluster_points = y1[c == cluster]
    cluster_mu = mu[cluster]
    cluster_sigma = sigma[cluster]

    # Plot the points
    plt.scatter(cluster_points[:, 0], cluster_points[:, 1], label=f'Cluster {cluster}')

    # Plot the mean
    plt.scatter(cluster_mu[0], cluster_mu[1], color='black', marker='x')

    # Plot the variance as an ellipse
    ellipse = matplotlib.patches.Ellipse(cluster_mu, 2 * cluster_sigma[0], 2 * cluster_sigma[1], edgecolor='black', facecolor='none')
    plt.gca().add_patch(ellipse)

plt.xlabel('X')
plt.ylabel('Y')
plt.title('Clusters with Mean and Variance')
plt.legend()

# Save the plot
plt.savefig('clusters_plot.png')
plt.close()


In [24]:
from genjaxmix.rejuvenation import gibbs_move
from genjaxmix.vectorized_rejuvenation import propose_parameters
from genjax._src.core.interpreters.incremental import Diff

key = jax.random.key(31415)
model_args = (1.0, 0.0, 1.0, 3.0, 0.5)
obs = tr_gt.get_choices()("assignments")

gibbs_jitted = jax.jit(gibbs_move)

In [32]:
from genjax import ChoiceMapBuilder as C
key = jax.random.key(313341)
key, subkey = jax.random.split(key)
constraint = C["assignments", "c"].set(obs[:, "c"]) ^ C["assignments", "y1"].set(obs[:, "y1"])

tr,_  = model.importance(subkey, constraint, model_args)
for t in range(10):
    key, subkey = jax.random.split(key)
    tr = gibbs_jitted(model, propose_parameters, model_args, tr, obs, subkey)

# print(tr.get_choices()["assignments", :, "c"])

In [33]:
c = obs[:, "c"]
y1 = obs[:, "y1"]

new_chm = tr.get_choices()
mu = new_chm["hyperparameters", "mu"]
sigma = jnp.sqrt(new_chm["hyperparameters", "sigma"])

# Create a 2D plot
plt.figure(figsize=(10, 8))

# Plot each cluster
for cluster in np.unique(c):
    cluster_points = y1[c == cluster]
    cluster_mu = mu[cluster]
    cluster_sigma = sigma[cluster]

    # Plot the points
    plt.scatter(cluster_points[:, 0], cluster_points[:, 1], label=f'Cluster {cluster}')

    # Plot the mean
    plt.scatter(cluster_mu[0], cluster_mu[1], color='black', marker='x')

    # Plot the variance as an ellipse
    ellipse = matplotlib.patches.Ellipse(cluster_mu, 2 * cluster_sigma[0], 2 * cluster_sigma[1], edgecolor='black', facecolor='none')
    plt.gca().add_patch(ellipse)

plt.xlabel('X')
plt.ylabel('Y')
plt.title('Clusters with Mean and Variance')
plt.legend()

# Save the plot
plt.savefig('clusters_plot_inferred.png')
plt.close()
