In [2]:
import jax.numpy as jnp
import jax


In [19]:
def renew(key):
    return jax.random.split(key, 1)[0]

In [5]:
def sample_along_chord(center, vec, cutoff, fn, key,
                       fwd_mult=1, bwd_mult=1, 
                       double_iters=10, resample_iters=10,
                       doubled=False):
    if double_iters < 0:
        raise ValueError(f"Reached maximum number of doubling steps at mults {fwd_mult}, {bwd_mult}. Consider increasing the initial mults.")
    if resample_iters < 0:
        raise ValueError(f"Reached maximum number of resampling steps at mults {fwd_mult}, {bwd_mult}. Consider decreasing the initial mults.")

    assert fn(center) < cutoff, "center must be below cutoff"

    long_enough = True
    if fn(center + fwd_mult * vec) < cutoff:
        long_enough = False
        fwd_mult *= 2
    if fn(center - bwd_mult * vec) < cutoff:
        long_enough = False
        bwd_mult *= 2
    
    if not long_enough:
        # chord is too short, double the mults and try again
        return sample_along_chord(center, vec, cutoff, fn, key, 
                                  fwd_mult, bwd_mult, 
                                  double_iters - 1, resample_iters, 
                                  doubled=True)
    
    # we've found a long enough chord

    sampled_mult = -bwd_mult + jax.random.uniform(key, ()) * (fwd_mult + bwd_mult)
    key = renew(key)

    sampled_point = center + sampled_mult * vec

    if fn(sampled_point) < cutoff:
        # resampled point is below cutoff, so we're done
        return sampled_point, sampled_mult
    
    # resampled point is above cutoff, so we need to resample

    if not doubled:
        # maybe initial mults were too big?
        fwd_mult /= 2
        bwd_mult /= 2
    
    # either way, resample

    return sample_along_chord(center, vec, cutoff, fn, key, 
                              fwd_mult, bwd_mult, 
                              double_iters, resample_iters - 1, 
                              doubled)


In [22]:
def hit_and_run(center, cutoff, fn, key, n_steps, 
                 fwd_mult=1, bwd_mult=1, 
                 double_iters=10, resample_iters=10):
    
    sampled_points = []
    sampled_mults = []
    for _ in range(n_steps):
        vec = jax.random.normal(key, center.shape)
        key, subkey = jax.random.split(key)
        vec = vec / jnp.linalg.norm(vec)
        sampled_point, sampled_mult = sample_along_chord(center, vec, cutoff, fn, subkey, 
                                                          fwd_mult, bwd_mult, 
                                                          double_iters, resample_iters)
        sampled_points.append(sampled_point)
        sampled_mults.append(sampled_mult)
    return jnp.stack(sampled_points), jnp.stack(sampled_mults)