In [249]:
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp

## Tensor Product in $O(L^2)$ time!

Fix $l_1, l_2, l_3$ in $l_1 \times l_2 \rightarrow l_3$.
Naively,
- Loop over $m_1$: $O(l_1)$ iterations
- Loop over $m_2$: $O(l_2)$ iterations
- Loop over $m_3$: $O(l_3)$ iterations
- Compute output of TP at $m_1$, $m_2$, $m_3$.
But this is $O(L^3)$.

However, the selection rules for $m_3$ are:
$$
m_3 = \pm m_1 \pm m_2
$$
So, instead:
- Loop over $m_1$: $O(l_1)$ iterations
- Loop over $m_3$: $O(l_3)$ iterations
- Loop over $m_2$: $O(4)$ iterations.
- Compute output of TP at $m_1$, $m_2$, $m_3$. 
This is $O(L^2)$.

In [250]:
x1 = e3nn.normal("2x2o", jax.random.PRNGKey(0))
x2 = e3nn.normal("3x2o", jax.random.PRNGKey(1))
filter_ir_out = None

In [251]:
def _prepare_inputs(input1, input2):
    input1 = e3nn.as_irreps_array(input1)
    input2 = e3nn.as_irreps_array(input2)

    leading_shape = jnp.broadcast_shapes(input1.shape[:-1], input2.shape[:-1])
    input1 = input1.broadcast_to(leading_shape + (-1,))
    input2 = input2.broadcast_to(leading_shape + (-1,))
    return input1, input2, leading_shape


def _validate_filter_ir_out(filter_ir_out):
    if filter_ir_out is not None:
        if isinstance(filter_ir_out, str):
            filter_ir_out = e3nn.Irreps(filter_ir_out)
        if isinstance(filter_ir_out, e3nn.Irrep):
            filter_ir_out = [filter_ir_out]
        filter_ir_out = [e3nn.Irrep(ir) for ir in filter_ir_out]
    return filter_ir_out

## Original TP

In [252]:
def tensor_product(
    input1: e3nn.IrrepsArray,
    input2: e3nn.IrrepsArray,
    *,
    filter_ir_out=None,
) -> e3nn.IrrepsArray:
    input1, input2, leading_shape = _prepare_inputs(input1, input2)
    filter_ir_out = _validate_filter_ir_out(filter_ir_out)

    irreps_out = []
    chunks = []
    for (mul_1, ir_1), x1 in zip(input1.irreps, input1.chunks):
        for (mul_2, ir_2), x2 in zip(input2.irreps, input2.chunks):
            for ir_out in ir_1 * ir_2:
                if filter_ir_out is not None and ir_out not in filter_ir_out:
                    continue

                irreps_out.append((mul_1 * mul_2, ir_out))

                if x1 is not None and x2 is not None:
                    cg = e3nn.clebsch_gordan(ir_1.l, ir_2.l, ir_out.l)
                    cg = cg.astype(x1.dtype)
                    chunk = jnp.einsum("...ui , ...vj , ijk -> ...uvk", x1, x2, cg)
                    chunk = jnp.reshape(
                        chunk, chunk.shape[:-3] + (mul_1 * mul_2, ir_out.dim)
                    )
                else:
                    chunk = None

                chunks.append(chunk)

    output = e3nn.from_chunks(irreps_out, chunks, leading_shape, input1.dtype)
    output = output.sort()
    return output

In [253]:
jax.jit(tensor_product)(x1, x2, filter_ir_out=filter_ir_out)

