In [None]:
import flax
from flax import linen as nn
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn

In [None]:
class WrappedFCTensorProduct(nn.Module):

    irreps_out: e3nn.Irreps

    @nn.compact
    def __call__(self, input_1: e3nn.IrrepsArray, input_2: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
        output = e3nn.tensor_product(input_1, input_2)
        output = e3nn.flax.Linear(irreps_out=self.irreps_out)(output)
        return output

In [None]:
input_1 = e3nn.IrrepsArray("0e + 1e + 2e + 3e", jnp.ones(16))
input_2 = e3nn.IrrepsArray("0e + 1e + 2e + 3e", jnp.ones(16))
tp = WrappedFCTensorProduct(irreps_out="2x0e + 5x1e + 8x2e + 11x3e")
params = tp.init(jax.random.PRNGKey(0), input_1, input_2)
output = tp.apply(params, input_1, input_2)
output

In [None]:
class WrappedFCTensorProductWithSetup(nn.Module):

    irreps_out: e3nn.Irreps

    def setup(self):
        self.linear = e3nn.flax.Linear(irreps_out=self.irreps_out)
    
    def __call__(self, input_1: e3nn.IrrepsArray, input_2: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
        output = e3nn.tensor_product(input_1, input_2)
        output = self.linear(output)
        return output

In [None]:
input_1 = e3nn.IrrepsArray("0e + 1e + 2e + 3e", jnp.ones(16))
input_2 = e3nn.IrrepsArray("0e + 1e + 2e + 3e", jnp.ones(16))
tp = WrappedFCTensorProductWithSetup(irreps_out="2x0e + 5x1e + 8x2e + 11x3e")
params = tp.init(jax.random.PRNGKey(0), input_1, input_2)
output = tp.apply(params, input_1, input_2)
output