In [1]:
import sys
import os
current_dir = os.path.dirname(os.curdir)
parent_dir = os.path.abspath(os.path.join(current_dir, "../.."))
sys.path.append(parent_dir)
import torch
import system.units as units
import system.topology as topology
import system.box as box
import forces.twobody as twobody
import system.system as sys
from integrators.NVT import NVT
from integrators.NVE import NVE
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import freud
import numpy as np
from tqdm import trange
from utils import *
import numpy as np
from itertools import combinations
import torch.nn as nn
from tqdm import tqdm
from torch import Tensor
import wandb
device = "cuda"
dtype=torch.float32


╔═══════════════════════════════════════════════════╗
║                                                   ║
║  ██████╗   ██████╗    ██╗      ██████╗   ██╗  ██╗ ║
║ ██╔════╝  ██╔══██╗   ██╔██╗    ██╔══██╗  ██║ ██╔╝ ║
║ ╚█████╗   ██████╔╝  ██╔╝╚██╗   ██████╔╝  █████╔╝  ║
║  ╚═══██╗  ██╔═══╝  ██╔╝  ╚██╗  ██╔══██╗  ██╔═██╗  ║
║ ██████╔╝  ██║     ██╔╝    ╚██╗ ██║  ██║  ██║ ╚██╗ ║
║ ╚═════╝   ╚═╝     ╚═╝      ╚═╝ ╚═╝  ╚═╝  ╚═╝  ╚═╝ ║
║                                                   ║
║     Statistical Physics Autodiff Research Kit     ║
╚═══════════════════════════════════════════════════╝

          V(r)           ψ, φ              q
           │               │               │
           ○               ○               ○
         ╱ | ╲           ╱ | ╲           ╱ | ╲
        ○  ○  ○         ○  ○  ○         ○  ○  ○
         ╲ | ╱           ╲ | ╱           ╲ | ╱
           ○               ○               ○
           │               │               │
          g(r)             F         

### Physical System Setup

In [2]:
top, node_features, mass, energy_dict = build_top_and_features("alanine-dipeptide.prmtop")
B = 2048
pos = torch.tensor(pmd.load_file("alanine-dipeptide.pdb").coordinates,dtype=dtype,device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
atomic_numbers = [a.atomic_number for a in pmd.load_file("alanine-dipeptide.pdb").atoms]
b = box.Box([1000,1000,1000],["s","s","s"])
u = units.UnitSystem.akma()
mom = 0.5*torch.randn_like(pos)

S = sys.System(pos, mom, mass, top, b, energy_dict, u, node_features)
S.potential_energy()
S.compile_force_fn()
S.pos = S.pos - (S.mass.unsqueeze(-1) * S.pos).sum(dim=1, keepdim=True) / S.mass.sum(dim=1, keepdim=True).unsqueeze(-1)

integrator = NVE(0.005)
print(integrator)

NVE(dt=0.005)


In [3]:
with torch.no_grad():
    M =S.mass.clone()
    Q =S.node_features['charge'].clone()
    P =S.mom.clone()
    F =S.force().clone()
    X =S.pos

### Equivariant Graph Neural Network

In [4]:
class VectorMatrixLayer(nn.Module):
    def __init__(self, A, B):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(A,B))
        nn.init.kaiming_uniform_(self.weight, a=5**0.5)

    def forward(self, x):
        # x: [b, N, A, d]
        # weight: [A, B]
        # output: [B, N, B, d]
        return torch.einsum('AB,bnAd->bnBd', self.weight, x)
        
    def __repr__(self):
        return f"{self.__class__.__name__}(in_channels={self.weight.shape[0]}, out_channels={self.weight.shape[1]})"

