In [1]:
import os
import sys
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
from matplotlib import gridspec
import matplotlib.ticker as ticker

# Add src directory to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.append(src_dir)

from utils.plotters import visualize_geometry, colorplot, visualize_signal, compare_geometries, compare_four_geometries
from utils.geometries import octahedron, trigonal_prism, see_saw, square_plane, tetrahedron, trigonal_plane
import spectra
from spectra import sum_of_diracs, bispectrum

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)


def interpolate_geometries(geom1, geom2, n):
    """
    Linearly interpolate between two geometries.
    
    Parameters:
    geom1: First geometry (array of vertices)
    geom2: Second geometry (array of vertices)
    n: Number of interpolated geometries to return (including start and end)
    
    Returns:
    List of n geometries interpolated between geom1 and geom2
    """
    # Ensure we have at least 2 interpolated geometries
    n = max(2, n)
    
    # Create parameter values from 0 to 1
    t_values = jnp.linspace(0, 1, n)
    
    # Create a list to store interpolated geometries
    interpolated_geometries = []
    
    # For each parameter value
    for t in t_values:
        # Linear interpolation: geom1 + t * (geom2 - geom1)
        interpolated_geom = geom1 + t * (geom2 - geom1)
        interpolated_geometries.append(interpolated_geom)
    
    return jnp.array(interpolated_geometries)

geometries = interpolate_geometries(octahedron, trigonal_prism, 200)

compute_bispectrum = lambda x: bispectrum(sum_of_diracs(x, lmax=4))

bispectra = jax.vmap(compute_bispectrum)(geometries)

def plot_bispectra_interpolation(bispectra):
    """
    Create a colorplot of bispectra vectors across the geometry interpolation.
    
    Parameters:
    bispectra: Array of shape (n_steps, n_components) containing bispectrum values
    """
    # Create figure with reduced height
    plt.figure(figsize=(9, 3))
    
    # Determine the color scale limits symmetrically around zero
    max_abs_val = jnp.max(jnp.abs(bispectra))
    
    # Create the colorplot/heatmap with PuOr colormap
    im = plt.imshow(bispectra.T, aspect='auto', interpolation='nearest', 
                   cmap='PuOr', origin='lower',
                   vmin=-max_abs_val, vmax=max_abs_val)
    
    # Add a colorbar
    cbar = plt.colorbar(im)
    cbar.set_label('Bispectrum Value')
    
    # Set y-axis label
    plt.ylabel('Component')
    
    # Remove x-axis ticks and numbers
    plt.xticks([], [])
    
    # Add labels for the start and end geometries
    plt.text(0, -1.5, 'Octahedron', ha='left', va='center')
    plt.text(bispectra.shape[0]-1, -1.5, 'Trigonal Prism', ha='right', va='center')
    
    # Adjust y-ticks to show component indices
    plt.yticks(range(bispectra.shape[1]))
    
    # Tight layout
    plt.tight_layout()
    plt.show()
    
    return plt.gcf()

# Call the function with the bispectra data
# plot_bispectra_interpolation(bispectra)

In [2]:
"""
Bispectrum Inversion for Molecular Environments - Optimized for Apple M2 Pro
"""
import os
import time
import pickle

import numpy as np
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
from functools import partial
import math

# Add src directory to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.append(src_dir)

from spectra import sum_of_diracs, bispectrum

#------------------------------------------------------------------------------
# JAX & Hardware Configuration
#------------------------------------------------------------------------------

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_default_matmul_precision', 'high')

#------------------------------------------------------------------------------
# Constants
#------------------------------------------------------------------------------

LMAX = 4
MAX_ENV_SIZE = 4

#------------------------------------------------------------------------------
# Helpers
#------------------------------------------------------------------------------

@jax.jit
def invert_with_adam(true_bs, init_pts):
    """Adam-based inversion."""
    lr = 1e-2
    opt = optax.adam(lr)
    state = opt.init(init_pts)
    def loss_fn(x):
        pred_bs = bispectrum(sum_of_diracs(x, LMAX))
        return jnp.mean(jnp.abs(true_bs - pred_bs))
    def step(carry, _):
        pts, st = carry
        g = jax.grad(loss_fn)(pts)
        updates, st = opt.update(g, st, pts)
        return (optax.apply_updates(pts, updates), st), None
    (final_pts, _), _ = jax.lax.scan(step, (init_pts, state), None, length=10000)
    return final_pts

def stack_points(points: np.ndarray) -> jnp.ndarray:
    """
    Cluster points by alignment (cosine > 0.866), sum and threshold
    with half the max norm. Returns (K×3) JAX array or empty.
    """
    norms = np.linalg.norm(points, axis=1, keepdims=True)
    valid = norms[:,0] > 0
    pts = points.copy()
    pts[valid] /= norms[valid]
    D = pts @ pts.T
    used = np.ones(len(points),bool)
    clusters = []
    while used.any():
        i = np.argmax(used)
        aligned = (D[i] > 0.866) & used
        if not aligned.any(): break
        clusters.append(points[aligned].sum(0))
        used[aligned] = False
    if not clusters:
        return jnp.zeros((0,3), jnp.float64)
    C = np.stack(clusters)
    norms = np.linalg.norm(C, axis=1)
    thresh = 0.5 * norms.max()
    C = C[norms >= thresh]
    return jnp.array(C, jnp.float64)

