In [1]:
import torch 
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, radius_graph
from torch_geometric.utils import add_self_loops, degree

import ase
import torch.nn as nn
import torch.nn.functional as Func
from torch.nn import Embedding, Sequential, Linear, ModuleList, Module
import numpy as np
from torch import linalg as LA
import math

from torch_geometric.data import Data

In [2]:
class CosineCutoff(torch.nn.Module):

    def __init__(self, cutoff=5.0):
        super(CosineCutoff, self).__init__()
        #self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
        self.cutoff = cutoff

    def forward(self, distances):
        """Compute cutoff.

        Args:
            distances (torch.Tensor): values of interatomic distances.

        Returns:
            torch.Tensor: values of cutoff function.

        """
        # Compute values of cutoff function
        cutoffs = 0.5 * (torch.cos(distances * np.pi / self.cutoff) + 1.0)
        # Remove contributions beyond the cutoff radius
        cutoffs *= (distances < self.cutoff).float()
        return cutoffs

In [3]:
class BesselBasis(torch.nn.Module):
    """
    Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
    """

    def __init__(self, cutoff=5.0, n_rbf=20):
        """
        Args:
            cutoff: radial cutoff
            n_rbf: number of basis functions.
        """
        super(BesselBasis, self).__init__()
        # compute offset and width of Gaussian functions
        freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff
        self.register_buffer("freqs", freqs)

    def forward(self, inputs):
        inputs = torch.norm(inputs, p=2, dim=1)
        a = self.freqs
        ax = torch.outer(inputs,a)
        sinax = torch.sin(ax)

        norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs)
        y = sinax / norm[:,None]

        return y

In [4]:
class PaiNN(torch.nn.Module):
    def __init__(self, num_feat, out_channels, num_nodes, cut_off=5.0, n_rbf=20, num_interactions=3):
        super(PaiNN, self).__init__() 
        '''PyG implementation of PaiNN network of Schütt et. al. Supports two arrays  
           stored at the nodes of shape (num_nodes,num_feat,1) and (num_nodes, num_feat,3). For this 
           representation to be compatible with PyG, the arrays are flattened and concatenated. 
           Important to note is that the out_channels must match number of features'''
        
        self.num_nodes = num_nodes
        self.num_interactions = num_interactions
        self.cut_off = cut_off
        self.n_rbf = n_rbf
        self.linear = Linear(num_feat,num_feat)
        self.silu = Func.silu
        
        self.list_message = nn.ModuleList(
            [
                MessagePassPaiNN(num_feat, out_channels, num_nodes, cut_off, n_rbf)
                for _ in range(self.num_interactions)
            ]
        )
        self.list_update = nn.ModuleList(
            [
                UpdatePaiNN(num_feat, out_channels, num_nodes)
                for _ in range(self.num_interactions)
            ]
        )


    def forward(self, s,v, edge_index, edge_attr):
        
        
        for i in range(self.num_interactions):
            
            s_temp,v_temp = self.list_message[i](s,v, edge_index, edge_attr)
            s, v = s_temp+s, v_temp+v
            s_temp,v_temp = self.list_update[i](s,v) 
            s, v = s_temp+s, v_temp+v       
        
        s = self.linear(s)
        s = self.silu(s)
        s = self.linear(s)
        
        return v

