### Imports

In [8]:
import sys
sys.path.append('../')
import jax.numpy as jnp
import jax
import optax
import chex
import matplotlib.pyplot as plt
import e3nn_jax as e3nn
from src import spectra
import plotly.graph_objects as go
from scipy.stats import special_ortho_group
from typing import Tuple



def visualize(geometry):
    sig = spectra.sum_of_diracs(geometry, lmax=4)
    layout = go.Layout(
        scene=dict(
            xaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            yaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            zaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False, backgroundcolor='rgba(255,255,255,255)', range=[-2.5, 2.5]),
            bgcolor='rgba(255,255,255,255)',
            aspectmode='cube',
            camera=dict(
                eye=dict(x=0.5, y=0.5, z=0.5)
            )
        ),
        plot_bgcolor='rgba(255,255,255,255)',
        paper_bgcolor='rgba(255,255,255,255)',
        margin=dict(l=0, r=0, t=0, b=0),
        showlegend=True
    )
    spherical_harmonics_trace = go.Surface(e3nn.to_s2grid(sig, 100, 99, quadrature="soft").plotly_surface(radius=0.5, scale_radius_by_amplitude=True), name="spherical_harmonics_trace")
    return go.Figure(data=[spherical_harmonics_trace], layout=layout)

### Geometries to test with

In [9]:
true_geometry = jnp.asarray([
    [1, 0, 0],
    [-0.5, jnp.sqrt(3)/2, 0],
    [-0.5, -jnp.sqrt(3)/2, 0]])
true_signal = spectra.sum_of_diracs(true_geometry, lmax=4)


visualize(true_geometry)

In [10]:
random_rotation_matrix = special_ortho_group.rvs(dim=3)
rotated_geometry = jnp.dot(true_geometry, random_rotation_matrix)
rotated_signal = spectra.sum_of_diracs(rotated_geometry, lmax=4)
visualize(rotated_geometry)

### Rotation functions

In [20]:

def compute_mean_squared_distance(quaternion: chex.Array, first_signal: e3nn.SphericalSignal, second_signal: e3nn.SphericalSignal):
    """
    Compute the mean squared distance between each point in rotated first_signal 
    to its closest point in second_signal.

    Args:
    quaternion: Quaternion array of shape (4,).
    first_signal: First spherical signal to be rotated.
    second_signal: Second spherical signal for comparison.

    Returns:
    Mean of the minimum squared distances.
    """
    assert quaternion.shape == (4,)
    rotated_first_signal = first_signal.transform_by_quaternion(quaternion, lmax=4)
    rotated_vectors = rotated_first_signal.grid_vectors * rotated_first_signal.grid_values[..., None]
    second_vectors = second_signal.grid_vectors * second_signal.grid_values[..., None]
    rotated_vectors = rotated_vectors.reshape((-1, 3))
    second_vectors = second_vectors.reshape((-1, 3))
    squared_distances = jnp.linalg.norm(rotated_vectors[:, None, :] - second_vectors[None, :, :], axis=-1) ** 2
    return jnp.min(squared_distances, axis=-1).mean()


@jax.jit
def compute_optimal_rotation(
    input_signal: e3nn.SphericalSignal,
    target_signal: e3nn.SphericalSignal,
    *,
    optimization_steps: int = 256,
    learning_rate: float = 1e-2,
) -> Tuple[chex.Array, chex.Numeric]:
    """
    Returns the optimal rotation matrix and corresponding alignment error 
    to rotate input_signal onto target_signal.

    Args:
    input_signal: Input spherical signal.
    target_signal: Target spherical signal.
    optimization_steps: Number of steps for optimization. Default is 256.
    learning_rate: Learning rate for optimization. Default is 1e-2.
    rng: Random number generator key.

    Returns:
    Tuple of optimal rotation matrix and corresponding alignment error.
    """
    rng = jax.random.PRNGKey(0)
    quaternion_rng, rng = jax.random.split(rng)
    
    input_signal = e3nn.to_s2grid(input_signal, 30, 29, quadrature="soft")
    target_signal = e3nn.to_s2grid(target_signal, 30, 29, quadrature="soft")

    def optimize(quaternion):
        optimizer = optax.adam(learning_rate)
        opt_state = optimizer.init(quaternion)

        def update(_, state):
            quaternion, opt_state = state
            loss, grad = jax.value_and_grad(compute_mean_squared_distance)(quaternion, input_signal, target_signal)
            updates, state = optimizer.update(grad, opt_state, quaternion)
            quaternion = optax.apply_updates(quaternion, updates)
            return quaternion, state

        quaternion, state = jax.lax.fori_loop(0, optimization_steps, update, (quaternion, opt_state))
        return quaternion

    quaternions = jax.random.normal(quaternion_rng, (10, 4))
    optimal_quaternions = jax.vmap(optimize)(quaternions)
    losses = jax.vmap(compute_mean_squared_distance, (0, None, None))(optimal_quaternions, input_signal, target_signal)
    optimal_quaternion = optimal_quaternions[jnp.nanargmin(losses)]
    optimal_rotation_matrix = e3nn.quaternion_to_matrix(optimal_quaternion)
    return optimal_rotation_matrix



### Attempt to recover rotation matrix

In [22]:
predicted_rotation_matrix = compute_optimal_rotation(rotated_signal, true_signal)


Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.



In [23]:
predicted_rotation_matrix

Array([[-0.44835693,  0.732217  ,  0.5126736 ],
       [-0.08057908, -0.60432786,  0.79265046],
       [ 0.8902151 ,  0.31407958,  0.32995594]], dtype=float32)

In [24]:
random_rotation_matrix

array([[-0.44835637,  0.73221703,  0.51267415],
       [ 0.08057893,  0.60432819, -0.79265028],
       [-0.89021548, -0.31407907, -0.32995566]])