In [3]:
from tqdm import tqdm


tetrahedron_bis = compute_bispectrum(tetrahedron)
trigonal_plane_bis = compute_bispectrum(trigonal_plane)

best_loss = np.inf
best_geometry = None

def invert(bispectrum):
    init_pts = jax.random.normal(jax.random.PRNGKey(0), (15,3), jnp.float64)
    final_pts = invert_with_adam(bispectrum, init_pts)
    stacked = stack_points(np.array(final_pts))
    return stacked


inverted_tetrahedron = invert(tetrahedron_bis)
inverted_trigonal_plane = invert(trigonal_plane_bis)

In [4]:
visualize_geometry(inverted_tetrahedron)

In [5]:
import sys
import os
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.append(src_dir)
import spectra as spectra
from utils.plotters import visualize_signal
from utils.alignment import point_distance
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import numpy as np
from scipy.spatial.transform import Rotation
import itertools


def align_predicted_geometry(true_geometry: np.ndarray,
                             predicted_geometry: np.ndarray
                            ) -> tuple[np.ndarray, float]:
    """
    Aligns `predicted_geometry` onto `true_geometry` using only a rigid-body rotation.
    Points are first sorted by distance from the origin, normalized to unit radius,
    then each pair of corresponding direction‐vectors is scaled by a different
    random weight (to break any geometric symmetry) before computing the optimal
    rotation. Assumes both inputs are centered at the origin.

    Parameters
    ----------
    true_geometry : (N, 3) array_like
        Reference point cloud, centered at the origin.
    predicted_geometry : (N, 3) array_like
        Point cloud to align, also centered.

    Returns
    -------
    aligned_predicted : (N, 3) ndarray
        The `predicted_geometry` rotated to best match `true_geometry`.
    rmsd : float
        The RMSD between the (weighted) direction‐vectors after optimal alignment.
    """
    # Convert inputs
    true_pts = np.asarray(true_geometry, dtype=float)
    pred_pts = np.asarray(predicted_geometry, dtype=float)
    if true_pts.shape != pred_pts.shape or true_pts.shape[1] != 3:
        raise ValueError("Both inputs must have shape (N, 3)")

    # Compute norms (radii) without sorting
    true_radii = np.linalg.norm(true_pts, axis=1)
    pred_radii = np.linalg.norm(pred_pts, axis=1)
    
    # Normalize to unit vectors directly without reordering
    true_dirs = true_pts / true_radii[:, None]
    pred_dirs = pred_pts / pred_radii[:, None]    

    # Generate a different random weight for each point to break symmetry
    # (e.g. uniform on [0,1) or any other distribution)
    weights = np.random.rand(true_dirs.shape[0], 1)

    # Scale both sets by the same per-point weights
    true_weighted = true_dirs * weights
    pred_weighted = pred_dirs * weights

    # Compute optimal rotation on weighted directions
    rot, rmsd = Rotation.align_vectors(true_weighted, pred_weighted)

    # Apply rotation back to the original predicted points
    pred_rotated = rot.apply(pred_pts)

    return pred_rotated, rmsd


def alignment_mae(true_geometry, predicted_geometry):

    true_geometry = np.array(true_geometry)
    predicted_geometry = np.array(predicted_geometry)

    if true_geometry.shape != predicted_geometry.shape:
        return float('inf'), None

    # Try all permutations of the predicted geometry and select the one with lowest RMSD
    best_error = float('inf')
    best_aligned = None
    
    # Get all permutations of indices
    n_points = predicted_geometry.shape[0]
    for perm in itertools.permutations(range(n_points)):
        perm_predicted = predicted_geometry[list(perm)]
        aligned, _ = align_predicted_geometry(true_geometry, perm_predicted)
        error = jnp.mean(jnp.linalg.norm(aligned - true_geometry, axis=1))
        # error = np.mean(np.abs(aligned - true_geometry), axis=1)
        if error < best_error:
            best_error = error
            best_aligned = aligned

    return best_error, best_aligned

In [26]:
def interpolate_bispectra(bis1, bis2, n):
    """
    Linearly interpolate between two bispectrum vectors.
    
    Parameters:
    bis1: First bispectrum vector
    bis2: Second bispectrum vector
    n: Number of interpolated bispectra to return (including start and end)
    
    Returns:
    Array of n bispectra interpolated between bis1 and bis2
    """
    # Ensure we have at least 2 interpolated bispectra
    n = max(2, n)
    
    # Create parameter values from 0 to 1
    t_values = jnp.linspace(0, 1, n)
    
    # Create interpolated bispectra
    interpolated_bispectra = []
    for t in t_values:
        # Linear interpolation: bis1 + t * (bis2 - bis1)
        interpolated_bis = bis1 + t * (bis2 - bis1)
        interpolated_bispectra.append(interpolated_bis)
    
    return jnp.array(interpolated_bispectra)


