## 1. Introduction to e3nn-jax

[e3nn-jax](https://github.com/e3nn/e3nn-jax) is the [JAX](https://github.com/google/jax) implementation of the [e3nn](https://github.com/e3nn/e3nn) library that was originally written in PyTorch. JAX is a Python framework for numerical computing that allows for automatic differentiation and just-in-time compilation.

This notebook will provide a small introduction to the e3nn-jax library and show how to use it to build equivariant neural networks.
For more details on the e3nn library, please refer to the [e3nn paper](https://arxiv.org/abs/2207.09453). 


In [None]:
# Imports
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

e3nn provides a way to build neural networks that understand 3D rotations and reflections.

What does this mean? Let's consider the example of a molecule, which consists of atoms living in 3D space. The molecule can be rotated and reflected in space, but the underlying physics of the molecule does not change.

For example, the energy of the molecule does not change under rotations and reflections. We term such properties as 'invariants' or 'scalars'.


In [None]:
# Create a random rotation.
key = jax.random.PRNGKey(0) # Key for random number generation.
R = e3nn.rand_matrix(key)

In [None]:
# The notation 0e is a short form for (l, p) = (0, 1).
# l = 0 means that the quantity is invariant under rotations.
# p = 1 means that the quantity is invariant under reflections, (ie even).
energy = e3nn.IrrepsArray("0e", jnp.asarray([0.5]))
rotated_energy = energy.transform_by_matrix(R)
reflected_energy = energy.transform_by_matrix(-R)

print("Observed energy after rotation:", rotated_energy)
print("Expected energy after rotation:", energy)
print()
print("Observed energy after reflection:", reflected_energy)
print("Expected energy after reflection:", energy)


An example of something that would not remain unchanged under rotation are the forces acting on the atoms in the molecule.
These forces transform in a specific way under rotations and reflections: under a rotation $R$, the force represented by $f$ transforms to $R f$, where $R$ is represented as a rotation matrix.
We term such properties as 'vectors'.

In [None]:
# The notation 1o is a short form for (l, p) = (1, -1).
# l = 1 means that the quantity is transformed as a vector under rotations.
# p = 1 means that the quantity flips sign under reflections, (ie odd).
forces = e3nn.IrrepsArray("1o", jnp.asarray([[1.0, 0.0, 2.0],
                                             [0.0, 5.0, 3.0],
                                             [1.0, 1.0, 1.0]]))
rotated_forces = forces.transform_by_matrix(R)
reflected_forces = forces.transform_by_matrix(-R)

print("Observed forces after rotation:", rotated_forces.array)
print("Expected forces after rotation:", forces.array @ R.T)
print()
print("Observed forces after reflection:", reflected_forces.array)
print("Expected forces after reflection:", forces.array @ -R.T)

      
go.Figure(data=[go.Scatter3d(x=forces.array[:, 0], y=forces.array[:, 1], z=forces.array[:, 2], mode='markers', name='Original Forces'),
                go.Scatter3d(x=rotated_forces.array[:, 0], y=rotated_forces.array[:, 1], z=rotated_forces.array[:, 2], mode='markers', name='Rotated Forces')]).show()
                

To keep track of how these different objects transform under rotations and reflections, e3nn uses the concept of 'irreducible representations' (irreps) of the 3D orthonormal group, O(3). 
Each irrep corresponds to a specific way in which the object transforms under rotations and reflections. 

An irrep is characterized by two numbers: $(l, p)$, where $l$ is the 'angular momentum' and $p$ is the parity. 
$l$ can take values $0, 1, 2, ...$ and $p$ can take values $+1$ or $-1$. The scalar irrep $(0, +1)$ corresponds to scalars, the vector irrep $(1, -1)$ corresponds to vectors, and so on.
If you have heard of [pseudoscalars](https://en.wikipedia.org/wiki/Pseudoscalar) or [pseudovectors](https://en.wikipedia.org/wiki/Pseudovector), these correspond to the irrep $(0, -1)$ and $(1, +1)$ respectively.

In e3nn, we keep track of how each quantity transforms under rotations and reflections by assigning it an irrep. The combination of the irreps with the actual data is called an IrrepsArray.

In [None]:
x = e3nn.IrrepsArray("0e + 1o", jnp.asarray([0.5, 1.0, 0.0, 2.0]))
print(x.irreps, x.array)
print("Printing the chunks:")
for irrep, chunk in zip(x.irreps, x.chunks):
    print(irrep, chunk)

Given two IrrepsArrays, we can create a new IrrepsArray via the equivariant Clebsch-Gordan tensor product. This tensor product reduces a tensor product of two irreps into a sum of irreps. This is a key operation in e3nn, and it can be used to build equivariant neural networks. 

In [None]:
x1 = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 2.0, 3.0]))
x2 = e3nn.IrrepsArray("1o", jnp.asarray([2.0, 5.0, 9.0]))

y = e3nn.tensor_product(x1, x2)
print(y)

You will notice that the (scalar) "0e" component above corresponds to the dot product of the two vectors. And clearly, the dot product is invariant under rotations and reflections.

In [None]:
e3nn.dot(x1, x2) / jnp.sqrt(3), y.slice_by_chunk[:1]

Similarly, the "1e" component corresponds to the cross product of the two vectors. The cross product is a pseudovector because it does not change sign under reflections, unlike a vector:
$$
\text{cross}(\mathbf{-a}, \mathbf{-b}) = \mathbf{-a} \times \mathbf{-b} = \mathbf{a} \times \mathbf{b} = \text{cross}(\mathbf{a}, \mathbf{b})
$$

In [None]:
e3nn.cross(x1, x2) / jnp.sqrt(2), y.slice_by_chunk[1:2]

The "2e" component corresponds to the symmetric traceless part of the outer product of the two vectors. This does not transform as a scalar or a vector, but as a rank-2 tensor. This is an example of a higher-order tensor that can be used to build more complex equivariant neural networks!

In [None]:
y.slice_by_chunk[2:3]

We see that the "2e" component has dimension 5.
In general, an irrep of the form $(l, p)$ has dimension $2l + 1$.

If you are interested in the various kinds of equivariant tensor products that people have designed, you should check out our [recent paper](https://openreview.net/forum?id=0HHidbjwcf)!