In [72]:
from genjax import adev
from jax import make_jaxpr
import jax
from jaxtyping import Array
import jax.numpy as jnp
import jaxtyping

from genjax.adev import reinforce

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

def cat_logpdf(v, logits):
    return tfd.Categorical(logits=logits).log_prob(v)

cat_reinforce = reinforce(
    lambda key, logits: tfd.Categorical(logits=logits).sample(seed=key),
    cat_logpdf,
)

In [73]:
# random maximum spanning tree
# θ is a matrix of size N^2 (N is number of leaves in tree).
# assume N is power of 2 -- at each step, you flatten θ,
# and sample an index from a categorical, this is the inner node 
#
# once you do that, you zero out the entire row and the entire column 
# for the 2 indices you got -- then continue in this loop.
# 
# That's one layer -- then there's an outer loop that 
# merges the theta to do higher layers.
#
# It's agglomerative clustering.

In [74]:
def sample(θ, N):
    logits = θ.flatten()
    idx = cat_reinforce(logits)
    # N * first_idx + snd_idx
    first_idx = idx // N
    snd_idx = jnp.mod(idx, N)
    mask = jnp.zeros((N, N))
    mask = mask.at[:, snd_idx].set(1.0)
    mask = mask.at[first_idx, :].set(1.0)
    new_θ = jnp.where(mask, -jnp.inf, θ)
    new_θ = new_θ - jax.nn.logsumexp(new_θ)
    return (first_idx, snd_idx), new_θ

@adev.expectation
def fn(θ: jaxtyping.Float[Array, "N N"]):
    list_of_pair_idxs = []
    N = jnp.shape(θ)[0]
    new_θ = θ
    for _ in range(int(N / 2)):
        (idx_1, idx_2), new_θ = sample(new_θ, N)
        list_of_pair_idxs.append((idx_1, idx_2))
    logits = jnp.array([θ[idx1, idx2] for (idx1, idx2) in list_of_pair_idxs])
    return jnp.sum(logits)

In [77]:
θ = jnp.ones((8, 8), dtype=float)
key = jax.random.key(1)
sub_keys = jax.random.split(key, 10000)
(θ_grads, ) = jax.vmap(fn.grad_estimate, in_axes=(0, None))(sub_keys, (θ, ))
jnp.mean(θ_grads, axis=0)

Array([[0.09098557, 0.04912024, 0.04733415, 0.0871511 , 0.0565511 ,
        0.06521481, 0.07132715, 0.05616816],
       [0.05592084, 0.0474706 , 0.06093879, 0.04676947, 0.05416715,
        0.05473091, 0.05415608, 0.08028414],
       [0.05597848, 0.08226778, 0.0634108 , 0.08944299, 0.07387813,
        0.0818228 , 0.06331954, 0.06372511],
       [0.06554484, 0.06269173, 0.06040351, 0.06267195, 0.04168061,
        0.06291707, 0.05631427, 0.06895594],
       [0.07168492, 0.05756315, 0.05186746, 0.0650251 , 0.06496514,
        0.04712012, 0.05578597, 0.06645662],
       [0.03733051, 0.05027314, 0.05912687, 0.04585521, 0.07194045,
        0.09018002, 0.05946372, 0.05192142],
       [0.05699509, 0.0856982 , 0.0548017 , 0.05940758, 0.03715845,
        0.03742853, 0.05359431, 0.051717  ],
       [0.07507776, 0.05132079, 0.08457854, 0.07906983, 0.07822269,
        0.06677518, 0.07489506, 0.07338294]], dtype=float32)