In [1]:
"""
so3_vector_conv_jax.py
======================
Pure‑JAX implementation of an **SO(3)‑equivariant convolution** that maps
vector (ℓ = 1) features → vector (ℓ = 1) features.

The kernel is constrained to the irreducible basis
    K(r̂) = a(r) I + b(r) [r̂]× + c(r) Q(r̂)
with
  • I         : 3×3 identity (ℓ_f = 0)
  • [r̂]×     : skew‑symmetric cross‑product matrix (ℓ_f = 1)
  • Q(r̂)     : quadrupole = 3 r̂ ⊗ r̂ − I       (ℓ_f = 2)
so that the layer is *exactly equivariant* under global rotations.

No external equivariant libraries (e3nn, lie‑torch, etc.) are used –
only JAX + Flax.
"""

import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Callable, Optional
jax.config.update("jax_enable_x64", True)


# -----------------------------------------------------------------------------
#  Helper: skew‑symmetric cross‑product matrix
# -----------------------------------------------------------------------------

def cross_matrix(v: jnp.ndarray) -> jnp.ndarray:
    """Return the 3×3 matrix [v]× such that [v]×w = v × w.
    v: [..., 3]
    Returns: [..., 3, 3]
    """
    x, y, z = v[..., 0], v[..., 1], v[..., 2]
    O = jnp.zeros_like(x)
    mat = jnp.stack(
        [jnp.stack([ O, -z,  y], axis=-1),
         jnp.stack([ z,  O, -x], axis=-1),
         jnp.stack([-y,  x,  O], axis=-1)],
        axis=-2)
    return mat  # shape [..., 3, 3]

# -----------------------------------------------------------------------------
#  Radial model g(r)  →  (a(r), b(r), c(r))
# -----------------------------------------------------------------------------

class RadialMLP(nn.Module):
    """Simple 1‑D MLP that outputs the three scalar radial weights a, b, c.
    By default uses two hidden layers with SiLU activation.
    """
    hidden: int = 32
    act: Callable = nn.silu

    @nn.compact
    def __call__(self, r: jnp.ndarray) -> jnp.ndarray:
        # r: [..., 1]
        x = nn.Dense(self.hidden)(r)
        x = self.act(x)
        x = nn.Dense(self.hidden)(x)
        x = self.act(x)
        x = nn.Dense(3)(x)  # → [..., 3]  → (a, b, c)
        return x

# -----------------------------------------------------------------------------
#  SO(3)‑equivariant vector convolution layer
# -----------------------------------------------------------------------------

class SO3VectorConv(nn.Module):
    """Equivariant convolution for ℓ = 1 vector fields.

    Args
    ----
    radial_model : flax Module mapping r[...,1] → (a, b, c)
    cutoff       : optional distance cutoff; interactions beyond this are ignored
    """
    radial_model: nn.Module = RadialMLP()
    cutoff: Optional[float] = None

    @nn.compact
    def __call__(self,
                 feats: jnp.ndarray,   # (N, 3)
                 pos: jnp.ndarray) -> jnp.ndarray:  # (N, 3)
        N = pos.shape[0]
        # Pairwise relative vectors r_ij = x_j − x_i
        rel = pos[:, None, :] - pos[None, :, :]        # (N, N, 3)
        dist = jnp.linalg.norm(rel, axis=-1, keepdims=True)  # (N, N, 1)
        # Avoid division by zero for i=j
        inv_dist = jnp.where(dist > 0, 1.0 / dist, 0.0)
        r_hat = rel * inv_dist                            # (N, N, 3)

        # Optional cutoff mask
        if self.cutoff is not None:
            mask = (dist[..., 0] < self.cutoff).astype(feats.dtype)  # (N, N)
        else:
            mask = jnp.ones((N, N), dtype=feats.dtype)
        mask = mask[..., None, None]  # (N, N, 1, 1) for broadcasting

        # Radial weights a(r), b(r), c(r)
        abc = self.radial_model(dist)           # (N, N, 3)
        a, b, c = [abc[..., i] for i in range(3)]  # each (N, N)

        # Basis tensors
        I = jnp.eye(3, dtype=feats.dtype)              # (3, 3)
        I = jnp.broadcast_to(I, (N, N, 3, 3))          # (N, N, 3, 3)
        S = cross_matrix(r_hat)                       # skew‑symm (ℓ_f = 1)
        Q = 3.0 * (r_hat[..., :, None] * r_hat[..., None, :]) - I  # quadrupole (ℓ_f = 2)

        # Kernel K_ij = a I + b S + c Q
        K = (a[..., None, None] * I +
             b[..., None, None] * S +
             c[..., None, None] * Q) * mask            # (N, N, 3, 3)

        # Contraction: out_i = Σ_j  K_ij · feats_j
        out = jnp.einsum('ijab,jb->ia', K, feats)      # (N, 3)
        return out

# -----------------------------------------------------------------------------
#  Example usage
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    import jax.random as rnd

    key = rnd.PRNGKey(3)
    N = 4
    # Random demo data
    pos = rnd.normal(key, (N, 3))
    feats = rnd.normal(key, (N, 3))  # vector features (ℓ=1)

    conv = SO3VectorConv(RadialMLP(hidden=64), cutoff=2.5)
    params = conv.init(key, feats, pos)  # initialise parameters
    out = conv.apply(params, feats, pos)

    # Quick equivariance check: rotate input and compare
    import math
    phi = math.pi / 3
    R = jnp.array([[ jnp.cos(phi), -jnp.sin(phi), 0.0],
                   [ jnp.sin(phi),  jnp.cos(phi), 0.0],
                   [          0.0,           0.0, 1.0]])
    feats_rot = feats @ R.T  # rotate vectors
    pos_rot = pos @ R.T      # rotate coordinates

    out1 = conv.apply(params, feats_rot, pos_rot)      # then apply conv
    out2 = conv.apply(params, feats, pos) @ R.T        # apply conv then rotate

    err = jnp.max(jnp.abs(out1 - out2))
    print("max|Δ| ≈", err, out1)

max|Δ| ≈ 8.881784197001252e-16 [[-0.61660119  1.40856409  0.50063294]
 [ 1.03411431  0.02944124  0.57241899]
 [ 1.72121069 -0.86492014  0.33603109]
 [ 0.          0.          0.        ]]
