In [None]:
import functools

import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import haiku as hk
import distrax
import optax

import plotly.graph_objects as go

# Radial

In [None]:
class RationalQuadraticSpline(hk.Module):
    """A rational quadratic spline flow."""

    def __init__(
        self,
        num_bins: int,
        range_min: float,
        range_max: float,
        num_layers: int,
        num_param_mlp_layers: int,
    ):
        super().__init__()
        self.num_bins = num_bins
        self.range_min = range_min
        self.range_max = range_max
        self.num_layers = num_layers
        self.num_param_mlp_layers = num_param_mlp_layers

    def create_flow(self, conditioning: e3nn.IrrepsArray) -> distrax.Bijector:
        """Creates a flow with the given conditioning."""
        if not conditioning.irreps.is_scalar():
            raise ValueError("Conditioning for flow must be scalars only.")
        conditioning = conditioning.array

        layers = []
        for _ in range(self.num_layers):
            param_dims = self.num_bins * 3 + 1
            params = hk.nets.MLP(
                [param_dims] * self.num_param_mlp_layers,
                activate_final=False,
                w_init=hk.initializers.RandomNormal(1e-4),
                b_init=hk.initializers.RandomNormal(1e-4),
            )(conditioning)
            layer = distrax.RationalQuadraticSpline(
                params,
                self.range_min,
                self.range_max,
                boundary_slopes="unconstrained",
                min_bin_size=1e-2,
            )
            layers.append(layer)

        flow = distrax.Inverse(distrax.Chain(layers))
        return flow

    def create_distribution(
        self, conditioning: e3nn.IrrepsArray
    ) -> distrax.Distribution:
        """Creates a distribution by composing a base distribution with a flow."""
        flow = self.create_flow(conditioning)
        base_distribution = distrax.Independent(
            distrax.Uniform(low=self.range_min, high=self.range_max),
            reinterpreted_batch_ndims=0,
        )
        dist = distrax.Transformed(base_distribution, flow)
        return dist

    def forward(
        self, base_samples: jnp.ndarray, conditioning: e3nn.IrrepsArray
    ) -> jnp.ndarray:
        """Applies the flow to the given samples from the base distribution."""
        flow = self.create_flow(conditioning)
        return flow.forward(base_samples)

    def log_prob(
        self, samples: jnp.ndarray, conditioning: e3nn.IrrepsArray
    ) -> jnp.ndarray:
        """Computes the log probability of the given samples."""
        assert conditioning.shape[:-1] == samples.shape[:-1], (
            conditioning.shape,
            samples.shape,
        )
        dist = self.create_distribution(conditioning)
        return dist.log_prob(samples)

    def sample(self, conditioning: e3nn.IrrepsArray, num_samples: int) -> jnp.ndarray:
        """Samples from the learned distribution."""
        dist = self.create_distribution(conditioning)
        rng = hk.next_rng_key()
        return dist.sample(seed=rng, sample_shape=(num_samples,))

    def langevin_sample(
        self,
        conditioning: e3nn.IrrepsArray,
        num_samples: int,
        beta: float,
        init_step_size: float,
        num_sampling_steps: int,
    ) -> jnp.ndarray:
        """Samples from the learned distribution using Langevin dynamics."""
        dist = self.create_distribution(conditioning)

        def score_fn(samples):
            return jax.grad(dist.log_prob)(samples)

        def update(carry, rng):
            samples, step_size = carry
            new_samples = samples + step_size * score_fn(samples)
            new_samples += jnp.sqrt(2 * step_size / beta) * jax.random.normal(
                rng, samples.shape
            )
            new_samples = self.range_min + new_samples % (
                self.range_max - self.range_min
            )

            # jax.debug.print('samples={x}', x=samples)
            # jax.debug.print('new_samples={x}', x=new_samples)
            log_acceptance_ratio = dist.log_prob(new_samples) - dist.log_prob(samples)

            acceptance_ratio = jnp.exp(log_acceptance_ratio)
            # jax.debug.print('acceptance_ratio={x}', x=acceptance_ratio)
            accept = jax.random.bernoulli(rng, acceptance_ratio)
            samples = jnp.where(accept, new_samples, samples)
            step_size = step_size * (1 - 1 / (num_sampling_steps))
            return (samples, step_size), samples

        rng = hk.next_rng_key()

        init_samples = dist.sample(seed=rng, sample_shape=(num_samples,))

        def sample_single_seed_fn(init_sample):
            return jax.lax.scan(
                update,
                (init_sample, init_step_size),
                jax.random.split(rng, num_sampling_steps),
            )

        return jax.vmap(sample_single_seed_fn)(init_samples)