class EquivariantGraphNeuralNetwork(nn.Module):

    def __init__(self, 
                f_1_layers, f_2_layers, f_3_layers, f_4_layers, f_5_layers,
                readout_layers, scalar_embed_layers, vector_embed_layers,
                message_passing_steps, p, activation_function,
                device
                ):

        super(EquivariantGraphNeuralNetwork, self).__init__()

         # --- extras ------------------------------------------------------------------------------------

        # Suppose we expect many body behavior. It may be beneficial to do multiple steps of node communication to
        # account for this. Therefore we have this extra parameter to indicate it.
        
        self.message_passing_steps = message_passing_steps

        # it may be helpful to store these
        
        self.activation_function = activation_function
        self.device = device
        self.p = p

        # I will be assuming the network is fully connected, the edge weight is the distance between nodes.
         
        # Fully connected graph (undirected, no self-loops)
        num_nodes = 5
        edges = list(combinations(range(num_nodes), 2))  # All unique pairs
        self.edge_index = torch.tensor(np.array(edges).T,device=device)  # 2 x E

        # --- message networks --------------------------------------------------------------------------------------

        # --- scalar features --------------------------------------------------------

        # f_1: s_j^(t), ||\vec{x}_{i,j}|| \rightarrow \mathbb{R}                                                    (1)
        
        # f_1 is a mapping from scalar features and scalarized vector features to an associated message. This message is
        # meant to be from node j to node i. If we take a permuation invariant aggregation operator (mean, sum) this 
        # scheme will be invariant to any permutative group action. Similarly, by scalarizing (applying the norm is a 
        # scalarizing type methdo) the vector features first we obtain invariance to any group action in SO(3). 
        # The corresponding message function can be read as

        # m_i^(t) := s_i^(t) + \sum_{j \in Neigh} f_1(s_j^(t), ||\vec{x}_{i,j}||)                                   (2)

        # Crucially the way this is defined is ambiugous to the channels of the network. Each channel corresponds to some 
        # feature of the graph at the node. An example of different "channels" of a physical system would be the charge,
        # mass, LJ parameters (given that mixing rules apply). Notably I am ignoring any features of a node which transform 
        # with respect to any element of the group. This choice allows me the freedom to apply non-linearties, mix information
        # across channels, and use any functional form of f_1 I please. Note that whenever I write out the mapping like in (1)
        # I am not being clear about the channels. Consider these functions on a per channel basis, each channel acts like (1).
        
        f_1_modules = []
        for in_features, out_features in zip(f_1_layers[:-1], f_1_layers[1:]):
            f_1_modules.append(nn.Linear(in_features, out_features))
            f_1_modules.append(activation_function)
            f_1_modules.append(torch.nn.Dropout(p=self.p))
        self.f_1 = nn.Sequential(*f_1_modules[:-2]) # exclude dropout and activation on the last pass to put in (-infty,infty)

        # Following the details of the PaiNN architechture the scalar message passing scheme we won't need  be anything beyond
        # a definition of f_1. Notably a key aspect is the ability to apply any 

        # --- vector features --------------------------------------------------------

        # f_2: s_j^(t), ||\vec{x}_{i,j}|| \rightarrow \mathbb{R}                                                    (3)

        # f_2 is exactly the same type of network as f_1. The reason for the distinction is to emphasize that they are 
        # explictly different networks. f_2 will act as a gate for the vector feature coming from node j. I like to think
        # about each node like a hairy point. Each hair of a persons haircut is not always important for their overall look.
        # Consider two cases, 1) a plain haired individual and 2) a power donut, the lary david type. Each hair on the sides
        # will be the same for both 1) and 2), implying there is nothing out of the ordinary. However, there is a clear 
        # distinction between the top of each head. Due to our mapping f_1 taking in scalar features (one of which could be
        # age in this example) we may want to use the age to put emphasis on the top of head hairs in order to distinguish
        # if this person is balding or not. The same can be said for the side hairs, we want to unweight this trait because
        # there is no indication in the scalar features that it is even needed. This situation can be compared with 
        # haircut 3) Skrillex's, his haircut has one side shaved completely. Another possible scalar feature might be career.
        # Using this we would know to weight hairs on the side more due to many EDM artists and DJs like Skrillex having such 
        # a haircut. Therefore we can consider the function f_2 as the vector feature gating as informed by the scalar features
        # and the scalarized vector features.

        f_2_modules = []
        for in_features, out_features in zip(f_2_layers[:-1], f_2_layers[1:]):
            f_2_modules.append(nn.Linear(in_features, out_features))
            f_2_modules.append(activation_function)
            f_2_modules.append(torch.nn.Dropout(p=self.p))
        self.f_2 = nn.Sequential(*f_2_modules[:-2]) # exclude dropout and activation on the last pass to put in (-infty,infty)


        # f_3: s_j^(t), ||\vec{x}_{i,j}|| \rightarrow \mathbb{R}                                                    (4)
        
        # f_3 is again more of the same. The only difference is that the gating is not done on the channel based features, 
        # rather this is on the edge based distance feature \vec{x}_{i,j}. 

        f_3_modules = []
        for in_features, out_features in zip(f_3_layers[:-1], f_3_layers[1:]):
            f_3_modules.append(nn.Linear(in_features, out_features))
            f_3_modules.append(activation_function)
            f_3_modules.append(torch.nn.Dropout(p=self.p))
        self.f_3 = nn.Sequential(*f_3_modules[:-2]) # exclude dropout and activation on the last pass to put in (-infty,infty)

        # Together these make up the full vector message like

        #  \vec{m}_i^(t) := \vec{v}_i^(t) + \sum_{j \in Neigh} [
        #          f_2(s_j^(t), ||\vec{x}_{i,j}||) \circ \vec{v_j}^(t) +  f_3(s_j^(t), ||\vec{x}_{i,j}||) \circ \vec{x_{i,j}}]  (5)

        # --- update networks --------------------------------------------------------------------------------------

        # The update networks act the similarly as the networks above. The key difference between the update and the message 
        # is there is no aggregation and there is no communication between the nodes. This implies there is no dependence on the
        # \vec{x_{i,j}}. This gives f_4 and f_5 as maps of the form. Also we do not interperet f_4 as a gating network as it 
        # just shifts the message to obtain the next scalar feature.

        # f_4: \vec{m}_i^(t), ||\vec{m}_i^(t)|| \rightarrow \mathbb{R}                                             (6)

        # f_5: \vec{m}_i^(t), ||\vec{m}_i^(t)|| \rightarrow \mathbb{R}                                             (7)

        # --- scalar features --------------------------------------------------------
        
        f_4_modules = []
        for in_features, out_features in zip(f_4_layers[:-1], f_4_layers[1:]):
            f_4_modules.append(nn.Linear(in_features, out_features))
            f_4_modules.append(activation_function)
            f_4_modules.append(torch.nn.Dropout(p=self.p))
        self.f_4 = nn.Sequential(*f_4_modules[:-2]) # exclude dropout and activation on the last pass to put in (-infty,infty)

        #  s_i^(t+1) := m_i^(t) + f_4(m_i^(t), ||\vec{m}_{i}^(t)||)                                                 (8)

        # --- vector features --------------------------------------------------------
        
        f_5_modules = []
        for in_features, out_features in zip(f_5_layers[:-1], f_5_layers[1:]):
            f_5_modules.append(nn.Linear(in_features, out_features))
            f_5_modules.append(activation_function)
            f_5_modules.append(torch.nn.Dropout(p=self.p))
        self.f_5 = nn.Sequential(*f_5_modules[:-2]) # exclude dropout and activation on the last pass to put in (-infty,infty)

        # At this point it may be beneficial to out learning to introduce the U and V matricies of the update function in a 
        # PaiNN network. I will not be doing this for pedagogical reasoning as of now. Perhaps if the learning is beans I will
        # come back and add it. 

        #  \vec{v}_i^(t+1) := \vec{m}_i^(t) + f_5(m_i^(t), ||\vec{m}_{i}^(t)||) \circ \vec{m}_{i}^(t)                (9)

        # --- readout --------------------------------------------------------------------------------------

        # Once all the message passing is complete there is a final readout step. Because we are intending to predict
        # vector features (momentum and position later) the readout will act on the vector features by taking a linear
        # combination of the vector channels. Linear combinations and rotations commute so this will be all good. 
        # Above we did not need to worry about permutation invariance because the graph based message passing handled it
        # for us. Here we can avoid any issues by not mixing nodes vector features. This would imply the vector readout
        # is only going to act like C -> 6, where the first three are the new node position and the last are the new
        # momentum components. Mathematically this looks like

        # G: \vec{m}_i \rightarrow \mathbb{R}^(6*3)                                                                 (10)

        # G(i) = M \cdot \vec{v}_i^(t+1)  where [6 x C] \cdot [C x 3] \rightarrow [6 x 3] linearly                  (11)
        
        self.readout = VectorMatrixLayer(readout_layers[0][0],readout_layers[1][0])

        # --- embeddings ------------------------------------------------------------------------------------

        # The initial scalar and vector features at each node are not of the channel shape. This provides a managerial issue
        # for the code. We really dont want to have multiple message and update functions as this causes headache. Therefore
        # we opt to just transform them into the correct shape initially. Just like the readout we need to be careful to
        # ensure that we remain equivariant. This means we may have a matrix for the vector features and then an MLP for the
        # scalar ones. 

        # Embed_vec(i) = M \cdot \vec{p}_i  where [C x num(p)] \cdot [num(p) x 3] \rightarrow [6 x 3] linearly       (12)

        # Embed_scalar(i) = MLP(p_i)                                                                                 (13)

        scalar_embed_modules = []
        for in_features, out_features in zip(scalar_embed_layers[:-1], scalar_embed_layers[1:]):
            scalar_embed_modules.append(nn.Linear(in_features, out_features))
            scalar_embed_modules.append(activation_function)
            scalar_embed_modules.append(torch.nn.Dropout(p=self.p))
        self.scalar_embed = nn.Sequential(*scalar_embed_modules[:-2])

        self.vector_embed = VectorMatrixLayer(vector_embed_layers[0][0],vector_embed_layers[1][0])

    def message(self, vectorial_feat: Tensor, scalar_feat: Tensor, node_pos: Tensor, edge_index: Tensor):
        r"""
        Parameters:
            vectorial_feat (torch.Tensor):
                Vectorial representations. Shape [B, N, C, 3]
            scalar_feat (torch.Tensor):
                Scalar representations. Shape [B, N, C]
            edge_index (torch.Tensor):
                Shape [2, E]
            node_pos (torch.Tensor):
                Atom's 3D coordinates. Shape [B, N, 3]

        Returns:
            vectorial_message (torch.Tensor):
                Shape [B, N, C, 3]
            scalar_message (torch.Tensor):
                Shape [B, N, C]
        """
        B, N, E = node_pos.shape[0], node_pos.shape[1], edge_index.shape[-1]
        source, target = edge_index # E, E
        Adj = torch.zeros(N, N, device=self.device)
        Adj[target, source] = 1  

        # vectorial quantity
        # compute all pairwise differences
        x_ij_full = node_pos[:, :, None, :] - node_pos[:, None, :, :]  # broadcasted over batch [B, N, N, 3]
        
        # create a mask to exclude diagonal of an N,N matrix
        mask = ~torch.eye(N, dtype=bool, device=node_pos.device) #  [N, N]
        
        # apply mask across batch, reshape
        x_ij_matrix = x_ij_full[:, mask].view(B, N, N - 1, 3) # [B, N, N-1, 3]

        # scalar quantity
        abs_x_ij = torch.sum(x_ij_matrix ** 2, dim = -1) ** 0.5 # [B, N, N-1]

        # cat features for f_something
        cat_feats = torch.cat([scalar_feat, abs_x_ij],dim=-1) # [B, N, C + N-1]
        
        # apply non-linearity
        transformed_scalar_feat = self.f_1(cat_feats) # [B, N, C]
        transformed_vector_feat = self.f_2(cat_feats) # [B, N, C]
        transformed_edge_feat   = self.f_3(cat_feats) # [B, N, C]

        # communicate scalars between nodes
        neighbor_sum_scalar = torch.einsum('ij,bjd->bid', Adj.to(dtype=transformed_scalar_feat.dtype), transformed_scalar_feat) # [B, N, C]

        # communicate vectors between nodes
        # unsqueeze on the transformed_vector_feat allows for conversion from [B, N, C] to [B, N, C, 1] so it gets projected
        # along all vector components.
        neighbor_sum_vector = torch.einsum('ij,bjck->bick',Adj,transformed_vector_feat.unsqueeze(-1)*vectorial_feat) # [B, N, C, 3]

        # communicate edge feats between nodes
        # The is transformed_edge_feat_expanded is duplicating along the N dimension. Giving B, N, N, C. This 
        # then gets the diagonal excluded. This is to prevent messages i-> itself In hindsight I shouldve probably just done 
        # einsum with the full distance matrix, the diagonal zeros woulda taken care of it.
        transformed_edge_feat_expanded = transformed_edge_feat[:, None, :, :].expand(B, N, N, C)[:, mask].view(B, N, N - 1, C)
        neighbor_sum_edge = torch.einsum('bnik,bnij->bnkj', transformed_edge_feat_expanded, x_ij_matrix) # [B, N, C, 3]

        # compute scalar message
        scalar_feat = scalar_feat + neighbor_sum_scalar  # [B, N, C]
        vectorial_feat = vectorial_feat + neighbor_sum_edge + neighbor_sum_vector  # [B, N, C, 3]
        
        return vectorial_feat, scalar_feat
        
    def update(self, vectorial_feat: Tensor, scalar_feat: Tensor):
        r"""
        Parameters:
            vectorial_feat (torch.Tensor):
                Vectorial representations. Shape [B, N, embedding_dim, 3]
            scalar_feat (torch.Tensor):
                Scalar representations. Shape [B, N, embedding_dim]

        Returns:
            vectorial_update (torch.Tensor):
                Shape [B, N, embedding_dim, 3]
            scalar_update (torch.Tensor):
                Shape [B, N, embedding_dim]
        """
        scalarize_vectorial_feat = (vectorial_feat**2).sum(dim=-1)**0.5 # [B, N, C]

        #                [B, N, C + C] ->  [B, N, C]
        scalar_delta = self.f_4(torch.cat([scalar_feat, scalarize_vectorial_feat],dim=-1)) # [B, N, C]

        vector_gate = self.f_5(torch.cat([scalar_feat, scalarize_vectorial_feat],dim=-1)) # [B, N, C]
        vector_delta = vector_gate.unsqueeze(dim=-1)*vectorial_feat # [B, N, C, 3]

        vectorial_update = vectorial_feat + vector_delta
        scalar_update = scalar_feat + scalar_delta

        return vectorial_update, scalar_update

    def forward(self, mass, charge, momentum, force, position):
        r"""
        Parameters:
            mass (torch.Tensor):
                Mass of each atom/node, scalar feature. Shape [B, N]
            charge (torch.Tensor):
                Charge of each atom/node, scalar feature. Shape [B, N]
            momentum (torch.Tensor):
                Momentum of each atom/node, vector feature. Shape [B, N, 3]
            force (torch.Tensor):
                Force of each atom/node, vector feature. Shape [B, N, 3]
            position (torch.Tensor):
                Position of each atom/node, vector feature. Shape [B, N, 3]

        Returns:
            momentum_later (torch.Tensor):
                Momentum of each atom/node after lag time, vector feature. Shape [B, N, 3]
            position_later (torch.Tensor):
                Position of each atom/node after lag time, vector feature. Shape [B, N, 3]
        """

        # embed the charge and the mass
        # [B, N, 2] -> [B, N, C]
        scalar_feat = self.scalar_embed(torch.cat([mass.unsqueeze(dim=-1),charge.unsqueeze(dim=-1)],dim=-1))

        # embed the force and momentum
        # [B, N, 2, 3] -> [B, N, C, 3]
        vectorial_feat = self.vector_embed(torch.stack([momentum, force], dim=2))

        # start a loop over the number of message + update steps
        for t in range(self.message_passing_steps):
            # apply message
            vectorial_feat_t, scalar_feat_t = self.message(vectorial_feat, scalar_feat, position, self.edge_index)
            # apply update
            vectorial_feat, scalar_feat = self.update(vectorial_feat_t, scalar_feat_t)
            
        # readout the new position and momentum from a linear combination of the chanels
        # [B, N, C, 3] -> [B, N, 2, 3]
        final_readout = self.readout(vectorial_feat)

        # split channels into position and momentum
        position_later = final_readout[:,:,0]
        momentum_later = final_readout[:,:,1]

        return momentum_later, position_later

