In [50]:
import math
from typing import Dict, Optional

import torch

from e3nn import o3
from e3nn.math import soft_one_hot_linspace
from e3nn.nn import ExtractIr, FullyConnectedNet, Gate
from e3nn.o3 import FullyConnectedTensorProduct, TensorProduct
from e3nn.util.jit import compile_mode
from e3nn.math import soft_unit_step
import torch
from e3nn import nn
from torch_scatter import scatter
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch_cluster import radius_graph


In [57]:
irreps_input = o3.Irreps("15x0e")
irreps_query = o3.Irreps("15x0e")
irreps_key = o3.Irreps("15x0e")
irreps_output = o3.Irreps("128x0e")

In [58]:
num_nodes = 20

pos = torch.randn(num_nodes, 3)
f = irreps_input.randn(num_nodes, -1)

# create graph
max_radius = 1.3
edge_src, edge_dst = radius_graph(pos, max_radius)
edge_vec = pos[edge_src] - pos[edge_dst]
edge_length = edge_vec.norm(dim=1)

In [59]:
pos.shape

torch.Size([20, 3])

In [60]:
h_q = o3.Linear(irreps_input, irreps_query)



In [61]:
number_of_basis = 10
edge_length_embedded = soft_one_hot_linspace(
    edge_length,
    start=0.0,
    end=max_radius,
    number=number_of_basis,
    basis='smooth_finite',
    cutoff=True  # goes (smoothly) to zero at `start` and `end`
)
edge_length_embedded = edge_length_embedded.mul(number_of_basis**0.5)

In [62]:
edge_weight_cutoff = soft_unit_step(10 * (1 - edge_length / max_radius))

In [63]:
irreps_sh = o3.Irreps.spherical_harmonics(3)
edge_sh = o3.spherical_harmonics(irreps_sh, edge_vec, True, normalization='component')

In [64]:
tp_k = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_key, shared_weights=False)
fc_k = nn.FullyConnectedNet([number_of_basis, 16, tp_k.weight_numel], act=torch.nn.functional.silu)

tp_v = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_output, shared_weights=False)
fc_v = nn.FullyConnectedNet([number_of_basis, 16, tp_v.weight_numel], act=torch.nn.functional.silu)

dot = o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "0e")



In [65]:
# compute the queries (per node), keys (per edge) and values (per edge)
q = h_q(f)
k = tp_k(f[edge_src], edge_sh, fc_k(edge_length_embedded))
v = tp_v(f[edge_src], edge_sh, fc_v(edge_length_embedded))

# compute the softmax (per edge)
exp = edge_weight_cutoff[:, None] * dot(q[edge_dst], k).exp()  # compute the numerator
z = scatter(exp, edge_dst, dim=0, dim_size=len(f))  # compute the denominator (per nodes)
z[z == 0] = 1  # to avoid 0/0 when all the neighbors are exactly at the cutoff
alpha = exp / z[edge_dst]

# compute the outputs (per node)
f_out = scatter(alpha.relu().sqrt() * v, edge_dst, dim=0, dim_size=len(f))

In [66]:
f_out.shape

torch.Size([20, 128])

In [67]:
f