In [None]:
target_radii = jnp.array([1.0, 1.5, 2.0, 2.5, 3.0])

conditioning = e3nn.IrrepsArray("0e", jnp.asarray([1.0]))


@hk.without_apply_rng
@hk.transform
def log_probs_fn(r, conditioning):
    flow = RationalQuadraticSpline(
        num_bins=8, range_min=0.0, range_max=4.0, num_layers=1, num_param_mlp_layers=1
    )
    return flow.log_prob(r, conditioning)


params = log_probs_fn.init(jax.random.PRNGKey(0), target_radii, conditioning)
tx = optax.adam(1e-3)
opt_state = tx.init(params)


def train_step(params, opt_state, conditioning, target_radii):
    def loss_fn(params):
        return -log_probs_fn.apply(params, target_radii, conditioning).mean()

    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = tx.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss


for step in range(10000):
    params, opt_state, loss = train_step(params, opt_state, conditioning, target_radii)
    if step % 100 == 0:
        print(f"Step {step}, loss {loss}")

In [None]:
@hk.transform
def sample_fn(conditioning, num_samples):
    flow = RationalQuadraticSpline(
        num_bins=8, range_min=0.0, range_max=4.0, num_layers=1, num_param_mlp_layers=1
    )
    return flow.sample(conditioning, num_samples)


@hk.transform
def langevin_sample_fn(
    conditioning, num_samples, beta, init_step_size, num_sampling_steps
):
    flow = RationalQuadraticSpline(
        num_bins=8, range_min=0.0, range_max=4.0, num_layers=1, num_param_mlp_layers=1
    )
    return flow.langevin_sample(
        conditioning, num_samples, beta, init_step_size, num_sampling_steps
    )


queries = jnp.linspace(0.0, 4.0, 1000)
queries = queries.reshape(-1, 1)
log_probs = jax.vmap(lambda r: log_probs_fn.apply(params, r, conditioning))(queries)
log_probs = log_probs.reshape(-1)
probs = jnp.exp(log_probs)

(samples, _), trajectory = langevin_sample_fn.apply(
    params,
    jax.random.PRNGKey(0),
    conditioning,
    num_samples=100,
    beta=10,
    init_step_size=1e-3,
    num_sampling_steps=100,
)
samples = samples.reshape(-1)

fig = go.Figure()
fig.add_trace(go.Scatter(x=queries.flatten(), y=probs, mode="lines"))
fig.add_trace(
    go.Scatter(x=samples.flatten(), y=jnp.zeros_like(samples), mode="markers")
)
fig.update_layout(title="Langevin dynamics samples")
fig.show()

In [None]:
frames = []
for t in range(trajectory.shape[1]):
    frame = go.Frame(
        data=[
            go.Scatter(
                x=trajectory[:, t],
                y=jnp.zeros_like(trajectory[:, t]),
                mode="markers",
                marker=dict(size=5),
            )
        ],
        layout=go.Layout(title=f"Langevin Monte Carlo: Timestep {t+1}"),
    )
    frames.append(frame)

anim_layout = go.Layout(
    xaxis=dict(range=[0, 4], autorange=False),
    yaxis=dict(range=[-0.1, 0.1], autorange=False),
    title="Langevin Monte Carlo Sampling of a Probability Distribution",
    updatemenus=[
        dict(
            type="buttons",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        {
                            "frame": {"duration": 50, "redraw": True},
                            "fromcurrent": True,
                            "transition": {"duration": 50, "easing": "linear"},
                        },
                    ],
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[
                        [None],
                        {
                            "frame": {"duration": 0, "redraw": False},
                            "mode": "immediate",
                            "transition": {"duration": 0},
                        },
                    ],
                ),
            ],
        )
    ],
)

fig = go.Figure(data=[frames[0].data[0]], layout=anim_layout, frames=frames)
fig.show()

In [None]:
samples = sample_fn.apply(params, jax.random.PRNGKey(0), conditioning, 100)
samples = samples.reshape(-1)

fig = go.Figure()
fig.add_trace(go.Scatter(x=queries.flatten(), y=probs, mode="lines"))
fig.add_trace(
    go.Scatter(x=samples.flatten(), y=jnp.zeros_like(samples), mode="markers")
)
fig.show()

# Angular

In [None]:
layout = dict(
    scene=dict(
        xaxis=dict(title="X", range=[-1, 1], autorange=False),
        yaxis=dict(title="Y", range=[-1, 1], autorange=False),
        zaxis=dict(title="Z", range=[-1, 1], autorange=False),
        aspectmode="cube",
    )
)

