In [301]:
import jax
from jax import jit
import jax.numpy as jnp
import genjax
from genjax import gen
import equinox as eq
from genjax import inverse_gamma, normal, categorical, smc, Target
from tensorflow_probability.substrates import jax as tfp

from genjax import ChoiceMapBuilder as C
from genjax import Pytree, Weight, gen, pretty
from genjax._src.generative_functions.distributions.distribution import Distribution
from genjax.typing import PRNGKey
from jaxtyping import Array, Float, Integer
import time


tfd = tfp.distributions
key = jax.random.PRNGKey(0)
pretty()

In [6]:
@Pytree.dataclass
class NormalInverseGamma(Distribution):
    def random_weighted(self, key: PRNGKey, mu, l, a, b):
        ig = tfd.InverseGamma(concentration=a, scale=b)
        key, subkey = jax.random.split(key)
        precision = ig.sample(seed=subkey)
        ig_logp = ig.log_prob(precision)

        normal = tfd.Normal(loc=mu, scale=precision / l)
        key, subkey = jax.random.split(key)
        mu = normal.sample(seed=subkey)
        mu_logp = normal.log_prob(mu)
        
        retval = jnp.stack([mu, precision], axis=1)
        inv_logp = -jnp.sum(ig_logp) - jnp.sum(mu_logp)
        return inv_logp, retval

    def estimate_logpdf(self, key: PRNGKey, x, mu, l, a, b):
        mu_sampled = x[:,0]
        precision = x[:,1]
        ig = tfd.InverseGamma(concentration=a, scale=b)
        ig_logp = ig.log_prob(precision)
        normal = tfd.Normal(loc=mu, scale= precision/l)
        mu_logp = normal.log_prob(mu_sampled)
        return jnp.sum(ig_logp) + jnp.sum(mu_logp)

nig = NormalInverseGamma()

@gen
def model():
    x = nig(jnp.zeros(10), jnp.ones(10), jnp.ones(10), jnp.ones(10)) @ "x"
    return x

model.simulate(key, ())

In [7]:
@Pytree.dataclass
class Dirichlet(Distribution):
    def random_weighted(self, key:PRNGKey, alpha):
        dir = tfd.Dirichlet(concentration = alpha)
        probs = dir.sample(seed=key)
        inv_weight = -dir.log_prob(probs)
        return inv_weight, probs
    def estimate_logpdf(self, key:PRNGKey, x, alpha):
        dir = tfd.Dirichlet(concentration = alpha)
        return dir.log_prob(x)
dirichlet = Dirichlet()

@gen
def model():
    dist = dirichlet(jnp.ones((5,5))) @ "x"
    return dist

tr = model.simulate(key, ())

In [8]:
"""
A class to store DP samples and the corresponding beta values. 

Used in GEM to avoid floating point error
"""
@Pytree.dataclass
class DPSample(Pytree):
    betas: Array
    pi: Array
    def __init__(self, betas, pi):
        self.betas = betas
        self.pi = pi


In [330]:
@Pytree.dataclass
class GEM(Distribution):
    C: int = Pytree.static(default=1)
    def __init__(self, C:int=10):
        self.C = jnp.asarray(C)
    def random_weighted(self, key: PRNGKey, alpha: Float):
        C = self.C
        beta = tfd.Beta(concentration1 = jnp.array(alpha), concentration0=jnp.array(1.0))
        betas = beta.sample(seed=key, sample_shape = C)
        inv_weight = -jnp.sum(beta.log_prob(betas))
        def fold(carry, b):
            return carry * (1-b) , carry * b
        _, pi = jax.lax.scan(fold, 1.0, betas)
        return inv_weight, DPSample(betas, pi)

    def estimate_logpdf(self, key: PRNGKey, dist:DPSample, alpha: Float):
        # assumes dist.pi corresponds to dist.betas
        betas = dist.betas
        beta = tfd.Beta(concentration1 = jnp.array(alpha), concentration0 = jnp.array(1.0))
        weight = jnp.sum(beta.log_prob(betas))
        return weight

gem = GEM(10)

