## 1. Introduction to e3nn-jax

[e3nn-jax](https://github.com/e3nn/e3nn-jax) is the [JAX](https://github.com/google/jax) implementation of the [e3nn](https://github.com/e3nn/e3nn) library that was originally written in PyTorch. JAX is a Python framework for numerical computing that allows for automatic differentiation and just-in-time compilation.

This notebook will provide a small introduction to the e3nn-jax library and show how to use it to build equivariant neural networks.
For more details on the e3nn library, please refer to the [e3nn paper](https://arxiv.org/abs/2207.09453). 


In [None]:
# Setup.
!pip install e3nn_jax jax plotly numpy<2

In [None]:
# Imports
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
import plotly.subplots

$$
\def\R{{\mathbf{R}}}
$$

Let's start with a brief introduction to 3D rotations.

**Rotations**: A rotation in 3 dimensions can be described by a 3x3 orthogonal matrix $\R$:
$$
\R^T \R = \R \R^T = \mathbb{I}
$$
such that the determinant of $\R$ is 1.

The group of all $3 \times 3$ matrices with determinant $1$ is called $SO(3)$, the special orthogonal group.

The group of all $3 \times 3$ matrices with determinant $1$ or $-1$ is called $O(3)$, the orthogonal group. 

It turns out that every orthogonal matrix in 3 dimensions is either a rotation (ie has determinant 1), or is a product of a rotation and the inversion $-I$:
$$
-I = \begin{pmatrix}
-1 & 0 & 0 \\
0 & -1 & 0 \\
0 & 0 & -1
\end{pmatrix}
$$
which is a reflection through the origin.

Often, people use the term "improper rotation" to refer to a rotation followed by a reflection, ie a matrix in $O(3)$ with determinant $-1$. Here, we will use the term "rotation" to refer to a matrix in $SO(3)$.

In [None]:
# Plot three vectors, and rotate / invert them.
# Define two vectors
v1 = np.asarray([2, 3, 1])
v2 = np.asarray([1, -1, 1])
v3 = np.asarray([1, 1, 2])

# Create a random rotation.
key = jax.random.PRNGKey(0)  # Key for random number generation.
R = e3nn.rand_matrix(key)

# Create a list of vectors
vectors = [v1, v2, v3]
names = ["v1", "v2", "v3"]
rotated_vectors = [R @ vec for vec in vectors]
rotated_names = ["R v1", "R v2", "R v3"]

colors = ["red", "blue", "green"]

# Create the 3D scatter plot
fig = plotly.subplots.make_subplots(
    rows=1,
    cols=2,
    specs=[[{"type": "scene"}, {"type": "scene"}]],
    subplot_titles=("Original Vectors", "Rotated Vectors"),
)

for col, [vectors_list, names_list] in enumerate(
    zip([vectors, rotated_vectors], [names, rotated_names]), start=1
):
    for vec, color, name in zip(vectors_list, colors, names_list):
        # Plot the vectors
        fig.add_trace(
            go.Scatter3d(
                x=[0, vec[0]],
                y=[0, vec[1]],
                z=[0, vec[2]],
                mode="lines",
                line=dict(width=5, color=color),
                name=name,
                showlegend=(col == 1),
            ),
            row=1,
            col=col,
        )

        # Add dots at the end of each vector
        fig.add_trace(
            go.Scatter3d(
                x=[vec[0]],
                y=[vec[1]],
                z=[vec[2]],
                mode="markers",
                marker=dict(size=5, color=color),
                name=name + " endpoint",
                showlegend=False,
            ),
            row=1,
            col=col,
        )


fig.update_layout(title_text="Vectors and their Rotations")
fig.update_layout(
    scene=dict(
        xaxis=dict(range=[-6, 6]),
        yaxis=dict(range=[-6, 6]),
        zaxis=dict(range=[-6, 6]),
    )
)
fig.update_layout(
    scene2=dict(
        xaxis=dict(range=[-6, 6]),
        yaxis=dict(range=[-6, 6]),
        zaxis=dict(range=[-6, 6]),
    )
)

# Show the plot
fig.show()

Note how the angles between the vectors are preserved under rotations. This is because the dot product of two vectors is invariant under rotations.

### Irreducible Representations

e3nn provides a way to build neural networks that understand 3D rotations. 

What does this mean? Let's consider the example of a molecule, which consists of atoms living in 3D space. The molecule can be rotated and reflected in space, but the underlying physics of the molecule does not change. Our choice of coordinates is arbitrary, and we should be able to build models that respect this arbitrariness.

For example, the energy of the molecule does not change under rotations and inversion. We term such properties as 'invariants' or 'scalars', because under rotations $R$ they transform as:
$$
E \mapsto_\R E
$$
and under inversion $p$ they transform as:
$$
E \mapsto_{-I} E
$$
We say that scalars have even parity under inversion.

In [None]:
# Don't worry about the notation below for now.
# We are just telling e3nn how energies transform under rotations and inversions.
# The notation 0e is a short form for (l, p) = (0, 1).
# l = 0 means that the quantity is invariant under rotations.
# p = 1 means that the quantity is invariant under reflections, (ie even).
energy = e3nn.IrrepsArray("0e", jnp.asarray([0.5]))
rotated_energy = energy.transform_by_matrix(R)
reflected_energy = energy.transform_by_matrix(-R)

print("Observed energy after rotation:", rotated_energy)
print("Expected energy after rotation:", energy)
print()
print("Observed energy after reflection:", reflected_energy)
print("Expected energy after reflection:", energy)

A [pseudo-scalar](https://en.wikipedia.org/wiki/Pseudoscalar) is a quantity that is invariant under rotation, but changes sign under inversions:
$$
H \mapsto_{-I} -H
$$
Pseudo-scalars have odd parity under inversion.
An example of a pseudo-scalar is the handedness of a helix:

In [None]:
# Plot helices.
t = np.linspace(0, 10, 1000)
x = np.cos(t)
y = np.sin(t)
z = t

# Create subplots
fig = plotly.subplots.make_subplots(
    rows=1,
    cols=2,
    specs=[[{"type": "scene"}, {"type": "scene"}]],
    subplot_titles=("Left-Handed Helix", "Right-Handed Helix"),
)


# Add left-handed helix (note the negative y values)
fig.add_trace(
    go.Scatter3d(
        x=x,
        y=-y,
        z=z,
        mode="lines",
        line=dict(width=5, color=z, colorscale="plasma"),
        showlegend=False,
    ),
    row=1,
    col=1,
)


# Add right-handed helix
fig.add_trace(
    go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode="lines",
        line=dict(width=5, color=z, colorscale="plasma"),
        showlegend=False,
    ),
    row=1,
    col=2,
)

# Show the plot
axis = dict(
    title="",
    showticklabels=False,
    showgrid=False,
    zeroline=False,
    backgroundcolor="rgba(255,255,255,255)",
)
fig.update_layout(
    scene=dict(
        xaxis=axis,
        yaxis=axis,
        zaxis=axis,
    ),
    scene2=dict(
        xaxis=axis,
        yaxis=axis,
        zaxis=axis,
    ),
)
fig.show()

An example of something that would not remain unchanged under rotation are the forces acting on the atoms in the molecule.
These forces transform in a specific way under rotations and reflections. Under a rotation $\R$, the force represented by the vector $f$ transforms to $\R f$, where $\R$ is represented as a rotation matrix:
$$
f \mapsto_\R \R f
$$
We term such properties as 'vectors'.

In [None]:
# The notation 1o is a short form for (l, p) = (1, -1).
# l = 1 means that the quantity is transformed as a vector under rotations.
# p = 1 means that the quantity flips sign under reflections, (ie odd).
forces = e3nn.IrrepsArray(
    "1o", jnp.asarray([[1.0, 0.0, 2.0], [0.0, 5.0, 3.0], [1.0, 1.0, 1.0]])
)
rotated_forces = forces.transform_by_matrix(R)
reflected_forces = forces.transform_by_matrix(-R)

print("Observed forces after rotation:", rotated_forces.array)
print("Expected forces after rotation:", forces.array @ R.T)
print()
print("Observed forces after reflection:", reflected_forces.array)
print("Expected forces after reflection:", forces.array @ -R.T)

Note that under inversion, the force vector $f$ transforms to $-f$, so vectors have odd parity under inversion.

A [pseudo-vector](https://en.wikipedia.org/wiki/Pseudovector) is a quantity that transforms as a vector under rotations, but does not change sign under reflections:
$$
v \mapsto_\R \R v, \quad v \mapsto_{-I} v
$$

An example of a pseudo-vector is the cross product of two vectors.
$$
\text{cross}(\mathbf{-a}, \mathbf{-b}) = \mathbf{-a} \times \mathbf{-b} = \mathbf{a} \times \mathbf{b} = \text{cross}(\mathbf{a}, \mathbf{b})
$$

In [None]:
# Plot the cross product of two vectors.
# Define two vectors
v1 = np.array([2, 3, 1])
v2 = np.array([1, -1, 1])

# Calculate the cross product
v3 = np.cross(v1, v2)

# Create a list of vectors
vectors = [v1, v2, v3]
names = ["v1", "v2", "cross(v1, v2)"]
inverse_vectors = [-v1, -v2, np.cross(-v1, -v2)]
inverse_names = ["-v1", "-v2", "cross(-v1, -v2)"]

colors = ["red", "blue", "green"]

# Create the 3D scatter plot
fig = plotly.subplots.make_subplots(
    rows=1,
    cols=2,
    specs=[[{"type": "scene"}, {"type": "scene"}]],
    subplot_titles=("Original Vectors", "After Inverting v1 and v2"),
)

for col, [vectors_list, names_list] in enumerate(
    zip([vectors, inverse_vectors], [names, inverse_names]), start=1
):
    for vec, color, name in zip(vectors_list, colors, names_list):
        # Plot the vectors
        fig.add_trace(
            go.Scatter3d(
                x=[0, vec[0]],
                y=[0, vec[1]],
                z=[0, vec[2]],
                mode="lines",
                line=dict(width=5, color=color),
                name=name,
                showlegend=(col == 1),
            ),
            row=1,
            col=col,
        )

        # Add dots at the end of each vector
        fig.add_trace(
            go.Scatter3d(
                x=[vec[0]],
                y=[vec[1]],
                z=[vec[2]],
                mode="markers",
                marker=dict(size=5, color=color),
                name=name + " endpoint",
                showlegend=False,
            ),
            row=1,
            col=col,
        )


fig.update_layout(title_text="The Cross Product is a Pseudo-Vector")
fig.update_layout(
    scene=dict(
        xaxis=dict(range=[-6, 6], **axis),
        yaxis=dict(range=[-6, 6], **axis),
        zaxis=dict(range=[-6, 6], **axis),
    )
)
fig.update_layout(
    scene2=dict(
        xaxis=dict(range=[-6, 6], **axis),
        yaxis=dict(range=[-6, 6], **axis),
        zaxis=dict(range=[-6, 6], **axis),
    )
)

# Show the plot
fig.show()

To keep track of how these different quantity transform under rotations and reflections, e3nn uses the concept of 'irreducible representations' (irreps) of the group O(3). Each irrep corresponds to a specific way in which the quantity transforms under rotations and reflections. For example, is the given quantity a scalar, vector, pseudo-scalar, or pseudo-vector?
[Representation theory](https://sites.ualberta.ca/~vbouchar/MAPH464/section-representation-so3-tensors.html) tells us that there are actually many more kinds of quantities, actually an infinite number of them, that can transform under rotations and reflections in different ways.


**Irreducible Representations**: An irrep of O(3) is characterized by two numbers: $(l, p)$, where $l$ is the 'angular momentum' and $p$ is the parity under inversion.
$l$ can take values $0, 1, 2, ...$ and $p$ can take values $+1$ (even) or $-1$ (odd).

People often use the terminology 'x is an $(l, p)$-irrep of O(3)' to mean that x is a quantity that transforms according to the irrep $(l, p)$. We will also use this terminology for convenience.

To be precise, if $v$ is an irrep of type $(l, p)$, then it has dimension $2l + 1$, and transforms under rotations $R$ according to the [Wigner D-matrix](https://en.wikipedia.org/wiki/Wigner_D-matrix) of $\R$:
$$
v \mapsto_\R D^{(l)}(\R) v
$$
and under inversion, according to the parity $p$:
$$
v \mapsto_{-I} p v
$$

From our discussion above, we can see that scalars correspond to the irrep $(0, +1)$, vectors correspond to the irrep $(1, -1)$, pseudo-scalars correspond to the irrep $(0, -1)$, and pseudo-vectors correspond to the irrep $(1, +1)$.

Hence,
$$
D^{(0)}(\R) = 1, \quad D^{(1)}(\R) = \R
$$
The Wigner D-matrices are orthogonal matrices.

In e3nn, we keep track of how each quantity transforms under rotations and inversions by assigning it an irrep. The combination of the irreps with the actual data is called an IrrepsArray.

In [None]:
x = e3nn.IrrepsArray("0e + 1o", jnp.asarray([0.5, 1.0, 0.0, 2.0]))
print(x.irreps, x.array)
print("Printing the chunks:")
for irrep, chunk in zip(x.irreps, x.chunks):
    print(irrep, chunk)

### Tensor Products

Given two IrrepsArrays, we can create a new IrrepsArray via the equivariant Clebsch-Gordan tensor product. This tensor product reduces a tensor product of two irreps into a sum of irreps. This is a key operation in e3nn, and it can be used to build equivariant neural networks. 

In [None]:
x1 = e3nn.IrrepsArray("1o", jnp.asarray([10.0, 5.0, 2.0]))
x2 = e3nn.IrrepsArray("1o", jnp.asarray([7.0, 1.0, 1.0]))

y = e3nn.tensor_product(x1, x2)
y

You will notice that the (scalar) "0e" component above corresponds to the dot product of the two vectors. And clearly, the dot product is invariant under rotations and reflections.

In [None]:
e3nn.dot(x1, x2) / jnp.sqrt(3), y.slice_by_chunk[:1]

Similarly, the "1e" component corresponds to the cross product of the two vectors. The cross product is a pseudovector because it does not change sign under reflections, unlike a vector:
$$
\text{cross}(\mathbf{-a}, \mathbf{-b}) = \mathbf{-a} \times \mathbf{-b} = \mathbf{a} \times \mathbf{b} = \text{cross}(\mathbf{a}, \mathbf{b})
$$

In [None]:
e3nn.cross(x1, x2) / jnp.sqrt(2), y.slice_by_chunk[1:2]

The "2e" component corresponds to the symmetric traceless part of the outer product of the two vectors. This does not transform as a scalar or a vector. This is an example of a higher-order equivariant feature!

In [None]:
x1x, x1y, x1z = x1.array
x2x, x2y, x2z = x2.array

(
    jnp.asarray(
        [
            x1z * x2x + x1x * x2z,
            x1x * x2y + x1y * x2x,
            jnp.sqrt(4 / 3) * ((x1y * x2y) - (x1x * x2x + x1z * x2z) / 2),
            x1y * x2z + x1z * x2y,
            x1z * x2z - x1x * x2x,
        ]
    ) / jnp.sqrt(2),
    y.slice_by_chunk[2:3],
)

We see that the "2e" component has dimension 5, as expected.
In general, the tensor product of two irreps of type $(l_1, p_1)$ and $(l_2, p_2)$ will be a sum of irreps of the form $(l, p)$ with $|l_1 - l_2| \leq l \leq l_1 + l_2$. Thus, we can build higher-order equivariant features by taking tensor products of lower-order features.

If you are interested in learning more about the various kinds of equivariant tensor products that people have designed, you should check out our [recent GRaM paper](https://openreview.net/forum?id=0HHidbjwcf)!