In [5]:
class MessagePassPaiNN(MessagePassing):
    def __init__(self, num_feat, out_channels, num_nodes, cut_off=5.0, n_rbf=20):
        super(MessagePassPaiNN, self).__init__(aggr='add') 
        
        self.lin1 = Linear(num_feat, out_channels) 
        self.lin2 = Linear(out_channels, 3*out_channels) 
        self.lin_rbf = Linear(n_rbf, 3*out_channels) 
        self.silu = Func.silu
        
        #self.prepare = Prepare_Message_Vector(num_nodes)
        self.RBF = BesselBasis(cut_off, n_rbf)
        self.f_cut = CosineCutoff(cut_off)
        self.num_nodes = num_nodes
    
    def forward(self, s,v, edge_index, edge_attr):
        
        s = s.flatten(-1)
        v = v.flatten(-2)
        
        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]
    
        x =torch.cat([s, v], dim = -1)
        
        
        x = self.propagate(edge_index, x=x, edge_attr=edge_attr
                            ,flat_shape_s=flat_shape_s, flat_shape_v=flat_shape_v)
            
        return x    
    
    def message(self, x_j, edge_attr, flat_shape_s, flat_shape_v):
        
        
        # Split Input into s_j and v_j
        s_j, v_j = torch.split(x_j, [flat_shape_s, flat_shape_v], dim=-1)
        
        # r_ij channel
        rbf = self.RBF(edge_attr)
        ch1 = self.lin_rbf(rbf)
        cut = self.f_cut(edge_attr.norm(dim=-1))
        W = torch.einsum('ij,i->ij',ch1, cut) # ch1 * f_cut
        
        # s_j channel
        phi = self.lin1(s_j)
        phi = self.silu(phi)
        phi = self.lin2(phi)
        
        # Split 
        left, dsm, right = torch.tensor_split(phi*W,3,dim=-1)
        
        # v_j channel
        normalized = Func.normalize(edge_attr, p=2, dim=1)
        
        v_j = v_j.reshape(-1, int(flat_shape_v/3), 3)
        hadamard_right = torch.einsum('ij,ik->ijk',right, normalized)
        hadamard_left = torch.einsum('ijk,ij->ijk',v_j,left)
        dvm = hadamard_left + hadamard_right 
        
        # Prepare vector for update
        x_j = torch.cat((dsm,dvm.flatten(-2)), dim=-1)
       
        return x_j
    
    def update(self, out_aggr,flat_shape_s, flat_shape_v):
        
        s_j, v_j = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)
        
        return s_j, v_j.reshape(-1, int(flat_shape_v/3), 3)

In [6]:
class UpdatePaiNN(torch.nn.Module):
    def __init__(self, num_feat, out_channels, num_nodes):
        super(UpdatePaiNN, self).__init__() 
        
        self.lin_up = Linear(2*num_feat, out_channels) 
        self.denseU = Linear(num_feat,out_channels, bias = False) 
        self.denseV = Linear(num_feat,out_channels, bias = False) 
        self.lin2 = Linear(out_channels, 3*out_channels) 
        self.silu = Func.silu
        
        
    def forward(self, s,v):
        
        # split and take linear combinations
        #s, v = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)
        
        s = s.flatten(-1)
        v = v.flatten(-2)
        
        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]
        
        v_u = v.reshape(-1, int(flat_shape_v/3), 3)
        v_ut = torch.transpose(v_u,1,2)
        U = torch.transpose(self.denseU(v_ut),1,2)
        V = torch.transpose(self.denseV(v_ut),1,2)
        
        
        # form the dot product
        UV =  torch.einsum('ijk,ijk->ij',U,V) 
        
        # s_j channel
        nV = torch.norm(V, dim=-1)

        s_u = torch.cat([s, nV], dim=-1)
        s_u = self.lin_up(s_u) 
        s_u = Func.silu(s_u)
        s_u = self.lin2(s_u)
        #s_u = Func.silu(s_u)
        
        # final split
        top, middle, bottom = torch.tensor_split(s_u,3,dim=-1)
        
        # outputs
        dvu = torch.einsum('ijk,ij->ijk',v_u,top) 
        dsu = middle*UV + bottom 
        
        #update = torch.cat((dsu,dvu.flatten(-2)), dim=-1)
        
        return dsu, dvu.reshape(-1, int(flat_shape_v/3), 3)

### Ethanol

In [25]:
F = 128
num_nodes = 9
s0 = torch.rand(num_nodes,F, dtype=torch.float)
PA = PaiNN(F, F, 9)

#### Unrotated and Untranslated Inputs

In [26]:
F = 128
num_nodes = 9

v0 = torch.zeros(num_nodes,F,3, dtype=torch.float)

r_ij = torch.tensor([[ 0.0072, -0.5687,  0.0000],
        [-1.2854,  0.2499,  0.0000],
        [ 1.1304,  0.3147,  0.0000],
        [ 0.0392, -1.1972,  0.8900],
        [ 0.0392, -1.1972, -0.8900],
        [-1.3175,  0.8784,  0.8900],
        [-1.3175,  0.8784, -0.8900],
        [-2.1422, -0.4239,  0.0000],
        [ 1.9857, -0.1365,  0.0000]], dtype = torch.float)

