In [None]:
import functools

import chex
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import haiku as hk
import distrax
import optax

import plotly.graph_objects as go

In [None]:
class RationalQuadraticSpline(hk.Module):
    """A rational quadratic spline flow."""

    def __init__(
        self,
        num_bins: int,
        range_min: float,
        range_max: float,
        num_layers: int,
        num_param_mlp_layers: int,
    ):
        super().__init__()
        self.num_bins = num_bins
        self.range_min = range_min
        self.range_max = range_max
        self.num_layers = num_layers
        self.num_param_mlp_layers = num_param_mlp_layers

    def create_flow(self, conditioning: e3nn.IrrepsArray) -> distrax.Bijector:
        """Creates a flow with the given conditioning."""
        conditioning = conditioning.filter("0e")
        conditioning = conditioning.array

        layers = []
        for _ in range(self.num_layers):
            param_dims = self.num_bins * 3 + 1
            params = hk.nets.MLP(
                [param_dims] * self.num_param_mlp_layers,
                activate_final=False,
                w_init=hk.initializers.RandomNormal(1e-4),
                b_init=hk.initializers.RandomNormal(1e-4),
            )(conditioning)
            layer = distrax.RationalQuadraticSpline(
                params,
                self.range_min,
                self.range_max,
                boundary_slopes="unconstrained",
                min_bin_size=1e-2,
            )
            layers.append(layer)

        flow = distrax.Inverse(distrax.Chain(layers))
        return flow

    def create_distribution(
        self, conditioning: e3nn.IrrepsArray
    ) -> distrax.Distribution:
        """Creates a distribution by composing a base distribution with a flow."""
        flow = self.create_flow(conditioning)
        base_distribution = distrax.Independent(
            distrax.Uniform(low=self.range_min, high=self.range_max),
            reinterpreted_batch_ndims=0,
        )
        dist = distrax.Transformed(base_distribution, flow)
        return dist

    def forward(
        self, base_sample: jnp.ndarray, conditioning: e3nn.IrrepsArray
    ) -> jnp.ndarray:
        """Applies the flow to the given samples from the base distribution."""
        flow = self.create_flow(conditioning)
        return flow.forward(base_sample)

    def log_prob(
        self, samples: jnp.ndarray, conditioning: e3nn.IrrepsArray
    ) -> jnp.ndarray:
        """Computes the log probability of the given samples."""
        dist = self.create_distribution(conditioning)
        return dist.log_prob(samples)

    def direct_sample(
        self, conditioning: e3nn.IrrepsArray, num_samples: int
    ) -> jnp.ndarray:
        """Samples from the learned distribution."""
        dist = self.create_distribution(conditioning)
        return dist.sample(seed=hk.next_rng_key(), sample_shape=(num_samples,))
    
    def langevin_sample(
        self,
        conditioning: e3nn.IrrepsArray,
        num_samples: int,
        beta: float,
        init_step_size: float,
        num_sampling_steps: int,
        use_symmetric_proposal: bool = True,
    ) -> jnp.ndarray:
        """Samples from the learned distribution using Langevin dynamics."""
        dist = self.create_distribution(conditioning)

        def score_fn(samples):
            return jax.grad(dist.log_prob)(samples)

        def update(carry, rng):
            samples, step_size = carry
            new_samples = samples + step_size * score_fn(samples)
            new_samples += jnp.sqrt(2 * step_size / beta) * jax.random.normal(
                rng, samples.shape
            )

            # Reflect samples that go out of bounds
            new_samples = self.range_min + new_samples % (self.range_max - self.range_min)

            log_acceptance_ratio = dist.log_prob(new_samples) - dist.log_prob(samples)
            if not use_symmetric_proposal:
                log_acceptance_ratio -= jnp.square(samples - new_samples - step_size * score_fn(new_samples))/(4 * step_size * beta) 
                log_acceptance_ratio += jnp.square(new_samples - samples - step_size * score_fn(samples))/(4 * step_size * beta)
            acceptance_ratio = jnp.exp(log_acceptance_ratio)
            accept = jax.random.bernoulli(rng, acceptance_ratio)
            samples = jnp.where(accept, new_samples, samples)
    
            step_size = step_size * (1 - 1 / (num_sampling_steps))
            return (samples, step_size), samples

        rng = hk.next_rng_key()
        init_samples = dist.sample(seed=rng, sample_shape=(num_samples,))
        sample_rngs = jax.random.split(rng, num_samples)

        def sample_single_seed_fn(init_sample, sample_rng):
            return jax.lax.scan(
                update,
                (init_sample, init_step_size),
                jax.random.split(sample_rng, num_sampling_steps),
            )

        return jax.vmap(sample_single_seed_fn)(init_samples, sample_rngs)