@gen
def model():
    pi = gem(jnp.array(2.0)) @ "pi"
    print(pi)
    return pi

# jax.jit(model.simulate)(key, ())
tr = model.simulate(key, ())
# pi = tr.retval
tr

DPSample(...)


In [331]:
@gen
def model():
    pi = gem(jnp.array(2.0)) @ "pi"
    return pi

simulate_jitted = jax.jit(jax.vmap(model.simulate, in_axes=(0,None)))
keys = jax.random.split(key, 1000)
simulate_jitted(keys, ())

In [334]:
@Pytree.dataclass
class GemByDirichlet(Distribution):
    C: int = Pytree.static(default=1)
    def __init__(self, C:int=10):
        self.C = jnp.asarray(C)
    def random_weighted(self, key: PRNGKey, concentration):
        dir = tfd.Dirichlet(concentration=concentration)
        probs = dir.sample(seed=key)
        # invert betas
        def unfold(carry, x):
            beta = x / carry
            return carry * (1-beta) , carry * beta
        _, betas = jax.lax.scan(unfold, jnp.array(1.0), probs)
        x = DPSample(betas, probs)
        inv_weight = -dir.log_prob(probs)
        return inv_weight, x

    def estimate_logpdf(self, key: PRNGKey, x:DPSample, concentration):
        # assumes dist.pi corresponds to dist.betas
        dir = tfd.Dirichlet(concentration = concentration)
        return dir.log_prob(x.pi)

gbd = GemByDirichlet(10)
@gen
def model():
    pi = gbd(jnp.ones((10,))) @ "pi"
    return pi

model.simulate(key, ())

In [336]:
@Pytree.dataclass
class MixtureModel(Distribution):
    def random_weighted(self, key, pi, categorical_probs):
        key_0, key_1 = jax.random.split(key, 2)
        cluster_dist = tfd.Categorical(pi)
        c = cluster_dist.sample(seed=key_0)
        c_logp = cluster_dist.log_prob(c)
        label_dist = tfd.Categorical(categorical_probs[c])
        y = label_dist.sample(seed=key_1)
        y_logp = label_dist.log_prob(y)
        return -c_logp-y_logp, (c,y)

    def estimate_logpdf(self, key:PRNGKey, x, pi, categorical_probs):
        c, y = x
        cluster_dist = tfd.Categorical(pi)
        label_dist = tfd.Categorical(categorical_probs[c])
        logp = cluster_dist.log_prob(c) + label_dist.log_prob(y)
        return logp

cmm = MixtureModel()

@genjax.repeat(n=100)
@gen 
def cluster(pi, probs):
    assignments = cmm(pi, probs) @ "assignments"
    return assignments

pi = jnp.ones(10) / 10
categorical_probs = jax.random.uniform(key, (10, 36, 19))
tr = cluster.simulate(key, (pi, categorical_probs,))
tr.get_choices()[0, "assignments"].unmask()

In [355]:
@genjax.repeat(n=100)
@gen 
def cluster(pi, probs):
    assignments = cmm(pi, probs) @ "assignments"
    return assignments

@gen
def model():
    C = 20
    n_features = 36
    n_labels = 20
    pi = gem(jnp.array(2.0)) @ "pi"
    cluster_parameters = dirichlet(jnp.ones((C, n_features, n_labels))) @ "parameters"
    assignments = cluster(pi.pi, cluster_parameters) @ "assignments"
    return assignments

simulate_jitted = jax.jit(model.simulate)
simulate_jitted(key, ())

In [356]:
tr = simulate_jitted(key, ())
tr.get_choices()["assignments", 0, "assignments"].unmask()

### Inference?

In [358]:
from genjax import ChoiceMapBuilder as C
import genspn.distributions 

key = jax.random.PRNGKey(0)
obs = C["assignments",0,"assignments"].set((jnp.array(2, dtype=jnp.int32), jnp.zeros((36,), dtype=jnp.int32)))
obs = tr.get_choices()
args = ()
model.importance(key, obs, args)