# edge_attr: inter_atomic distances
edge_index = radius_graph(r_ij, r=1.70, batch=None, loop=False)
row, col = edge_index
edge_attr = (r_ij[row] - r_ij[col])
#print(edge_index.dtype == torch.long)

V_1 = PA(s0,v0, edge_index,edge_attr)

#### Rotated and Translated inputs

In [27]:
a = 10 # Angle
b = 0 # Translation factor
F = 128
num_nodes = 9

v0 = torch.zeros(num_nodes,F,3, dtype=torch.float)
# Translation
translate = b*torch.ones(r_ij.shape[0], r_ij.shape[1])
trans_r_ij = r_ij + translate
# Rotation
rot_mat = torch.tensor([[1,0,0],
                       [0, np.cos(a), -np.sin(a)],
                      [0, np.sin(a), np.cos(a)]], dtype = torch.float) 
rot_r_ij = (rot_mat@trans_r_ij.T).T

# edge_attr: inter_atomic distances
edge_index = radius_graph(rot_r_ij, r=1.70, batch=None, loop=False)
row, col = edge_index
edge_attr_st = (rot_r_ij[row] - rot_r_ij[col])
#print(edge_index.dtype == torch.long)

V_2 = PA(s0,v0, edge_index,edge_attr_st)

In [28]:
rot_out = (rot_mat@V_1[0].T).T # rotate output with unrotated inputs (taking only first node)

In [29]:
max_err = max((abs((V_2[0]-rot_out)/V_2[0])).flatten()) # max.error between rotated output of network with unrotated inputs, and output of network with rotated input.
min_err = min((abs((V_2[0]-rot_out)/V_2[0])).flatten())

In [30]:
print('max:', max_err, 'min:', min_err)

max: tensor(7.4412e-05, grad_fn=<UnbindBackward>) min: tensor(0., grad_fn=<UnbindBackward>)


### Formaldehyde

In [13]:
# #Paramerts
# # F: Num. features, r_ij: cartesian positions
# F = int(128/2)
# num_nodes = 4
# s0 = torch.rand(num_nodes,F, dtype=torch.float)
# v0 = torch.zeros(num_nodes,F,3, dtype=torch.float)
# s10 = torch.rand(num_nodes,F, dtype=torch.float)
# v10 = torch.zeros(num_nodes,F,3, dtype=torch.float)
# r_ij =  torch.tensor([[0.000000,  0.000000,  -0.537500],
#   [0.000000,  0.000000,   0.662500],
#   [0.000000,  0.866025,  -1.037500],
#   [0.000000, -0.866025,  -1.037500]])  

In [14]:
# # edge_attr: inter_atomic distances
# edge_index = radius_graph(r_ij, r=1.30, batch=None, loop=False)
# row, col = edge_index
# edge_attr = (r_ij[row] - r_ij[col])
# #print(edge_index.dtype == torch.long)

In [15]:
# # PA = PaiNN(F, F, 4)
# # S = PA(s0,v0, edge_index,edge_attr)
# rev=RevPaiNN(F, F, 4)

### Water

In [16]:
# #Paramerts
# # F: Num. features, r_ij: cartesian positions
# F = 128
# num_nodes = 3
# s0 = torch.rand(num_nodes,F, dtype=torch.float)
# v0 = torch.zeros(num_nodes,F,3, dtype=torch.float)
# r_ij =  torch.tensor([[0.000000,  0.000000,  0.000000],
#    [0.758602,  0.000000,  0.504284],
#   [0.758602,  0.000000,  -0.504284]]) 

In [17]:
# # edge_attr: inter_atomic distances
# edge_index = radius_graph(r_ij, r=1.30, batch=None, loop=False)
# row, col = edge_index
# edge_attr = (r_ij[row] - r_ij[col])
# #print(edge_index.dtype == torch.long)

In [18]:
# PA = PaiNN(F, F, 3)
# h20 = PA(s0,v0, edge_index,edge_attr)

In [19]:
# test = UpdatePaiNN(F,F,4)
# prep = Prepare_Message_Vector(4)

In [20]:
lin = Linear(128,128, bias = False) 

In [21]:
vT = torch.transpose(v0, 1,2)

In [22]:
c = torch.transpose(lin(vT),1,2)

In [23]:
torch.einsum('ijk,ijk->ij',c,c).shape

torch.Size([9, 128])

In [24]:
c.shape

torch.Size([9, 128, 3])