In [None]:
class AngularPredictor(hk.Module):
    def __init__(
        self,
        max_ell: int,
        num_channels: int,
        radial_mlp_layers: int,
        radial_embed_dim: int,
        max_radius: int,
        res_beta: float,
        res_alpha: float,
        quadrature: str,
    ):
        super().__init__()
        self.max_ell = max_ell
        self.num_channels = num_channels
        self.radial_mlp_layers = radial_mlp_layers
        self.radial_embed_dim = radial_embed_dim
        self.max_radius = max_radius
        self.res_beta = res_beta
        self.res_alpha = res_alpha
        self.quadrature = quadrature

    def coeffs(self, radius: float, conditioning: e3nn.IrrepsArray):
        radial_embed = e3nn.bessel(radius, self.radial_embed_dim, x_max=self.max_radius)
        radial_embed = jnp.atleast_2d(radial_embed)
        radial_embed = e3nn.haiku.MultiLayerPerceptron(
            [self.radial_embed_dim] * (self.radial_mlp_layers - 1)
            + [conditioning.irreps.num_irreps],
            act=jax.nn.swish,
            output_activation=True,
        )(radial_embed)

        conditioning *= radial_embed
        coeffs = e3nn.haiku.Linear(
            irreps_out=e3nn.s2_irreps(self.max_ell), channel_out=self.num_channels
        )(conditioning)
        assert coeffs.shape == (self.num_channels, (self.max_ell + 1) ** 2)

        return coeffs

    def logits(self, position: e3nn.IrrepsArray, conditioning: e3nn.IrrepsArray):
        assert position.shape == (3,), position.shape
        
        radius = jnp.linalg.norm(position.array)
        coeffs = self.coeffs(radius, conditioning)
        normalized_position = position / radius
        return self.logits_with_coeffs(normalized_position, coeffs)

    def logits_with_coeffs(self, normalized_position: e3nn.IrrepsArray, coeffs: e3nn.IrrepsArray):
        assert normalized_position.shape == (3,), normalized_position.shape

        vals = e3nn.to_s2point(coeffs, normalized_position)
        vals = vals.array.squeeze(-1)
        assert vals.shape == (self.num_channels,), vals.shape

        logits = jax.scipy.special.logsumexp(vals, axis=-1)
        assert logits.shape == (), logits.shape
        return logits

    def log_partition_function(self, radius: float, conditioning: e3nn.IrrepsArray):
        coeffs = self.coeffs(radius, conditioning)
        prob_signal = e3nn.to_s2grid(coeffs, res_beta=self.res_beta, res_alpha=self.res_alpha, quadrature=self.quadrature)
        assert prob_signal.shape == (
            self.num_channels,
            self.res_beta,
            self.res_alpha,
        )

        prob_signal = prob_signal.replace_values(prob_signal.grid_values)
        prob_signal = prob_signal.replace_values(jnp.exp(prob_signal.grid_values))
        prob_signal = prob_signal.replace_values(jnp.sum(prob_signal.grid_values, axis=-3))
        assert prob_signal.shape == (self.res_beta, self.res_alpha), prob_signal.shape

        log_Z = jnp.log(prob_signal.integrate().array.sum())
        assert log_Z.shape == (), log_Z.shape

        return log_Z
    
    @staticmethod
    def coeffs_to_probability_distribution(coeffs, res_beta, res_alpha, quadrature):
        num_channels = coeffs.shape[-2]
        
        prob_signal = e3nn.to_s2grid(coeffs, res_beta=res_beta, res_alpha=res_alpha, quadrature=quadrature)
        assert prob_signal.shape == (
            num_channels,
            res_beta,
            res_alpha,
        )

        prob_signal = prob_signal.replace_values(prob_signal.grid_values - jnp.max(prob_signal.grid_values))
        prob_signal = prob_signal.replace_values(jnp.exp(prob_signal.grid_values))
        prob_signal = prob_signal.replace_values(jnp.sum(prob_signal.grid_values, axis=-3))
        prob_signal /= prob_signal.integrate().array.sum()
        assert prob_signal.shape == (
            res_beta,
            res_alpha,
        )
        return prob_signal

    def langevin_sample(self, radius: float, conditioning: e3nn.IrrepsArray,
                        num_samples: int, beta: float, num_sampling_steps: int, init_step_size: float):

        coeffs = self.coeffs(radius, conditioning)

        def sample_from_uniform_distribution_on_sphere(key: chex.PRNGKey) -> e3nn.IrrepsArray:
            key, z_key = jax.random.split(key)
            z = jax.random.uniform(z_key, (3,), minval=-1, maxval=1)
            z /= jnp.linalg.norm(z)
            z = e3nn.IrrepsArray("1o", z)
            return z, key

        def score(sample: e3nn.IrrepsArray) -> float:
            return jax.grad(self.logits_with_coeffs, argnums=0)(sample, coeffs)

        def project_update_on_tangent_space(sample: e3nn.IrrepsArray, update: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
            return update - e3nn.dot(sample, update) * sample

        def apply_exponential_map(sample: e3nn.IrrepsArray, update: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
            update_norm = jnp.linalg.norm(update.array)
            return jnp.cos(update_norm) * sample + jnp.sin(update_norm) * update / update_norm

        def update(state, key: chex.PRNGKey):
            sample, step_size = state

            # Compute Langevin dynamics update.
            key, noise_key = jax.random.split(key)
            update = step_size * score(sample)
            update += jnp.sqrt(2 * step_size / beta) * e3nn.normal("1o", noise_key)
            update = project_update_on_tangent_space(sample, update)
            new_sample = apply_exponential_map(sample, update)

            # Apply Metropolis-Hastings correction.
            key, mh_key = jax.random.split(key)
            log_acceptance_ratio = (
                self.logits_with_coeffs(new_sample, coeffs)
                - self.logits_with_coeffs(sample, coeffs)
            )
            log_acceptance_ratio = jnp.minimum(0, log_acceptance_ratio)
            acceptance_ratio = jnp.exp(log_acceptance_ratio)
            acceptance = jax.random.bernoulli(mh_key, acceptance_ratio)
            new_sample = jnp.where(acceptance, new_sample.array, sample.array)
            new_sample = e3nn.IrrepsArray("1o", new_sample)

            new_step_size = step_size * (1 - 1 / (num_sampling_steps))
            return (new_sample, new_step_size), new_sample

        rng = hk.next_rng_key()
        sample_rngs = jax.random.split(rng, num_samples)
        init_samples, sample_rngs = jax.vmap(sample_from_uniform_distribution_on_sphere)(sample_rngs)

        def sample_single_seed_fn(init_sample, sample_rng):
            """Samples the chain for a single seed."""
            return jax.lax.scan(
                update,
                (init_sample, init_step_size),
                jax.random.split(sample_rng, num_sampling_steps),
            )

        (positions, _), trajectory = jax.vmap(sample_single_seed_fn)(init_samples, sample_rngs)
        # Scale by the radius.
        positions = e3nn.IrrepsArray("1o", positions.array * radius)
        trajectory = e3nn.IrrepsArray("1o", trajectory.array * radius)
        return positions, trajectory



In [None]:
def create_radial_predictor():
    return RationalQuadraticSpline(
        num_bins=16,
        range_min=0.0,
        range_max=5.0,
        num_layers=2,
        num_param_mlp_layers=2,
    )


def create_angular_predictor():
    return AngularPredictor(
        max_ell=5,
        num_channels=2,
        radial_mlp_layers=2,
        radial_embed_dim=8,
        max_radius=5.0,
        res_beta=100,
        res_alpha=99,
        quadrature="gausslegendre",
    )


@hk.without_apply_rng
@hk.transform
def radial_predictor_fn(position, focus_node_embedding):
    radial_predictor = create_radial_predictor()
    return radial_predictor.log_prob(
        jnp.linalg.norm(position.array),
        conditioning=focus_node_embedding,
    )


@hk.without_apply_rng
@hk.transform
def angular_predictor_fn(position, focus_node_embedding):
    angular_predictor = create_angular_predictor()
    logits = angular_predictor.logits(position, conditioning=focus_node_embedding)
    log_partition_function = angular_predictor.log_partition_function(
        jnp.linalg.norm(position.array), conditioning=focus_node_embedding
    )
    return logits, log_partition_function


@jax.jit
def loss(params, targets):
    num_targets = targets.shape[0]
    focus_node_embedding = params["focus_node_embedding"]

    def logits_at_single_target(target: e3nn.IrrepsArray) -> float:
        radial_logits = radial_predictor_fn.apply(
            params["radial_predictor"], target, focus_node_embedding
        )
        angular_logits, angular_log_partition_function = angular_predictor_fn.apply(
            params["angular_predictor"], target, focus_node_embedding
        )
        return radial_logits + angular_logits - angular_log_partition_function

    logits = jax.vmap(logits_at_single_target)(targets)
    assert logits.shape == (num_targets,), logits.shape

    return -logits.mean()

In [None]:
focus_node_embedding = e3nn.normal(e3nn.s2_irreps(5) * 10, jax.random.PRNGKey(0))

radial_predictor_params = radial_predictor_fn.init(
    jax.random.PRNGKey(42),
    e3nn.zeros("1o"),
    focus_node_embedding,
)
angular_predictor_params = angular_predictor_fn.init(
    jax.random.PRNGKey(42),
    e3nn.zeros("1o"),
    focus_node_embedding,
)
params = {
    "focus_node_embedding": focus_node_embedding,
    "radial_predictor": radial_predictor_params,
    "angular_predictor": angular_predictor_params,
}

In [None]:
targets = e3nn.IrrepsArray(
    "1o",
    jnp.asarray(
        [
            [3.0, 0.0, 0.0],
            [0.0, 1.0 / jnp.sqrt(2), 1.0 / jnp.sqrt(2)],
            [0.0, 0.0, 2.0],
            [0.0, -1.0, 0.0],
        ]
    ),
)

In [None]:
tx = optax.chain(
    optax.adam(1e-3),
    optax.clip_by_global_norm(1.0),
)
opt_state = tx.init(params)


@jax.jit
def train_step(rng, params, opt_state, targets):
    rng, noise_rng = jax.random.split(rng)
    # targets += e3nn.normal(targets.irreps, noise_rng) * 0.1
    loss_val, grads = jax.value_and_grad(loss)(params, targets)
    updates, opt_state = tx.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return rng, new_params, opt_state, loss_val


rng = jax.random.PRNGKey(42)
for step in range(1000):
    rng, params, opt_state, loss_val = train_step(rng, params, opt_state, targets)
    if step % 100 == 0:
        print(f"Step {step}, loss {loss_val}")

In [None]:
@hk.without_apply_rng
@hk.transform
def log_prob_radial(r, focus_node_embedding):
    radial_predictor = create_radial_predictor()
    return radial_predictor.log_prob(
        r,
        conditioning=focus_node_embedding,
    )

@hk.transform
def direct_sample_radius(focus_node_embedding, num_samples):
    return create_radial_predictor().direct_sample(
        focus_node_embedding,
        num_samples,
    )


# Plot radial distribution
queries = jnp.linspace(0.0, 5.0, 100)
log_probs = jax.vmap(
    lambda r: log_prob_radial.apply(
        params["radial_predictor"], r, params["focus_node_embedding"]
    )
)(queries)
probs = jnp.exp(log_probs)
samples = direct_sample_radius.apply(params["radial_predictor"], jax.random.PRNGKey(0), params["focus_node_embedding"], 100)

fig = go.Figure()
fig.add_trace(go.Scatter(x=queries, y=probs))
fig.add_trace(go.Scatter(x=samples, y=jnp.zeros_like(samples), mode="markers"))
fig.update_layout(
    title="Radial distribution", xaxis_title="Radius", yaxis_title="Probability"
)
fig.show()

In [None]:
@hk.transform
def langevin_sample_radius(focus_node_embedding, num_samples):
    return create_radial_predictor().langevin_sample(
        focus_node_embedding,
        num_samples,
        beta=100,
        init_step_size=0.001,
        num_sampling_steps=1000,
    )


(radius_samples, _), trajectory = langevin_sample_radius.apply(
    params["radial_predictor"],
    jax.random.PRNGKey(0),
    params["focus_node_embedding"],
    num_samples=100,
)

fig = go.Figure()
fig.add_trace(go.Scatter(x=queries, y=probs))
fig.add_trace(
    go.Scatter(x=radius_samples, y=jnp.zeros_like(radius_samples), mode="markers")
)
fig.update_layout(
    title="Langevin sampled radii", xaxis_title="Radius", yaxis_title="Count",
    xaxis=dict(range=[0.0, 5.0], autorange=False),
)
fig.show()

In [None]:
frames = []
for t in range(trajectory.shape[1]):
    frame = go.Frame(
        data=[
            go.Scatter(
                x=trajectory[:, t],
                y=jnp.zeros_like(trajectory[:, t]),
                mode="markers",
                marker=dict(size=5),
            )
        ],
        layout=go.Layout(title=f"Langevin Monte Carlo: Timestep {t+1}"),
    )
    frames.append(frame)

anim_layout = go.Layout(
    xaxis=dict(range=[0, 4], autorange=False),
    yaxis=dict(range=[-0.1, 0.1], autorange=False),
    title="Langevin Monte Carlo Sampling of a Probability Distribution",
    updatemenus=[
        dict(
            type="buttons",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        {
                            "frame": {"duration": 50, "redraw": True},
                            "fromcurrent": True,
                            "transition": {"duration": 50, "easing": "linear"},
                        },
                    ],
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[
                        [None],
                        {
                            "frame": {"duration": 0, "redraw": False},
                            "mode": "immediate",
                            "transition": {"duration": 0},
                        },
                    ],
                ),
            ],
        )
    ],
)

