In [1]:
"""
Bispectrum Inversion for Molecular Environments - Optimized for Apple M2 Pro
"""
import os
import time
import pickle

import numpy as np
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
from functools import partial
import math

# Add src directory to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.append(src_dir)

from spectra import sum_of_diracs, trispectrum

#------------------------------------------------------------------------------
# JAX & Hardware Configuration
#------------------------------------------------------------------------------

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_default_matmul_precision', 'high')

#------------------------------------------------------------------------------
# Constants
#------------------------------------------------------------------------------

LMAX = 2
MAX_ENV_SIZE = 6
NUM_INITS = 20

#------------------------------------------------------------------------------
# Helpers
#------------------------------------------------------------------------------

def determine_optimal_batch_size() -> int:
    """Pick a default batch size: 8 on M1/M2 silicon, else 5."""
    try:
        import platform
        return 8 if platform.processor() == 'arm' else 5
    except:
        return 5

@partial(jax.jit, static_argnames=('max_size',))
def pad_environment(env, max_size=MAX_ENV_SIZE):
    """
    Take an (N×3) array (N may be <, =, or > max_size),
    pad it with zeros so it has at least (max_size) rows,
    then slice to exactly (max_size×3).
    """
    env = jnp.array(env, dtype=jnp.float64)             # (N,3)
    # Pad below with max_size zeros → shape = (N+max_size, 3)
    padded = jnp.pad(env, ((0, max_size), (0, 0)))
    # Then take the first max_size rows: shape = (max_size, 3)
    return padded[:max_size]

@jax.jit
def compute_trispectrum_for(env):
    """Compute trispectrum for an (M×3) array, zeros allowed."""
    return trispectrum(sum_of_diracs(env, LMAX))

@partial(jax.jit, static_argnums=(2,))
def invert_with_adam(true_bs, init_pts, num_iters=10000):
    """Adam-based inversion."""
    lr = 1e-2
    opt = optax.adam(lr)
    state = opt.init(init_pts)
    def loss_fn(x):
        pred_bs = trispectrum(sum_of_diracs(x, LMAX))
        return jnp.mean(jnp.abs(true_bs - pred_bs))
    def step(carry, _):
        pts, st = carry
        g = jax.grad(loss_fn)(pts)
        updates, st = opt.update(g, st, pts)
        return (optax.apply_updates(pts, updates), st), None
    (final_pts, _), _ = jax.lax.scan(step, (init_pts, state), None, length=num_iters)
    return final_pts

def stack_points(points: np.ndarray) -> jnp.ndarray:
    """
    Cluster points by alignment (cosine > 0.866), sum and threshold
    with half the max norm. Returns (K×3) JAX array or empty.
    """
    norms = np.linalg.norm(points, axis=1, keepdims=True)
    valid = norms[:,0] > 0
    pts = points.copy()
    pts[valid] /= norms[valid]
    D = pts @ pts.T
    used = np.ones(len(points),bool)
    clusters = []
    while used.any():
        i = np.argmax(used)
        aligned = (D[i] > 0.866) & used
        if not aligned.any(): break
        clusters.append(points[aligned].sum(0))
        used[aligned] = False
    if not clusters:
        return jnp.zeros((0,3), jnp.float64)
    C = np.stack(clusters)
    norms = np.linalg.norm(C, axis=1)
    thresh = 0.5 * norms.max()
    C = C[norms >= thresh]
    return jnp.array(C, jnp.float64)

In [9]:
from tqdm import tqdm

octahedron = np.array([
    [1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
    [-1, 0, 0],
    [0, -1, 0],
    [0, 0, -1]
])

master_rng = jax.random.PRNGKey(0)

true_ts = compute_trispectrum_for(octahedron)

best_loss = np.inf
best_geometry = None


# 1) split RNG, sample a fresh init
master_rng, subkey = jax.random.split(master_rng)
init_pts = jax.random.normal(subkey, (15,3), jnp.float64)

# 2) first inversion
final_pts = invert_with_adam(true_ts, init_pts, num_iters=10000)

# 3) cluster the result
stacked = stack_points(np.array(final_pts))

# 4) second inversion on clustered points
if stacked.shape[0] > 0:
    pred_geometry = invert_with_adam(true_ts, stacked, num_iters=10000)
    loss = np.mean(np.abs(compute_trispectrum_for(pred_geometry) - true_ts))
else:
    loss = np.inf


# 5) keep if best
if loss < best_loss:
    best_loss = loss
    best_geometry = pred_geometry

100%|██████████| 20/20 [04:32<00:00, 13.62s/it]


In [10]:
best_loss

np.float64(1.367926608622252e-06)

In [11]:
from utils.plotters import visualize_geometry

fig = visualize_geometry(best_geometry, lmax=6)
fig.show()