In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import distrax
import haiku as hk
import optax
import seaborn as sns
import jax.config
jax.config.update("jax_enable_x64", True)

In [None]:
# true_dist = distrax.MixtureOfTwo(0.5,
#                                  distrax.Normal(loc=1., scale=0.1),
#                                  distrax.Normal(loc=3., scale=0.1)) 
true_dist = distrax.Normal(loc=2., scale=1e-1)


In [None]:
rng = jax.random.PRNGKey(1234)
num_bins = 10
num_layers = 5
num_param_mlp_layers = 2
range_min = -1
range_max = 10
batch_size = 1

In [None]:
class RQSpline(hk.Module):

    def __init__(self, num_bins, range_min, range_max, num_layers: int):
        super().__init__()
        self.num_bins = num_bins
        self.range_min = range_min
        self.range_max = range_max
        self.num_layers = num_layers

    def create_flow(self, conditioning):
        layers = []
        for _ in range(self.num_layers):
            param_dims = self.num_bins * 3 + 1
            params = hk.nets.MLP(
                [param_dims] * num_param_mlp_layers,
                activate_final=False,
                w_init=hk.initializers.RandomNormal(1e-4),
                b_init=hk.initializers.RandomNormal(1e-4),
            )(conditioning)
            layer = distrax.RationalQuadraticSpline(params, self.range_min, self.range_max, boundary_slopes='unconstrained', min_bin_size=1e-2)
            layers.append(layer)

        flow = distrax.Inverse(distrax.Chain(layers))
        return flow

    def create_distribution(self, conditioning, inverse_temperature: float = 1.):
        flow = self.create_flow(conditioning)
        
        mean = (self.range_max + self.range_min) / 2
        std = (self.range_max - self.range_min) / (20 * inverse_temperature)
        base_distribution = distrax.Independent(
            distrax.ClippedNormal(mean, std, minimum=self.range_min, maximum=self.range_max),
            reinterpreted_batch_ndims=0)

        dist = distrax.Transformed(base_distribution, flow)
        return dist

    def forward(self, samples, conditioning):
        flow = self.create_flow(conditioning)
        return flow.forward(samples)
    
    def log_prob(self, samples, conditioning):
        assert conditioning.shape[:-1] == samples.shape[:-1]
        dist = self.create_distribution(conditioning)
        return dist.log_prob(samples)

    def sample(self, conditioning, inverse_temperature):
        dist = self.create_distribution(conditioning, inverse_temperature)
        rng = hk.next_rng_key()
        return dist.sample(seed=rng, sample_shape=conditioning.shape[:-1])

In [None]:
@hk.without_apply_rng
@hk.transform
def log_prob_fn(samples, conditioning):
    return hk.vmap(RQSpline(num_bins, range_min, range_max, num_layers).log_prob, split_rng=False)(samples, conditioning)


@hk.transform
def sample_fn(conditioning, inverse_temperature: float):
    return hk.vmap(lambda condition: RQSpline(num_bins, range_min, range_max, num_layers).sample(condition, inverse_temperature), split_rng=True)(conditioning)


@hk.without_apply_rng
@hk.transform
def forward_fn(base_samples, conditioning):
    return hk.vmap(RQSpline(num_bins, range_min, range_max, num_layers).forward, split_rng=False)(base_samples, conditioning)

In [None]:
@jax.jit
def loss_fn(params, true_samples, true_conditioning):
    loss = -jnp.mean(log_prob_fn.apply(params, true_samples, true_conditioning))
    return loss

dummy_samples = jnp.zeros((batch_size, 1))
dummy_conditioning = jnp.ones_like(dummy_samples)
params = log_prob_fn.init(rng, dummy_samples, dummy_conditioning)
tx = optax.chain(optax.adam(1e-4))
# tx = optax.sgd(1e-2)
opt_state = tx.init(params)

@jax.jit
def train_step(params, opt_state, true_samples, true_conditioning):
    grads = jax.grad(loss_fn)(params, true_samples, true_conditioning)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state


losses = []
for step in range(10000):
    step_rng, rng = jax.random.split(rng)
    true_samples = true_dist.sample(seed=rng, sample_shape=(batch_size, 1))
    true_conditioning = jnp.ones_like(true_samples)
    params, opt_state = train_step(params, opt_state, true_samples, true_conditioning)
    if step % 100 == 0:
        loss = loss_fn(params, true_samples, true_conditioning)
        losses.append(loss)
        print("step:", step, "loss:", loss)


In [None]:
sns.set_style("darkgrid")
plt.plot([x for x in range(10000) if x % 100 == 0], losses)
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training loss")
plt.show()

In [None]:
true_samples = true_dist.sample(seed=rng, sample_shape=(100, 1))
true_conditioning = jnp.ones_like(true_samples)
for inverse_temperature in [1.]:
    samples = sample_fn.apply(params, rng, true_conditioning, inverse_temperature=inverse_temperature)
    print(inverse_temperature, ':', samples.mean(), '+/-', samples.std())
    # sns.set(style="darkgrid")
    # plt.hist(samples.flatten(), bins=100, color='C0', density=True)
    # plt.hist(true_samples.flatten(), bins=100, color='C1', density=True)
    # plt.show()
    

sns.set(style="darkgrid")
plt.hist(samples.flatten(), bins=100, color='C0', density=True)
# plt.hist(true_samples.flatten(), bins=100, color='C1', density=True)
plt.show()

In [None]:
# Visualize log prob
x = jnp.linspace(1.5, 2.5, 1000).reshape(-1, 1)
y = jnp.ones_like(x)
log_prob = log_prob_fn.apply(params, x, y)
plt.plot(x, true_dist.log_prob(x), label="True")
plt.plot(x, log_prob, label="Estimated")
plt.xlabel("x")
plt.ylabel("log p(x)")
plt.title("Log Probability")
plt.legend()
plt.show()

In [None]:
# Visualize flow
sns.set(style="darkgrid")
x = jnp.linspace(range_min, range_max, 1000).reshape(-1, 1)
y = forward_fn.apply(params, x, jnp.ones_like(x))
plt.title("Flow")
plt.xlabel("x")
plt.ylabel("f(x)")
plt.plot(x.flatten(), y.flatten(), label='true')
plt.show();