fig = go.Figure(data=[frames[0].data[0]], layout=anim_layout, frames=frames)
fig.show()

In [None]:
@hk.without_apply_rng
@hk.transform
def get_angular_coeffs(radius, focus_node_embedding):
    angular_predictor = create_angular_predictor()
    return angular_predictor.coeffs(
        radius,
        focus_node_embedding,
    )


for radius in [0.5, 1.0, 2.0, 3.0, 4.0]:
    coeffs = get_angular_coeffs.apply(
        params["angular_predictor"], radius, params["focus_node_embedding"]
    )
    prob_sig = AngularPredictor.coeffs_to_probability_distribution(coeffs, res_beta=100, res_alpha=99, quadrature="gausslegendre")
    targets_at_this_radius = targets.array[
        jnp.abs(jnp.linalg.norm(targets.array, axis=-1) - radius) < 0.1
    ]
    if targets_at_this_radius.shape[0] == 0:
        continue

    fig = go.Figure(
        [
            go.Surface(
                prob_sig.plotly_surface(scale_radius_by_amplitude=False, radius=radius)
            ),
            go.Scatter3d(
                x=targets_at_this_radius[:, 0] * 1.2,
                y=targets_at_this_radius[:, 1] * 1.2,
                z=targets_at_this_radius[:, 2] * 1.2,
                mode="markers",
                marker=dict(size=10, color="green"),
            ),
        ],
        layout=dict(
            scene=dict(
                xaxis=dict(
                    title="X", range=[-radius * 1.5, radius * 1.5], autorange=False
                ),
                yaxis=dict(
                    title="Y", range=[-radius * 1.5, radius * 1.5], autorange=False
                ),
                zaxis=dict(
                    title="Z", range=[-radius * 1.5, radius * 1.5], autorange=False
                ),
                aspectmode="cube",
            )
        ),
    )
    fig.update_layout(title=f"Angular distribution at radius {radius}")
    fig.show()

