# GNS Implementation Details

The model works by adopting a particle-based representation of the physical system. Physical dynamics are approximated by interactions among the particles. The objective of the model is to learn these interactions.

## Encoder
The encoder embeds the particle-based state representation, $X$, as a latent graph $G_0=\text{ENCODER}(X)$, where $G=(V,E,\mathbf{u})$, $\mathbf{v}_i\in V$, and $\mathbf{e}_{i,j}\in E$.
- The encoder constructs the graph structure $G^0$ by assignning a node to each particle and adding edges between particles within a connectivity radius, $R$. On each timestep the graph's edges are recomputed by a nearest neighbor algorithm, implemented by a standard kd-tree, to reflect the current particle positions.
- The node embeddings, $\mathbf{v}_i=\varepsilon ^v(x_i)$, are learned functions of the particles' states.
- The edge embeddings, $\mathbf{e}_{i,j}=\varepsilon^e(\mathbf{r}_{i,j})$, are learned functions of the pairwise properties of the corresponding particles, $\mathbf{r}_{i,j}$, e.g., displacement between their positions, spring constant, etc.
- $\varepsilon^v$ and $\varepsilon^e$ as a multilayer perceptron, which encode node features and edge features into the latent vectors, $v_i$ and $e_{i,j}$, of size $128$.
- The graph-level embedding, $\mathbf{u}$, could represent global properties such as gravity and magnetic fields. Although, this is currently implemented as node level features instead.

## Processor
The *PROCESSOR* performs $M$ rounds of learned message-passing over the latent graphs, $G^0, ..., G^M$. This is implemented by a stack of $M$ GNs with identical structure, MLPs as internal edge and node update functions, and either shared or unshared parameters. GNs are used without global features or global updates, and with a residual connection between the input and output latent node and edge attributes.

## Decoder
The decoder's learned function, $\delta^v$, is an MLP. After the *DECODER*, the future position and velocity are updated using an Euler integrator, so the $\mathbf{y}_i$ corresponds to accelerations, $\ddot{p}^t_i$.

## Neural network parameters
- All MLPs have **2 hidden layers** with **ReLU** activations, followed by a non-activated output layer, each layer with size **128**.
- All MLPs, except the output decoder, are followed by a **LayerNorm** layer.


## Loss function
The loss function is the $L_2$ loss on the prediceted per-particle accelerations. The model parameters are optimizied over the loss using the Adam optimizer, using a minibatch size of 2.


## Update mechanism
The implementation uses semi-implicit Euler integration to update the next state based on the predicted accelerations:
$$\dot{\mathbf{p}}^{t+1}=\dot{\mathbf{p}}^t+\Delta t\cdot \ddot{\mathbf{p}}^t $$
$$\mathbf{p}^{t+1}=\mathbf{p}^t+\Delta t\cdot \dot{\mathbf{p}}^t$$
where $\Delta t=1$. 

# Creating Message Passing Networks using Pytorch Geometric



In [None]:
%%writefile open_gns/models/encoder.py

import torch
from torch.nn import Sequential, Linear, ReLU, LayerNorm
from torch_geometric.nn import MetaLayer

class EdgeModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size=128):
        super(EdgeModel, self).__init__()
        self.edge_mlp = Sequential(
            Linear(input_size, hidden_size),
            ReLU(),
            Linear(hidden_size, hidden_size),
            ReLU(),
            Linear(hidden_size, hidden_size),
            LayerNorm(hidden_size)
        )
    def forward(self, src, dest, edge_attr, u, batch):
        out = torch.cat([src, dest])
        print(f'edge out: {out.size()}')
        return self.edge_mlp(out)

class NodeModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size=128):
        super(NodeModel, self).__init__()
        self.node_mlp = Sequential(
            Linear(input_size, hidden_size),
            ReLU(),
            Linear(hidden_size, hidden_size),
            ReLU(),
            Linear(hidden_size, hidden_size),
            LayerNorm(hidden_size)
        )

    def forward(self, x, edge_index, edge_attr, u, batch):
        return self.mlp(x)
                
class Encoder(torch.nn.Module):
    def __init__(self, input_size, output_size=128):
        super(Encoder, self).__init__()
        self.encoder = MetaLayer(EdgeModel(2*input_size), NodeModel(input_size))
    
    def forward(self, x, edge_index):
        # TODO: The encoder needs to build the Graph
        # otherwise the graph would need to be pre-built
        print(x, edge_index)
        return self.encoder(x, edge_index)


In [13]:
%load_ext autoreload
%autoreload 2
from torch_geometric.nn import MetaLayer
from open_gns.dataset import GNSDataset
from open_gns.models.encoder import Encoder, EdgeModel, NodeModel
# Load dataset
dataset = GNSDataset('./notebooks')
print(len(dataset), dataset.num_edge_features, dataset.num_node_features)
data = dataset[0]
input_size= dataset.num_node_features
encoder = Encoder(input_size)
x, e, u = encoder(data.x, data.edge_index)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
143 0 20
tensor([[ 8.2808e-01,  1.0000e-02,  1.1584e-01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 8.2808e-01,  1.0000e-02,  1.6584e-01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 8.2808e-01,  1.0000e-02,  2.1584e-01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 3.6678e-01,  6.9551e-01,  2.7248e-01,  ..., -2.1372e-03,
         -1.6259e-02,  5.4002e-05],
        [ 3.6827e-01,  6.9732e-01,  3.2329e-01,  ..., -1.7999e-03,
         -1.4621e-02, -1.2994e-03],
        [ 3.6667e-01,  6.9900e-01,  3.6536e-01,  ..., -2.7943e-03,
         -1.4815e-02, -3.0315e-03]], dtype=torch.float64) tensor([[   1,    4,    5,  ..., 1014, 1015, 1022],
        [   0,    0,    0,  ..., 1023, 1023, 1023]])
edge out: torch.Size([30880, 20])


RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat1' in call to _th_addmm

# Parameters
Parameters for BoxBath:
- Trajectory length: 150
- Number of rollouts: Train/Validation/Test -> 2700/150/150
- Connectivity radius: $R=0.08$

In [None]:
%cd ".."
!pwd