In [1]:
import torch
import e3nn
from torch_cluster import radius_graph
from torch_scatter import scatter

In [2]:
MAX_Z = 10
N_ATOMS = 4
DIM_ATOMS = 12
ATOM_IRREPS = e3nn.o3.Irreps(f"{DIM_ATOMS}x0e")
R_CUT = 2.0
DIM_R = 1
R_IRREPS = e3nn.o3.Irreps(f"{DIM_R}x0e")
SH_IRREPS = e3nn.o3.Irreps.spherical_harmonics(2)
FEATURE_IRREPS = R_IRREPS + SH_IRREPS
DIM_INPUT = 9
DIM_OUTPUT = 30
N_INTERACTIONS = 3

In [3]:
z = torch.randint(1, MAX_Z, (N_ATOMS,))
z

tensor([6, 8, 9, 6])

In [4]:
one_hot = torch.nn.functional.one_hot(z, MAX_Z).float()
one_hot

tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])

In [5]:
atom_filter = torch.nn.Linear(MAX_Z, DIM_ATOMS)
x = atom_filter(one_hot)
x

tensor([[-0.0770,  0.3873,  0.0529, -0.2166, -0.1137,  0.0419,  0.0978, -0.2285,
         -0.0447,  0.1377, -0.0778, -0.0754],
        [-0.1061,  0.3869,  0.0288, -0.1677, -0.2974, -0.0076, -0.1399, -0.2070,
         -0.1173, -0.2545, -0.5210,  0.0836],
        [-0.0951, -0.0466, -0.4967, -0.2758, -0.1829, -0.0439,  0.0979, -0.3595,
         -0.1889, -0.1096, -0.1910,  0.3734],
        [-0.0770,  0.3873,  0.0529, -0.2166, -0.1137,  0.0419,  0.0978, -0.2285,
         -0.0447,  0.1377, -0.0778, -0.0754]], grad_fn=<AddmmBackward0>)

In [6]:
pos = torch.randn([N_ATOMS, 3])
pos

tensor([[ 1.9915,  0.2242, -0.6689],
        [ 0.2612,  0.1146, -0.6114],
        [ 1.0665, -0.5098,  1.6085],
        [-0.8244,  1.0469,  0.0753]])

In [7]:
src, dst = radius_graph(pos, R_CUT)
src, dst

(tensor([1, 0, 3, 1]), tensor([0, 1, 1, 3]))

In [32]:
edges = pos[src] - pos[dst]
edges

tensor([[-1.7304, -0.1096,  0.0575],
        [ 1.7304,  0.1096, -0.0575],
        [-1.0855,  0.9323,  0.6868],
        [ 1.0855, -0.9323, -0.6868]])

In [9]:
r = edges.norm(dim=-1)
r

tensor([1.7348, 1.7348, 1.5872, 1.5872])

In [10]:
sh = e3nn.o3.spherical_harmonics(SH_IRREPS, edges, normalize=True, normalization="component")
sh, sh.shape


(tensor([[ 1.0000, -1.7276, -0.1094,  0.0574, -0.1280,  0.2441, -1.1046, -0.0081,
          -1.9245],
         [ 1.0000,  1.7276,  0.1094, -0.0574, -0.1280,  0.2441, -1.1046, -0.0081,
          -1.9245],
         [ 1.0000, -1.1846,  1.0174,  0.7495, -1.1462, -1.5559,  0.0392,  0.9843,
          -0.5433],
         [ 1.0000,  1.1846, -1.0174, -0.7495, -1.1462, -1.5559,  0.0392,  0.9843,
          -0.5433]]),
 torch.Size([4, 9]))

In [11]:
rot = e3nn.o3.rand_matrix()
rot

tensor([[ 0.5539,  0.5101, -0.6580],
        [-0.8325,  0.3490, -0.4302],
        [ 0.0102,  0.7861,  0.6180]])

In [12]:
sh_rot = SH_IRREPS.D_from_matrix(rot)
sh_rot

tensor([[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.5539,  0.5101, -0.6580,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000, -0.8325,  0.3490, -0.4302,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0102,  0.7861,  0.6180,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.3356,  0.4406,  0.6946, -0.2020,
         -0.4123],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.3095, -0.2314,  0.3084, -0.4491,
          0.7442],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.6203, -0.5033, -0.3173, -0.2601,
         -0.4400],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.5189, -0.6509,  0.4752, -0.1225,
         -0.2574],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.3708, -0.2745,  0.3098,  0.8215,
          0.1278]])

In [13]:
edge_features = torch.cat((r.unsqueeze(-1), sh), dim=-1)
edge_features

tensor([[ 1.7348,  1.0000, -1.7276, -0.1094,  0.0574, -0.1280,  0.2441, -1.1046,
         -0.0081, -1.9245],
        [ 1.7348,  1.0000,  1.7276,  0.1094, -0.0574, -0.1280,  0.2441, -1.1046,
         -0.0081, -1.9245],
        [ 1.5872,  1.0000, -1.1846,  1.0174,  0.7495, -1.1462, -1.5559,  0.0392,
          0.9843, -0.5433],
        [ 1.5872,  1.0000,  1.1846, -1.0174, -0.7495, -1.1462, -1.5559,  0.0392,
          0.9843, -0.5433]])

In [14]:
data = {'pos': pos,
        'z': z,
        'x': x,
        'src': src,
        'dst': dst,
        'r': r,
        'sh': sh,
        'edge_features': edge_features,
}

In [15]:
class InteractionLayer(torch.nn.Module):
    def __init__(self, irreps_feature, irreps_atom, irreps_mid):
        super().__init__()
        self.tp1 = e3nn.o3.FullyConnectedTensorProduct(irreps_feature, irreps_feature, irreps_mid)
        self.tp2 = e3nn.o3.FullyConnectedTensorProduct(irreps_atom, irreps_mid, irreps_atom)

    def forward(self, data):
        mid1 = self.tp1(data['edge_features'], data['edge_features'])
        mid2 = self.tp2(data['x'][data['dst']], mid1)
        out = scatter(mid2, data['src'], dim=0, out=torch.zeros_like(data['y']))
        data['y'] = data['y'] + out
        return data

In [29]:
class ReadoutLayer(torch.nn.Module):
    def __init__(self, dim_atoms, dim_mid):
        super().__init__()
        self.layer1 = torch.nn.Linear(dim_atoms, dim_mid)
        self.layer2 = torch.nn.Linear(dim_mid, 1)

    def forward(self, data):
        out = self.layer1(data['y'])
        out = torch.relu(out)
        out = self.layer2(out)
        return out.sum()

In [30]:
class InteractionModel(torch.nn.Module):
    def __init__(self, dim_atoms, irreps_feature, irreps_atom, irreps_mid, n_interactions):
        super().__init__()
        self.layers = torch.nn.ModuleList([InteractionLayer(irreps_feature, irreps_atom, irreps_mid)] * n_interactions)
        self.n_interactions = n_interactions
        self.readout = ReadoutLayer(dim_atoms, dim_atoms // 2)

    def forward(self, data):
        data['y'] = data['x']
        for layer in self.layers:
            data = layer(data)
        readout = self.readout(data)
        return data, readout

In [31]:
model = InteractionModel(DIM_ATOMS, FEATURE_IRREPS, ATOM_IRREPS, '16x0e + 4x1o + 1x2e', N_INTERACTIONS)
result = model(data)
result[1]



tensor(0.0935, grad_fn=<SumBackward0>)