In [None]:
count = 0
cmin = 0.
cmax = 100
prob_traces = []
for radius in jnp.linspace(0., 5., 51):
    coeffs = get_angular_coeffs.apply(
        params["angular_predictor"], radius, params["focus_node_embedding"]
    )
    prob_sig = AngularPredictor.coeffs_to_probability_distribution(coeffs, res_beta=100, res_alpha=99, quadrature="gausslegendre")
    radius_prob = jnp.exp(log_prob_radial.apply(
        params["radial_predictor"], radius, params["focus_node_embedding"]
    ))
    prob_sig *= radius_prob
    if prob_sig.grid_values.max() < 1e-2 * cmax:
        continue

    print("Found a good one", radius, prob_sig.grid_values.max())
    count += 1
    surface_r = go.Surface(
        **prob_sig.plotly_surface(radius=radius),
        colorscale=[[0, "rgba(4, 59, 192, 0.)"], [1, "rgba(4, 59, 192, 1.)"]],
        showscale=False,
        cmin=cmin,
        cmax=cmax,
        name="Position Probabilities",
        legendgroup="Position Probabilities",
        showlegend=(count == 1),
    )
    prob_traces.append(surface_r)

fig = go.Figure(
    prob_traces,
    layout=dict(
        scene=dict(
            xaxis=dict(
                title="X", range=[-5, 5], autorange=False
            ),
            yaxis=dict(
                title="Y", range=[-5, 5], autorange=False
            ),
            zaxis=dict(
                title="Z", range=[-5, 5], autorange=False
            ),
            aspectmode="cube",
        )
    ),
)
fig.update_layout(title=f"Position distribution")
fig.show()