In [None]:
targets = e3nn.IrrepsArray(
    "1o",
    jnp.asarray(
        [
            [1.0, 0.0, 0.0],
            [0.0, 1.0 / jnp.sqrt(2), 1.0 / jnp.sqrt(2)],
            [0.0, 0.0, 1.0],
            [0.0, -1.0, 0.0],
        ]
    ),
)

In [None]:
def compute_logits(coeffs, target):
    assert target.shape == (3,)

    num_channels = coeffs.shape[0]
    assert coeffs.shape == (num_channels, coeffs.irreps.dim)

    vals = e3nn.to_s2point(coeffs, target)
    vals = vals.array.squeeze(-1)
    assert vals.shape == (num_channels,), vals.shape

    logits = jax.scipy.special.logsumexp(vals, axis=-1)
    assert logits.shape == ()

    return e3nn.IrrepsArray("0e", logits)


def loss(coeffs, targets):
    num_targets = targets.shape[0]
    num_channels = coeffs.shape[0]

    # Compute the logits for each target.
    logits = jax.vmap(lambda target: compute_logits(coeffs, target).array)(targets)
    assert logits.shape == (num_targets,), logits.shape

    # To compute the log-partition function, we need to integrate the signal.
    res_beta = 100
    res_alpha = 99
    prob_signal = e3nn.to_s2grid(
        coeffs, res_beta=res_beta, res_alpha=res_alpha, quadrature="gausslegendre"
    )
    assert prob_signal.shape == (num_channels, res_beta, res_alpha), prob_signal.shape

    prob_signal = prob_signal.apply(jnp.exp)
    prob_signal = prob_signal.replace_values(jnp.sum(prob_signal.grid_values, axis=-3))
    assert prob_signal.shape == (res_beta, res_alpha), prob_signal

    log_Z = jnp.log(prob_signal.integrate().array[0])
    assert log_Z.shape == (), log_Z.shape

    return -jnp.mean(logits) + log_Z

In [None]:
@jax.jit
def train_step(coeffs, targets, opt_state):
    loss_value, grads = jax.value_and_grad(loss)(coeffs, targets)
    updates, opt_state = tx.update(grads, opt_state)
    coeffs = optax.apply_updates(coeffs, updates)
    return coeffs, loss_value, opt_state


lmax = 5
irreps = e3nn.s2_irreps(lmax)
num_channels = 5
coeffs = e3nn.normal(
    irreps,
    leading_shape=(num_channels,),
    key=jax.random.PRNGKey(0),
)

tx = optax.chain(
    optax.adam(1e-3),
    optax.clip_by_global_norm(1.0),
)
opt_state = tx.init(coeffs)

for i in range(10000):
    coeffs, loss_value, opt_state = train_step(coeffs, targets, opt_state)
    if i % 1000 == 0:
        print("loss", loss_value)

In [None]:
logits_sig = e3nn.SphericalSignal.from_function(
    lambda target: compute_logits(coeffs, e3nn.IrrepsArray("1o", target)).array,
    res_beta=100,
    res_alpha=99,
    quadrature="gausslegendre",
)
print(logits_sig.grid_vectors.shape)
go.Figure(
    [
        go.Surface(logits_sig.plotly_surface()),
    ]
).show()

In [None]:
def coeffs_to_prob_sig(coeffs):
    sig = e3nn.to_s2grid(coeffs, res_beta=100, res_alpha=99, quadrature="gausslegendre")
    assert sig.shape == (num_channels, 100, 99), sig.shape
    sig = sig.replace_values(sig.grid_values - jnp.max(sig.grid_values))
    sig = sig.replace_values(jnp.exp(sig.grid_values))
    sig = sig.replace_values(jnp.sum(sig.grid_values, axis=-3))
    sig /= sig.integrate().array[0]
    return sig


prob_sig = coeffs_to_prob_sig(coeffs)
go.Figure(
    [
        go.Surface(prob_sig.plotly_surface(scale_radius_by_amplitude=False)),
        go.Scatter3d(
            x=targets.array[:, 0] * 1.1,
            y=targets.array[:, 1] * 1.1,
            z=targets.array[:, 2] * 1.1,
            mode="markers",
            marker=dict(size=10, color="green"),
            visible=True,
        ),
    ],
    layout=layout,
).show()

In [None]:
def discrete_sample(coeffs, beta):
    coeffs *= beta
    prob_sig = coeffs_to_prob_sig(coeffs)
    num_samples = 100
    rngs = jax.random.split(jax.random.PRNGKey(0), num_samples)
    beta_indices, alpha_indices = jax.vmap(prob_sig.sample)(rngs)
    samples = jax.vmap(lambda bi, ai: prob_sig.grid_vectors[bi, ai])(
        beta_indices, alpha_indices
    )
    return samples