In [5]:
C = 32
N = 22

f_1_layers = [C + N-1, C + N - 1, C]
f_2_layers = [C + N-1, C + N - 1, C]
f_3_layers = [C + N-1, C + N - 1, C]
f_4_layers = [C + C,   C + C    , C]
f_5_layers = [C + C,   C + C    , C]

readout_layers      = [(C,3), (2,3)]
scalar_embed_layers = [2, C, C, C]
vector_embed_layers = [(2,3), (C,3)]
                
message_passing_steps = 5
p = 0
activation_function = nn.Tanh()

EGNN = EquivariantGraphNeuralNetwork(f_1_layers, f_2_layers, f_3_layers, f_4_layers, f_5_layers,
                readout_layers, scalar_embed_layers, vector_embed_layers,
                message_passing_steps, p,  activation_function,
                device)
EGNN.to(EGNN.device)

num_iters = 100

median_total_mom_error = []
median_total_pos_error = []

mean_total_mom_error = []
mean_total_pos_error = []

eps_mom   = 0.0
eps_pos   = 0.0

with torch.no_grad():

    for blah in range(num_iters):
        # Sample rotation
        RP, R, _ = random_rotation_3d(P)      # rotate positions
        RF, _, _ = random_rotation_3d(F, R)   # share the same R for everything
        RX, _, _ = random_rotation_3d(X, R)
    
        mom_rot, pos_rot   = EGNN(M, Q, RP, RF, RX)   # rotated inputs
        mom_ref, pos_ref   = EGNN(M, Q, P,  F,  X)    # original inputs
        Rmom_ref, _, _     = random_rotation_3d(mom_ref, R) # rotate the output of the original inputs
        Rpos_ref, _, _     = random_rotation_3d(pos_ref, R)  
    
        # median errors
        median_total_mom_error.append(torch.median((mom_rot - Rmom_ref)**2).item()**0.5)
        median_total_pos_error.append(torch.median((pos_rot - Rpos_ref)**2).item()**0.5)
    
        # mean errors
        mean_total_mom_error.append(torch.mean((mom_rot - Rmom_ref)**2).item()**0.5)
        mean_total_pos_error.append(torch.mean((pos_rot - Rpos_ref)**2).item()**0.5)

        # relative equivariance error ε = ||Δ|| / ||ref||
        eps_mom += torch.linalg.vector_norm(mom_rot - Rmom_ref) / (
                   torch.linalg.vector_norm(mom_ref) + 1e-12)
        eps_pos += torch.linalg.vector_norm(pos_rot - Rpos_ref) / (
                   torch.linalg.vector_norm(pos_ref) + 1e-12)
    
    eps_mom /= num_iters
    eps_pos /= num_iters
    
    print(f"mean relative error in momentum over {num_iters} rotations: {eps_mom:.2e}")
    print(f"mean relative error in position over {num_iters} rotations: {eps_pos:.2e}")
    print()

    print(f"Median momentum RMSD averaged over {num_iters} rotations:", sum(median_total_mom_error) / num_iters)
    print(f"Median position RMSD difference averaged over {num_iters} rotations:", sum(median_total_pos_error) / num_iters)
    print()
    
    print(f"Mean momentum RMSD averaged over {num_iters} rotations:", sum(mean_total_mom_error) / num_iters)
    print(f"Mean position RMSD difference averaged over {num_iters} rotations:", sum(mean_total_pos_error) / num_iters)
    print()

mean ε_momentum over 100 rotations: 1.43e-06
mean ε_position over 100 rotations: 4.94e-06

Median momentum RMSD averaged over 100 rotations: 0.00015666961669921875
Median position RMSD difference averaged over 100 rotations: 9.467899799346924e-05

Mean momentum RMSD averaged over 100 rotations: 0.0020901413679104345
Mean position RMSD difference averaged over 100 rotations: 0.0010432152005589387