def invert_and_align_bispectra(start_geom, end_geom, n_steps=50):
    """
    Interpolate between bispectra, invert them, and align successive geometries.
    
    Parameters:
    start_geom: Starting geometry
    end_geom: Ending geometry
    n_steps: Number of interpolation steps
    
    Returns:
    inverted_geoms: List of inverted geometries
    interpolated_bispectra: Array of interpolated bispectra
    """
    # Compute bispectra for start and end geometries
    start_bis = compute_bispectrum(start_geom)
    end_bis = compute_bispectrum(end_geom)
    
    # Interpolate between bispectra
    interpolated_bispectra = interpolate_bispectra(start_bis, end_bis, n_steps)
    
    # Invert each interpolated bispectrum
    print("Inverting interpolated bispectra...")
    inverted_geoms = []
    for i, bis in enumerate(tqdm(interpolated_bispectra)):
        inverted_geom = invert(bis)
        inverted_geoms.append(inverted_geom)
    
    return inverted_geoms, interpolated_bispectra

# Example usage: Interpolate between octahedron and trigonal prism bispectra
print("Interpolating between octahedron and trigonal prism bispectra...")
inverted_geoms, interpolated_bispectra = invert_and_align_bispectra(
    octahedron, 
    trigonal_plane, 
    n_steps=100  # Reduced for faster computation
)


Interpolating between octahedron and trigonal prism bispectra...
Inverting interpolated bispectra...


100%|██████████| 100/100 [01:01<00:00,  1.63it/s]


In [27]:
for i in range(len(inverted_geoms)):
    print(i, inverted_geoms[i].shape)

0 (6, 3)
1 (6, 3)
2 (6, 3)
3 (6, 3)
4 (6, 3)
5 (6, 3)
6 (6, 3)
7 (6, 3)
8 (6, 3)
9 (6, 3)
10 (6, 3)
11 (6, 3)
12 (6, 3)
13 (6, 3)
14 (6, 3)
15 (6, 3)
16 (6, 3)
17 (6, 3)
18 (6, 3)
19 (6, 3)
20 (6, 3)
21 (6, 3)
22 (6, 3)
23 (6, 3)
24 (6, 3)
25 (6, 3)
26 (6, 3)
27 (6, 3)
28 (6, 3)
29 (6, 3)
30 (6, 3)
31 (6, 3)
32 (6, 3)
33 (6, 3)
34 (6, 3)
35 (6, 3)
36 (6, 3)
37 (6, 3)
38 (6, 3)
39 (6, 3)
40 (6, 3)
41 (6, 3)
42 (6, 3)
43 (6, 3)
44 (6, 3)
45 (6, 3)
46 (6, 3)
47 (6, 3)
48 (6, 3)
49 (6, 3)
50 (6, 3)
51 (6, 3)
52 (6, 3)
53 (6, 3)
54 (6, 3)
55 (6, 3)
56 (6, 3)
57 (4, 3)
58 (4, 3)
59 (4, 3)
60 (4, 3)
61 (4, 3)
62 (4, 3)
63 (4, 3)
64 (4, 3)
65 (4, 3)
66 (4, 3)
67 (4, 3)
68 (4, 3)
69 (4, 3)
70 (4, 3)
71 (4, 3)
72 (4, 3)
73 (4, 3)
74 (3, 3)
75 (3, 3)
76 (3, 3)
77 (3, 3)
78 (3, 3)
79 (3, 3)
80 (3, 3)
81 (3, 3)
82 (3, 3)
83 (3, 3)
84 (3, 3)
85 (3, 3)
86 (3, 3)
87 (3, 3)
88 (3, 3)
89 (3, 3)
90 (3, 3)
91 (3, 3)
92 (3, 3)
93 (3, 3)
94 (3, 3)
95 (3, 3)
96 (3, 3)
97 (3, 3)
98 (3, 3)
99 (3, 3)


In [32]:
compare_four_geometries(inverted_geoms[56], '56/99', inverted_geoms[57], '57/99', inverted_geoms[73], '73/99', inverted_geoms[74], '74/99', lmax=4, show_points=False)

In [20]:
compare_four_geometries(inverted_geoms[0], '0', inverted_geoms[1], '1', inverted_geoms[2], '2', inverted_geoms[3], '3', lmax=4, show_points=False)

In [21]:
compare_four_geometries(inverted_geoms[4], '4', inverted_geoms[5], '5', inverted_geoms[6], '6', inverted_geoms[7], '7', lmax=4, show_points=False)