In [554]:
import functools

import e3nn_jax as e3nn
import jax
import jax.numpy as jnp

import plotly.graph_objects as go

jax.config.update("jax_debug_nans", True)

In [555]:
def coeffs_to_logits(coeffs, res_beta, res_alpha):
    sig = e3nn.to_s2grid(
        coeffs, res_beta, res_alpha, quadrature="soft", p_val=1, p_arg=-1
    )
    return sig


inverse_temperature = 10.0
coeffs = (
    e3nn.s2_dirac(
        jnp.asarray([[1.0, 0.0, 0.0],
                     [0.0, 1.0, 0.0],
                     [0.0, 0.0, 1.0],
                     [0.0, 1.0, 1.0],]),
        lmax=10,
    )
    * inverse_temperature
)
coeffs = e3nn.sum(coeffs)

go.Figure(
    [
        go.Surface(
            coeffs_to_logits(
                coeffs, res_beta=90, res_alpha=89
            ).plotly_surface(scale_radius_by_amplitude=True)
        )
    ],
    layout=go.Layout(
        title="Coefficients",
    )
)

In [556]:
def coeffs_to_probability_distribution(
    coeffs, res_beta, res_alpha
) -> e3nn.SphericalSignal:
    sig = e3nn.to_s2grid(
        coeffs, res_beta, res_alpha, quadrature="soft", p_val=1, p_arg=-1
    )
    sig = sig.replace_values(sig.grid_values - jnp.max(sig.grid_values))
    sig = sig.apply(lambda x: jnp.exp(x))
    sig /= sig.integrate()
    return sig


sig = coeffs_to_probability_distribution(
    coeffs, res_beta=90, res_alpha=89
).apply(lambda x: jnp.log(x + 1e-9))

cmin = float(sig.grid_values.min())
cmax = float(sig.grid_values.max())
print(cmin, cmax)

go.Figure(
    [
        go.Surface(
            sig.plotly_surface(),
            colorscale='Viridis',
            colorbar=dict(title='Probability'),
            cmin=cmin,  # Set the minimum value for the log scale
            cmax=cmax,  # Set the maximum value for the log scale
            colorbar_tickvals=jnp.linspace(cmin, cmax, 5),  # Set the tick values for the colorbar
            colorbar_ticktext=[f"{val}" for val in 10 ** jnp.linspace(cmin, cmax, 5)],  # Set the tick labels for the colorbar
        )
    ]
)


-20.7232666015625 6.102669715881348


In [557]:
def probability(samples, coeffs):
    logits = jax.vmap(lambda sample: e3nn.to_s2point(coeffs, sample).array)(samples)
    logits = logits.squeeze(axis=-1)
    probs = jax.nn.softmax(logits)
    return probs

In [558]:
def sample_from_probability_distribution(sig: e3nn.SphericalSignal, key, num_samples):
    keys = jax.random.split(key, num_samples)
    beta_indices, alpha_indices = jax.vmap(sig.sample)(keys)
    samples = sig.grid_vectors[beta_indices, alpha_indices]
    return e3nn.IrrepsArray("1o", samples)


def plot_samples(samples, title, probs=None):
    fig = go.Figure(
        [
            go.Scatter3d(
                x=samples.array[:, 0],
                y=samples.array[:, 1],
                z=samples.array[:, 2],
                marker=dict(color=[float(prob) for prob in probs],
                            colorscale='Viridis',
                            cmin=float(probs.min()),
                            cmax=float(probs.max()),
                            colorbar=dict(title='Probability')),
                text=[f"Prob: {prob.round(3)}" for prob in probs],
                hoverinfo='text',
                mode="markers",
                
            )
        ]
    )
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=[-1, 1]),
            yaxis=dict(range=[-1, 1]),
            zaxis=dict(range=[-1, 1]),
            aspectmode="cube",
        )
    )
    fig.update_layout(title=title)
    return fig


key = jax.random.PRNGKey(0)
res_beta = 90
res_alpha = 89
samples = sample_from_probability_distribution(
    coeffs_to_probability_distribution(coeffs, res_beta=res_beta, res_alpha=res_alpha),
    key,
    100,
)

plot_samples(
    samples,
    title=f"Discretized Samples: res_beta = {res_beta}, res_alpha = {res_alpha}",
    probs=probability(samples, coeffs),
)

In [559]:
key = jax.random.PRNGKey(0)
init_coeffs = e3nn.IrrepsArray("0e", jnp.ones(1))
num_samples = 100
init_samples = sample_from_probability_distribution(
    coeffs_to_probability_distribution(init_coeffs, res_beta=100, res_alpha=99),
    key,
    num_samples,
)