6x0e+6x1e+6x2e+6x3e+6x4e
[-0.30594468  0.65587217  0.8156828   0.43714806 -0.04631896 -0.8187588
  0.07375173  0.13654615 -0.26689768 -0.1876133   0.10305031 -0.3133805
  0.33506107 -0.03605647 -0.21220711 -0.37972197 -0.19940192 -0.18087739
  0.40132776 -0.49416944  0.3329233  -0.05897088 -0.5503977   0.20598856
  0.32779682 -0.10878956 -0.18368451  0.1075291   0.03287186  0.06962056
  0.06809835 -0.06248124  0.5807315   0.11280058  0.11679052  0.26604736
 -0.26138645  0.5083903   0.19740054 -0.324839    0.18261841  0.19158448
  0.04468798 -0.29023015 -0.12107493  0.18791169  0.35540935 -0.3038459
 -0.49011904 -0.05223124 -0.24283403  0.20550597 -0.26107872 -0.75380576
  0.24322642  0.08006992  0.04031758  0.28545618  0.1339949  -0.04789218
  0.12223046  0.2194841  -0.04944647 -0.0719676   0.41960192  0.01269295
  0.0945653   0.273924    0.3423461   0.18685225  0.00434232  0.33810347
  0.15372635 -0.08647925  0.21101776 -0.09164129  0.01596775  0.02483024
 -0.19484337 -0.40859804 -0.0

# Optimized TP

In [254]:
def tensor_product_optimized(
    input1: e3nn.IrrepsArray,
    input2: e3nn.IrrepsArray,
    *,
    filter_ir_out=None,
) -> e3nn.IrrepsArray:
    input1, input2, leading_shape = _prepare_inputs(input1, input2)
    filter_ir_out = _validate_filter_ir_out(filter_ir_out)

    irreps_out = []
    chunks = []
    for (mul_1, ir_1), x1 in zip(input1.irreps, input1.chunks):
        for (mul_2, ir_2), x2 in zip(input2.irreps, input2.chunks):
            if x1 is None or x2 is None:
                continue

            x1_t = jnp.moveaxis(x1, -1, 0)
            x2_t = jnp.moveaxis(x2, -1, 0)

            for ir_out in ir_1 * ir_2:
                if filter_ir_out is not None and ir_out not in filter_ir_out:
                    continue

                irreps_out.append((mul_1 * mul_2, ir_out))

                l1, l2, l3 = ir_1.l, ir_2.l, ir_out.l
                cg = e3nn.clebsch_gordan(l1, l2, l3)
                chunk = jnp.zeros((2 * l3 + 1, x1.shape[-2], x2.shape[-2]))
                for m3 in range(-l3, l3 + 1):
                    sum = 0
                    for m1 in range(-l1, l1 + 1):
                        for m2 in set([m3 - m1, m3 + m1, -m3 + m1, -m3 - m1]):
                            if m2 < -l2 or m2 > l2:
                                continue

                            path = jnp.einsum(
                                "u...,v... -> uv...",
                                x1_t[l1 + m1, ...],
                                x2_t[l2 + m2, ...],
                            )
                            path *= cg[l1 + m1][l2 + m2][l3 + m3]
                            sum += path
                    chunk = chunk.at[l3 + m3].set(sum)

                chunk = jnp.moveaxis(chunk, 0, -1)
                chunk = jnp.reshape(
                    chunk, chunk.shape[:-3] + (mul_1 * mul_2, ir_out.dim)
                )
                chunks.append(chunk)

    output = e3nn.from_chunks(irreps_out, chunks, leading_shape, input1.dtype)
    output = output.sort()
    return output

In [255]:
jax.jit(tensor_product_optimized)(x1, x2, filter_ir_out=filter_ir_out)

6x0e+6x1e+6x2e+6x3e+6x4e
[-0.30594468  0.65587217  0.8156828   0.43714795 -0.04631888 -0.8187588
  0.07375173  0.13654613 -0.2668977  -0.1876133   0.10305033 -0.3133805
  0.33506107 -0.03605647 -0.2122071  -0.37972197 -0.19940192 -0.18087742
  0.40132776 -0.49416947  0.3329233  -0.05897085 -0.5503977   0.20598854
  0.32779682 -0.10878956 -0.18368451  0.10752911  0.03287186  0.06962059
  0.06809834 -0.06248123  0.58073145  0.11280058  0.11679052  0.26604736
 -0.26138645  0.50839025  0.19740053 -0.32483903  0.18261841  0.1915845
  0.04468796 -0.29023015 -0.12107492  0.18791166  0.35540938 -0.30384594
 -0.49011907 -0.05223123 -0.24283405  0.20550597 -0.26107872 -0.7538057
  0.24322641  0.08006992  0.04031758  0.28545618  0.1339949  -0.04789218
  0.12223046  0.2194841  -0.04944647 -0.07196759  0.41960192  0.01269295
  0.09456529  0.273924    0.34234613  0.18685226  0.00434231  0.3381035
  0.15372634 -0.08647925  0.21101774 -0.09164129  0.01596775  0.02483024
 -0.19484337 -0.40859804 -0.026

In [256]:
assert jnp.allclose(
    tensor_product(x1, x2, filter_ir_out=filter_ir_out).array,
    tensor_product_optimized(x1, x2, filter_ir_out=filter_ir_out).array,
)

(
    tensor_product(x1, x2, filter_ir_out=filter_ir_out).array
    - tensor_product_optimized(x1, x2, filter_ir_out=filter_ir_out).array
)

Array([ 0.0000000e+00,  0.0000000e+00,  5.9604645e-08,  8.9406967e-08,
       -7.4505806e-08,  0.0000000e+00, -7.4505806e-09,  1.4901161e-08,
        0.0000000e+00,  1.4901161e-08, -1.4901161e-08,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00, -2.9802322e-08,  0.0000000e+00,
        0.0000000e+00,  2.9802322e-08,  2.9802322e-08,  2.9802322e-08,
        0.0000000e+00,  1.4901161e-08, -5.9604645e-08,  4.4703484e-08,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        3.7252903e-09, -1.4901161e-08,  0.0000000e+00, -1.1175871e-08,
        5.9604645e-08, -7.4505806e-09,  0.0000000e+00,  0.0000000e+00,
       -2.9802322e-08,  0.0000000e+00,  1.4901161e-08,  2.9802322e-08,
        0.0000000e+00,  1.4901161e-08,  1.1175871e-08,  0.0000000e+00,
       -1.4901161e-08,  4.4703484e-08, -2.9802322e-08,  2.9802322e-08,
        0.0000000e+00, -1.4901161e-08,  1.4901161e-08,  0.0000000e+00,
        0.0000000e+00, -5.9604645e-08,  0.0000000e+00,  0.0000000e+00,
      