In [1]:
import torch
import torch.nn as nn
from e3nn import o3

# Invarian GNN block

Let's set up feature dimension and filter dimension

In [2]:
feature_dimension = 8
filter_dimension = 16

Embedding (here, lets say we have 5 different elements)

In [3]:
n_el = 5
embeding_block = nn.Embedding(n_el,feature_dimension)
element_label = torch.tensor([0])
embeding_vector = embeding_block(element_label)

In [None]:
print(embeding_vector)

Linear blocks for atom-wise interaction (feature to feature) and feature 

In [None]:
feature_to_filter = nn.Linear(feature_dimension,filter_dimension,bias=False)
filter_vector = feature_to_filter(embeding_vector)
print(filter_vector)

Activations are as `torch.nn.functional`

In [None]:
activated_filter = nn.functional.silu(filter_vector)
print(filter_vector)

Hint for edge to filter
Here I will take just distances from points in the cutoff. Here i will assume that we have two atoms in the distance 3.0 and 2.3 from the center atom

In [7]:
offset = torch.linspace(0.0,6,10)
d_ij = torch.tensor([3.0,2.3])
j = (d_ij.unsqueeze(-1)-offset)

In [None]:
tmp = nn.Linear(10,filter_dimension)
tmp(j)

Class for the distance RBF

In [9]:
class RBFdistance(nn.Module):
    def __init__(self,cutoff,n_distances):
        super().__init__()
        offset = torch.linspace(0.0,cutoff,n_distances)

    def forward(self,d_ij):
        return d_ij.unsqueeze(-1)-offset
        

# Basics of e3n

`o3` group means no translation.

 Here I will show some e3nn stuff which won't be really needed, just to see the tensor product behaviour. This was also very heavily influanced by [MACE tutorila](https://colab.research.google.com/drive/1AlfjQETV_jZ0JQnV5M3FGwAM2SGCl2aU#scrollTo=E-lt2fyV_03E) 

## Basic operation

Creation of the irreps is easily done like this

In [None]:
from e3nn import o3
irreps1 = o3.Irreps("3x2o+5x3e")
print(irreps1.dim)
# you can create a random irreps:
my_ir = irreps1.randn(-1,1)
my_ir = torch.tensor([0.3 for _ in range(irreps1.dim)])
# or you can just create a torch.tensor of the 

## Spherical harmonics

First spherical harmonics needed for the filter

In [None]:
# a function for Ylms where we evaluate for l=0,1,2.
spherical_harmonics = o3.SphericalHarmonics([0,1,2], True)

# evaulate spherical harmonics on a vector
vector = torch.tensor([1.0, 0.2, 0.75])
print(spherical_harmonics(vector)) # size if because scalar (1) + vector (3) + tensor (5)

# Tensor product
Here Is the full tensor product, which at the end won't be used today

In [None]:
tp = o3.FullTensorProduct(
    irreps_in1='2x0e + 3x1o',
    irreps_in2='5x0e + 7x1e'
)
print(tp)
tp.visualize();

This is how to set up tensor product we will need today

In [None]:
tp = o3.FullyConnectedTensorProduct(
    irreps_in1='5x0e + 5x1e',
    irreps_in2='6x0e + 4x1e',
    irreps_out='15x0e + 3x1e'
)
print(tp)
tp.visualize();

In [None]:
# set up a tensor product.
# This does the product of two l=0,1,2 arrays, and maps the result to three l=0 values.
tensor_product = o3.FullyConnectedTensorProduct(
    irreps_in1=o3.Irreps("1x1o"),
    irreps_in2=o3.Irreps("3x0e + 1x1o + 5x2e"),
    irreps_out=o3.Irreps("3x0e"),
    internal_weights=False
)
print(tensor_product)
tensor_product.visualize()

In [15]:
irresp1 = o3.Irreps("1x1o")
irresp2 = o3.Irreps("3x0e + 1x1o + 5x2e")
# my_irresp1 = irresp1.randn(1,-1)
my_irresp1 = torch.tensor([-1.2,0.2,0.3]).unsqueeze(0)
my_irresp2 = irresp2.randn(1,-1)