beta = 10
samples = discrete_sample(coeffs, beta)
go.Figure(
    [
        go.Scatter3d(
            x=samples[:, 0],
            y=samples[:, 1],
            z=samples[:, 2],
            mode="markers",
            marker=dict(size=10, color="red"),
        ),
    ],
    layout=layout,
).show()

In [None]:
@jax.jit
def score(sample, coeffs: e3nn.IrrepsArray):
    return e3nn.grad(compute_logits, argnums=1)(coeffs, sample)


@jax.jit
def project_update_on_tangent_space(sample, update):
    return update - e3nn.dot(sample, update) * sample


@jax.jit
def apply_exponential_map(sample, update):
    update_norm = jnp.linalg.norm(update.array)
    return jnp.cos(update_norm) * sample + jnp.sin(update_norm) * update / update_norm


@functools.partial(jax.jit, static_argnames=("num_steps"))
def langevin_monte_carlo(init_sample, coeffs, beta, key, num_steps, init_step_size):
    def update(state, key):
        sample, step_size = state

        # Compute Langevin dynamics update.
        key, noise_key = jax.random.split(key)
        update = step_size * score(sample, coeffs)
        update += jnp.sqrt(2 * step_size / beta) * e3nn.normal("1o", noise_key)
        update = project_update_on_tangent_space(sample, update)

        new_sample = apply_exponential_map(sample, update)

        # Apply Metropolis-Hastings correction.
        key, mh_key = jax.random.split(key)
        log_acceptance_ratio = (
            compute_logits(coeffs, new_sample) - compute_logits(coeffs, sample)
        ).array
        log_acceptance_ratio = jnp.minimum(0, log_acceptance_ratio)
        acceptance_ratio = jnp.exp(log_acceptance_ratio)
        acceptance = jax.random.bernoulli(mh_key, acceptance_ratio)
        new_sample = jnp.where(acceptance, new_sample.array, sample.array)
        new_sample = e3nn.IrrepsArray("1o", new_sample)

        new_step_size = step_size * (1 - 1 / (num_steps))
        return (new_sample, new_step_size), new_sample

    return jax.lax.scan(
        update, (init_sample, init_step_size), jax.random.split(key, num_steps)
    )


def sample_from_uniform_distribution_on_sphere(rng):
    z = jax.random.uniform(rng, (3,), minval=-1, maxval=1)
    return z / jnp.linalg.norm(z)


beta = 10
step_size = 1.0
num_steps = 1000
num_samples = 100
key = jax.random.PRNGKey(0)
init_samples = e3nn.IrrepsArray(
    "1o",
    jax.vmap(sample_from_uniform_distribution_on_sphere)(
        jax.random.split(key, num_samples)
    ),
)
keys = jax.random.split(key, num_samples)
(samples, _), trajectory = jax.vmap(
    lambda sample, key: langevin_monte_carlo(
        sample, coeffs, beta, key, num_steps, step_size
    )
)(init_samples, keys)

samples = samples.array
go.Figure(
    [
        go.Scatter3d(
            x=samples[:, 0],
            y=samples[:, 1],
            z=samples[:, 2],
            mode="markers",
            marker=dict(size=10, color="red"),
        ),
    ],
    layout=layout,
).show()

In [None]:
plot_trajectory = trajectory.array.transpose(1, 0, 2)

# Assuming your numpy array is named 'data' with shape (100, 10, 3)
x = plot_trajectory[:, :, 0]
y = plot_trajectory[:, :, 1]
z = plot_trajectory[:, :, 2]

frames = []
for t in range(plot_trajectory.shape[0]):
    frame = go.Frame(
        data=[
            go.Scatter3d(x=x[t], y=y[t], z=z[t], mode="markers", marker=dict(size=5))
        ],
        layout=go.Layout(title=f"Langevin Monte Carlo: Timestep {t+1}"),
    )
    frames.append(frame)

anim_layout = go.Layout(
    scene=layout["scene"],
    title="Langevin Monte Carlo Sampling of a Probability Distribution",
    updatemenus=[
        dict(
            type="buttons",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        {
                            "frame": {"duration": 50, "redraw": True},
                            "fromcurrent": True,
                            "transition": {"duration": 50, "easing": "linear"},
                        },
                    ],
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[
                        [None],
                        {
                            "frame": {"duration": 0, "redraw": False},
                            "mode": "immediate",
                            "transition": {"duration": 0},
                        },
                    ],
                ),
            ],
        )
    ],
)

fig = go.Figure(data=[frames[0].data[0]], layout=anim_layout, frames=frames)
fig.show()