In [None]:
@hk.transform
def sample_angular_fn(radii, focus_node_embedding, num_samples):
    positions, _ = create_angular_predictor().langevin_sample(radii, focus_node_embedding, num_samples, beta=10., num_sampling_steps=1000, init_step_size=10.)
    return positions

rng = jax.random.PRNGKey(0)
radii = direct_sample_radius.apply(params["radial_predictor"], rng, params["focus_node_embedding"], 10)
positions = jax.jit(jax.vmap(lambda r: sample_angular_fn.apply(params["angular_predictor"], rng, r, params["focus_node_embedding"], 10)))(radii)
positions = positions.array.reshape(-1, 3)

fig = go.Figure()
fig.add_trace(go.Scatter(x=radii, y=jnp.zeros_like(radii), mode="markers"))
fig.update_layout(title="Sampled radii", xaxis_title="Radius", yaxis_title="Probability")
fig.show()

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=targets.array[:, 0], y=targets.array[:, 1], z=targets.array[:, 2], mode="markers", marker=dict(size=10, color="green"), name="Targets"))
fig.add_trace(go.Scatter3d(x=positions[:, 0], y=positions[:, 1], z=positions[:, 2], mode="markers", marker=dict(size=5), name="Samples"))
fig.update_layout(
    title="Sampled positions",
    scene=dict(
        xaxis=dict(
            title="X", range=[-3, 3], autorange=False
        ),
        yaxis=dict(
            title="Y", range=[-3, 3], autorange=False
        ),
        zaxis=dict(
            title="Z", range=[-3, 3], autorange=False
        ),
        aspectmode="cube",
    )
)
    
fig.show()
