# Jain Neal Split-Merge Moves

## Random Split-Merge Procedure

In [28]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from dataclasses import dataclass

In [27]:
K_max = 10
K_true = 3
D = 2
N = 10
alpha = 1.0
sigma_hyper = 1.0

key = jax.random.key(0)

key, *subkeys = jax.random.split(key, 5)

mu_true = jax.random.normal(subkeys[0], (K_max, D)) * sigma_hyper

pi_true = jax.random.dirichlet(subkeys[1], jnp.ones(K_true)* alpha)
pi_true = jnp.concat([pi_true, jnp.zeros(K_max-K_true)], axis=0)

z_true = jax.random.categorical(subkeys[2], jnp.log(pi_true), shape=(N,))

data = jax.random.normal(subkeys[3], (N,D)) + mu_true[z_true]

In [73]:
@dataclass
class Latent:
    K_max: int
    alpha: float
    K: int
    z: jax.Array


def partition_log_ratio(alpha, M: int, N: int):
    return jnp.log(alpha) - jax.scipy.special.gammaln(M+N) + jax.scipy.special.gammaln(M) + jax.scipy.special.gammaln(N)

def proposal_log_ratio(M,N):
    return jnp.log(2)*(M+N-2)

def likelihood_log_ratio(data, ):
    return 0.0

def random_split(key, data, latent: Latent, i: int, j: int):
    c_next = jnp.max(latent.z)+1
    key, subkey = jax.random.split(key)
    splits = jnp.where(jax.random.bernoulli(subkey, shape=(latent.z.shape[0],)), latent.z[i], c_next)
    mask = latent.z == latent.z[i]
    z_proposal = jnp.where(mask, splits, latent.z)
    z_proposal = z_proposal.at[i].set(latent.z[i])
    z_proposal = z_proposal.at[j].set(c_next)
    
    Ni = jnp.count_nonzero(z_proposal == latent.z[i])
    Nj = jnp.count_nonzero(z_proposal == c_next)

    A = proposal_log_ratio(Ni,Nj)
    B = partition_log_ratio(alpha, Ni, Nj)
    C = likelihood_log_ratio(1,1)
    print(f"A: {A} B: {B} C: {C}")

    a = jnp.exp(min(0, A + B + C))
    print(a)
    key, subkey = jax.random.split(key)
    u = jax.random.uniform(subkey)
    z_new = jnp.where(u < a, z_proposal, latent.z)
    return z_new


key = jax.random.key(2)
latent = Latent(K_max, alpha, 2, jax.random.categorical(key, jnp.ones(2), shape=(N,)))
print(latent.z)
random_split(key, data, latent, 0, 1)

[1 1 0 0 1 1 1 1 0 1]
A: 3.465735912322998 B: -4.094344615936279 C: 0.0
0.53333324


Array([1, 1, 0, 0, 1, 1, 1, 1, 0, 1], dtype=int32)