In [None]:
print(my_irresp1.shape)
print(my_irresp2.shape)

For convolution you will need to find out how many weights you need. This is different from our invariant case, where this was set up by you, this is dependent on a number of channels and fancy $\ell$ of irreps. Find it like this: 

In [None]:
print(f"{tensor_product} needs {tensor_product.weight_numel} weights")

In [None]:
print(torch.arange(1,tensor_product.weight_numel,1).shape)

In [None]:
# product the arrays
product = tensor_product(
    my_irresp1.unsqueeze(0),
    my_irresp2.unsqueeze(0),
    weight=torch.rand(tensor_product.weight_numel) # the product has weights which can be trained - here I will just provide them as random, you have to provide filter values
)
print('invariant outputs:', product)

Rotate 1 vector and lets see if the scalar values will change

In [None]:
R = o3.rand_matrix()
torch.linalg.norm(torch.matmul(irresp1.D_from_matrix(R),my_irresp1.T).T)

In [None]:
torch.linalg.norm(my_irresp1)

In [22]:
rot_irreps1 = torch.matmul(irresp1.D_from_matrix(R),my_irresp1.T).T

In [None]:
# product the arrays
product = tensor_product(
    my_irresp1.unsqueeze(0),
    my_irresp2.unsqueeze(0),
    weight=torch.rand(tensor_product.weight_numel) # the product has weights which can be trained - here I will just provide them as random, you have to provide filter values
)
print('invariant outputs:', product)

- Create irrep

In [24]:
irreps = o3.Irreps("1x1e")

In [25]:
ir1 = irreps.randn(1,-1)
ir2 = irreps.randn(1,-1)

Tensor products are bilinear and equivariant

In [None]:
full_tp = o3.FullTensorProduct("1x1e","1x1e")

Getting invariant product that changes with non-invariant information

In [None]:
# set up a tensor product.
# This does the product of two l=0,1,2 arrays, and maps the result to three l=0 values.
tensor_product = o3.FullyConnectedTensorProduct(
    irreps_in1=o3.Irreps("1x0e + 1x1o + 1x2e"),
    irreps_in2=o3.Irreps("1x0e + 1x1o + 1x2e"),
    irreps_out=o3.Irreps("3x0e"),
    internal_weights=False
)
print(tensor_product)
tensor_product.visualize()

### Bilinear
- For nerds, this show bilinearity of the tensor product, we can discuss it but it is not needed
$$
(\alpha x_1 + x_2) \otimes y = \alpha x_1 \otimes y + x_2 \otimes y 
\quad \text{and} \quad 
x \otimes (\alpha y_1 + y_2) = \alpha x \otimes y_1 + x \otimes y_2
$$ 

In [None]:
import torch
from e3nn.o3 import Irreps, TensorProduct


irreps_in1 = Irreps("1x0e + 1x1e")
irreps_in2 = Irreps("1x0e + 1x1e")
irreps_out = Irreps("1x0e + 1x1e")

tp = o3.FullTensorProduct(irreps_in1, irreps_in2)

a1 = irreps_in1.randn(1,-1)
a2 = irreps_in1.randn(1,-1)
b1 = irreps_in2.randn(1,-1)
b2 = irreps_in2.randn(1,-1)

lhs1 = tp(a1 + a2, b1)
rhs1 = tp(a1, b1) + tp(a2, b1)

lhs2 = tp(a1, b1 + b2)
rhs2 = tp(a1, b1) + tp(a1, b2)

print("Check TP(a1 + a2, b1) == TP(a1, b1) + TP(a2, b1):", torch.allclose(lhs1, rhs1))
print("Check TP(a1, b1 + b2) == TP(a1, b1) + TP(a1, b2):", torch.allclose(lhs2, rhs2))


Different D matrices for pseudo vector and vector

irresp1 = o3.Irreps("1x1o")

In [29]:
inversion_matrix = -torch.eye(3)


D_pseudovector = o3.Irreps("1e").D_from_matrix(inversion_matrix)  
D_vector = o3.Irreps("1o").D_from_matrix(inversion_matrix)

In [None]:
print("No change after the inversion\n",D_pseudovector)
print("Change after the inversion\n",D_vector)