# plot_samples(init_samples, title="Initial Samples")

In [560]:
@jax.jit
def score(sample, coeffs):
    def compute_log_probability(sample):
        value = e3nn.to_s2point(coeffs, sample)
        return value

    return e3nn.grad(compute_log_probability)(sample)


sig = e3nn.SphericalSignal.zeros(res_beta=10, res_alpha=9, quadrature="soft")
vectors = sig.grid_vectors.reshape(-1, 3)
values = jax.vmap(jax.vmap(lambda sample: score(e3nn.IrrepsArray("1o", sample), coeffs).array))(sig.grid_vectors)
values = values.reshape(-1, 3)

go.Figure(
    [
        go.Cone(
            x=vectors[:, 0],
            y=vectors[:, 1],
            z=vectors[:, 2],
            u=values[:, 0],
            v=values[:, 1],
            w=values[:, 2],
            colorscale="Viridis",
            sizemode="scaled",
            sizeref=5,
            showscale=True,
        )
    ]
)

In [561]:


def project_update_on_tangent_space(sample, update):
    return update - e3nn.dot(sample, update) * sample


def apply_exponential_map(sample, update):
    update_norm = jnp.linalg.norm(update.array)
    return jnp.cos(update_norm) * sample + jnp.sin(update_norm) * update / update_norm



@functools.partial(jax.jit, static_argnames=("num_steps"))
def langevin_monte_carlo(init_sample, coeffs, key, num_steps, init_step_size):
    def update(state, key):
        sample, step_size = state
        
        key, noise_key = jax.random.split(key)
        update = step_size * score(sample, coeffs)
        update += jnp.sqrt(2 * step_size) * e3nn.normal("1o", noise_key)
        update = project_update_on_tangent_space(sample, update)

        new_sample = apply_exponential_map(sample, update)

        # Apply Metropolis-Hastings correction
        key, mh_key = jax.random.split(key)
        log_acceptance_ratio = (
            e3nn.to_s2point(coeffs, new_sample)
            - e3nn.to_s2point(coeffs, sample)
        ).array
        log_acceptance_ratio = jnp.minimum(0, log_acceptance_ratio)
        acceptance_ratio = jnp.exp(log_acceptance_ratio)
        acceptance = jax.random.bernoulli(mh_key, acceptance_ratio)
        new_sample = jnp.where(acceptance, new_sample.array, sample.array)
        new_sample = e3nn.IrrepsArray("1o", new_sample)

        new_step_size = step_size * 0.999
        return (new_sample, new_step_size), new_sample

    return jax.lax.scan(
        update, (init_sample, init_step_size), jax.random.split(key, num_steps)
    )


step_size = 1.
print(f"Initial step size: {step_size}")
num_steps = 250
keys = jax.random.split(key, num_samples)
(samples, _), trajectory = jax.vmap(
    lambda sample, key: langevin_monte_carlo(sample, coeffs, key, num_steps, step_size)
)(init_samples, keys)

plot_samples(samples, title="Metropolis-Adjusted Langevin Monte Carlo Samples",
             probs=probability(samples, coeffs))

Initial step size: 1.0


In [562]:
plot_trajectory = trajectory.array.transpose(1, 0, 2)

# Assuming your numpy array is named 'data' with shape (100, 10, 3)
x = plot_trajectory[:, :, 0]
y = plot_trajectory[:, :, 1]
z = plot_trajectory[:, :, 2]

frames = []
for t in range(plot_trajectory.shape[0]):
    frame = go.Frame(
        data=[
            go.Scatter3d(x=x[t], y=y[t], z=z[t], mode="markers", marker=dict(size=5))
        ],
        layout=go.Layout(title=f"Metropolis-Adjusted Langevin Monte Carlo: Timestep {t+1}"),
    )
    frames.append(frame)

layout = go.Layout(
    scene=dict(
        xaxis=dict(title="X", range=[-1, 1], autorange=False),
        yaxis=dict(title="Y", range=[-1, 1], autorange=False),
        zaxis=dict(title="Z", range=[-1, 1], autorange=False),
        aspectmode="cube",
    ),
    title="Metropolis-Adjusted Langevin Monte Carlo Sampling",
    updatemenus=[
        dict(
            type="buttons",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        {
                            "frame": {"duration": 50, "redraw": True},
                            "fromcurrent": True,
                            "transition": {"duration": 50, "easing": "linear"},
                        },
                    ],
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[[None], {"frame": {"duration": 0, "redraw": False},
                                "mode": "immediate", "transition": {"duration": 0}}]
                )
            ],
        )
    ],
)

fig = go.Figure(data=[frames[0].data[0]], layout=layout, frames=frames)
fig.show()