TypeError: true_fun and false_fun output must have identical types, got
('ShapedArray(float32[36])', 'DIFFERENT ShapedArray(float32[36]) vs. ShapedArray(float32[], weak_type=True)', ('ShapedArray(int32[])', 'ShapedArray(int32[36])')).

In [342]:
@gen
def rejuvenate_pi(assignments):
    C = 10
    cluster_counts = jnp.sum(jax.nn.one_hot(assignments, num_classes= C), axis=0)
    pi = gbd(cluster_counts) @ "pi"
    return pi

assignments, _ = tr.get_choices()["assignments", ..., "assignments"]

rejuvenate_pi.simulate(key, (assignments,))

In [357]:
@gen
def rejuvenate_parameters(assignments, data):
    C = 10
    n_features = 20
    # update posterior of hyperparameters
    one_hot_x = jax.nn.one_hot(data, num_classes=n_features)
    counts = jax.ops.segment_sum(one_hot_x, assignments, num_segments=C)
    concentration = jnp.ones((C, data.shape[1], n_features)) + counts
    parameters = dirichlet(concentration) @ "parameters"
    return parameters

assignments, _ = tr.get_choices()["assignments", ..., "assignments"]
data = jax.random.randint(key, (100, 36), minval=0, maxval=19)
rejuvenate_parameters.simulate(key, (assignments, data))

In [None]:
@gen
def rejuvenate_parameters(assignments, data):
    C = 10
    n_features = 20
    # update posterior of hyperparameters
    one_hot_x = jax.nn.one_hot(data, num_classes=n_features)
    counts = jax.ops.segment_sum(one_hot_x, assignments, num_segments=C)
    concentration = jnp.ones((C, data.shape[1], n_features)) + counts
    parameters = dirichlet(concentration) @ "parameters"
    return parameters

assignments, _ = tr.get_choices()["assignments", ..., "assignments"]
data = jax.random.randint(key, (100, 36), minval=0, maxval=19)
rejuvenate_parameters.simulate(key, (assignments, data))

In [None]:
# @gen
# def rejuvenate_assignments(parameters, pi, data):
#     C = 10
#     n_features = 20
    # log_likelihoods = jax.vmap(jax.vmap(logpdf, in_axes=(0, None)), in_axes=(None, 0))(f, data)
    # log_likelihoods = log_likelihoods + log_likelihood_mask
    # log_score = log_likelihoods + jnp.log(pi)

    # assignments = jax.random.categorical(key, log_score, axis=-1).astype(int)
    # categorical()
    # return assignments

# assignments, _ = tr.get_choices()["assignments", ..., "assignments"]
# data = jax.random.randint(key, (100, 36), minval=0, maxval=19)
# rejuvenate_parameters.simulate(key, (assignments, data))

In [43]:
train, test = load_huggingface("AutoML/soybean")

shape: (307, 36)
┌────────────────┬──────┬─────────────┬────────┬───┬───────────────┬───────────┬────────────┬──────┐
│ class          ┆ date ┆ plant_stand ┆ precip ┆ … ┆ seed_discolor ┆ seed_size ┆ shriveling ┆ root │
│ ---            ┆ ---  ┆ ---         ┆ ---    ┆   ┆ ---           ┆ ---       ┆ ---        ┆ ---  │
│ str            ┆ str  ┆ str         ┆ str    ┆   ┆ str           ┆ str       ┆ str        ┆ str  │
╞════════════════╪══════╪═════════════╪════════╪═══╪═══════════════╪═══════════╪════════════╪══════╡
│ phytophthora_r ┆ 2    ┆ 1           ┆ 1      ┆ … ┆ null          ┆ null      ┆ null       ┆ 1    │
│ ot             ┆      ┆             ┆        ┆   ┆               ┆           ┆            ┆      │
│ frog_eye_leaf_ ┆ 3    ┆ 0           ┆ 2      ┆ … ┆ 0             ┆ 0         ┆ 0          ┆ 0    │
│ spot           ┆      ┆             ┆        ┆   ┆               ┆           ┆            ┆      │
│ phytophthora_r ┆ 1    ┆ 1           ┆ 2      ┆ … ┆ 0             ┆ 0    

  df = df.cast(pl.Categorical)