tensor([[ 1.0108, -0.4120, -1.2691,  0.5434, -0.3108, -0.0259,  1.2164, -0.0178,
          0.7902, -1.5214, -0.4848,  0.4506, -0.0061, -1.0693, -1.6660],
        [ 1.9457, -0.9373, -0.6052, -0.0843,  0.5191,  1.3375, -0.7168, -1.1653,
          0.7648, -1.1443,  0.5113, -0.7004,  0.9156, -0.0039, -1.2587],
        [ 0.7415,  0.8114, -0.4896,  0.7137, -0.5670,  0.6979, -0.3470,  0.9645,
         -2.4666,  0.9865,  1.0848,  0.4396, -0.1511,  0.4272,  0.8656],
        [ 1.0039, -0.8731,  0.1023,  0.8909,  2.0612, -0.3998,  0.5176,  0.5354,
         -1.7104,  1.0056, -0.6347,  0.0039,  1.2202,  0.3048,  1.2008],
        [ 1.7200, -1.4296,  0.7554,  1.1488,  0.2106,  2.5101,  0.1621,  0.0049,
          1.0198,  0.5128,  0.0367,  1.4464,  0.5385,  0.6567,  1.2070],
        [-0.9541,  1.0299,  0.6233,  1.7747, -0.5644,  0.3849, -1.1908,  0.7297,
          0.3244,  0.7336,  1.0145, -0.7193, -0.7312, -1.5739,  0.2230],
        [-1.8391,  0.6928,  0.0222,  1.5460,  1.8051, -1.2119, -0.9513,  0.5

In [68]:
f_out

tensor([[ 0.8893,  0.4559, -0.2268,  ..., -0.7073,  1.0102, -0.7480],
        [ 1.0796,  0.9648, -0.0740,  ...,  0.0138, -0.4639,  0.8033],
        [ 0.1390,  1.8978,  1.2474,  ...,  0.4028, -2.0364, -0.6878],
        ...,
        [-0.1836, -1.1423, -0.2519,  ...,  1.7280, -0.4552,  0.2658],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2556, -0.8711, -0.5214,  ...,  0.1657,  0.2269,  0.2245]],
       grad_fn=<ScatterAddBackward0>)

In [32]:
from torch_geometric.data import Data

In [33]:
dataset=Data(pos=pos,x=f)

In [35]:
class ETG(nn.Module):
    def __init__(self,
                 irreps_in: o3.Irreps,
                 irreps_query: o3.Irreps,
                 irreps_key: o3.Irreps,
                 irreps_output: o3.Irreps,
                 ):
        super(ETG, self).__init__()
        self.irreps_in = irreps_in
        self.irreps_query = irreps_query
        self.irreps_key = irreps_key
        self.irreps_output = irreps_output
        
    

In [36]:
def transformer(
        irreps_input: o3.Irreps,
        irreps_query: o3.Irreps,
        irreps_key: o3.Irreps,
        irreps_output: o3.Irreps,
        dataset,
        
):
    edge_src, edge_dst = radius_graph(dataset.pos, max_radius)
    edge_vec = dataset.pos[edge_src] - dataset.pos[edge_dst]
    edge_length = edge_vec.norm(dim=1)
    number_of_basis = 10
    edge_length_embedded = soft_one_hot_linspace(
        edge_length,
        start=0.0,
        end=max_radius,
        number=number_of_basis,
        basis='smooth_finite',
        cutoff=True
    )
    edge_length_embedded = edge_length_embedded.mul(number_of_basis**0.5)
    edge_weight_cutoff = soft_unit_step(10 * (1 - edge_length / max_radius))
    irreps_sh = o3.Irreps.spherical_harmonics(3)
    dot = o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "0e")
    edge_sh = o3.spherical_harmonics(irreps_sh, edge_vec, True, normalization='component')

    h_q = o3.Linear(irreps_input, irreps_query)
    tp_k = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_key, shared_weights=False)
    fc_k = nn.FullyConnectedNet([number_of_basis, 16, tp_k.weight_numel], act=torch.nn.functional.silu)

    tp_v = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_output, shared_weights=False)
    fc_v = nn.FullyConnectedNet([number_of_basis, 16, tp_v.weight_numel], act=torch.nn.functional.silu)
    q = h_q(f)
    k = tp_k(f[edge_src], edge_sh, fc_k(edge_length_embedded))
    v = tp_v(f[edge_src], edge_sh, fc_v(edge_length_embedded))

    exp = edge_weight_cutoff[:, None] * dot(q[edge_dst], k).exp()
    z = scatter(exp, edge_dst, dim=0, dim_size=len(f))
    z[z == 0] = 1
    alpha = exp / z[edge_dst]

    return scatter(alpha.relu().sqrt() * v, edge_dst, dim=0, dim_size=len(f))
