# GNS Implementation Details

The model works by adopting a particle-based representation of the physical system. Physical dynamics are approximated by interactions among the particles. The objective of the model is to learn these interactions.

## Encoder
The encoder embeds the particle-based state representation, $X$, as a latent graph $G_0=\text{ENCODER}(X)$, where $G=(V,E,\mathbf{u})$, $\mathbf{v}_i\in V$, and $\mathbf{e}_{i,j}\in E$.
- The encoder constructs the graph structure $G^0$ by assignning a node to each particle and adding edges between particles within a connectivity radius, $R$. On each timestep the graph's edges are recomputed by a nearest neighbor algorithm, implemented by a standard kd-tree, to reflect the current particle positions.
- The node embeddings, $\mathbf{v}_i=\varepsilon ^v(x_i)$, are learned functions of the particles' states.
- The edge embeddings, $\mathbf{e}_{i,j}=\varepsilon^e(\mathbf{r}_{i,j})$, are learned functions of the pairwise properties of the corresponding particles, $\mathbf{r}_{i,j}$, e.g., displacement between their positions, spring constant, etc.
- $\varepsilon^v$ and $\varepsilon^e$ as a multilayer perceptron, which encode node features and edge features into the latent vectors, $v_i$ and $e_{i,j}$, of size $128$.
- The graph-level embedding, $\mathbf{u}$, could represent global properties such as gravity and magnetic fields. Although, this is currently implemented as node level features instead.


## Neural network parameters
- All MLPs have **2 hidden layers** with **ReLU** activations, followed by a non-activated output layer, each layer with size **128**.
- All MLPs, except the output decoder, are followed by a **LayerNorm** layer.



In [4]:
# setup autoreload
#%load_ext autoreload
#%autoreload 2
# Change working directory to root directory
%cd ".."

/workspace


In [10]:
%%writefile ../open_gns/models/encoder.py

import torch
from torch.nn import Sequential, Linear, ReLU, LayerNorm
from torch_geometric.nn import MetaLayer

def make_mlp(input_size, hidden_size=128, output_size=128, layer_norm=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    layers = [
        Linear(input_size, hidden_size),
            ReLU(),
            Linear(hidden_size, hidden_size),
            ReLU(),
            Linear(hidden_size, output_size),
    ]
    if layer_norm:
        layers.append(LayerNorm(output_size))
    return Sequential(*layers).to(device)

class EdgeModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size=128):
        super(EdgeModel, self).__init__()
        self.edge_mlp = make_mlp(input_size, hidden_size)

    def forward(self, src, dest, edge_attr, u, batch):
        features = [src, dest] if edge_attr is None else [src, dest, edge_attr]
        out = torch.cat(features, 1)
        return self.edge_mlp(out)

class NodeModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size=128):
        super(NodeModel, self).__init__()
        self.node_mlp = make_mlp(input_size, hidden_size)

    def forward(self, x, edge_index, edge_attr, u, batch):
        # TODO: Do we need to combine with edge_attr?
        return self.node_mlp(x)
                
class Encoder(torch.nn.Module):
    def __init__(self, input_size, output_size=128):
        super(Encoder, self).__init__()
        self.encoder = MetaLayer(EdgeModel(2*input_size), NodeModel(input_size))

    def forward(self, x, edge_index):
        # TODO: The encoder needs to build the Graph
        # otherwise the graph would need to be pre-built
        return self.encoder(x, edge_index)


Overwriting ../open_gns/models/encoder.py


In [None]:
%autoreload 2
# Testing the Encoder
from open_gns.dataset import GNSDataset
from open_gns.models.encoder import Encoder
# Load dataset
dataset = GNSDataset('./notebooks')
print(f'Samples: {len(dataset)}, Num node features: {dataset.num_node_features}')
data = dataset[0]
input_size= dataset.num_node_features
encoder = Encoder(input_size)
x, e, u = encoder(data.x.float(), data.edge_index)
print(x.size(), e.size())

## Processor
The *PROCESSOR* performs $M$ rounds of learned message-passing over the latent graphs, $G^0, ..., G^M$. This is implemented by a stack of $M$ GNs with identical structure, MLPs as internal edge and node update functions, and either shared or unshared parameters. GNs are used without global features or global updates, and with a residual connection between the input and output latent node and edge attributes.

In [10]:
%%writefile open_gns/models/processor.py

import torch
from torch_geometric.nn import MetaLayer
from open_gns.models.encoder import EdgeModel, NodeModel

class Processor(torch.nn.Module):
    def __init__(self, input_size, output_size=128, M=10):
        super(Processor, self).__init__()
        self.GNs = torch.nn.ModuleList([])
        for i in range(M):
            GN = MetaLayer(EdgeModel(3*input_size), NodeModel(input_size))
            self.GNs.append(GN) 

    def forward(self, x, edge_index, edge_attr):
        for GN in self.GNs:
            # TODO: Concatenate residuals instead?
            # Keep residuals
            node_residual = x
            edge_residual = edge_attr
            # Apply GN
            x, edge_attr, u = GN(x, edge_index, edge_attr)
            # Add residuals
            x = x + node_residual
            edge_attr = edge_attr + edge_residual
        return x, edge_attr, u

Overwriting open_gns/models/processor.py


In [None]:
# Testing the processor
from open_gns.models.processor import Processor

processor = Processor(128)
print(e.size())
x, e, u = processor(x, data.edge_index, e)
print(x.size(), e.size())

## Decoder
The *DECODER* extracts dynamic information from the nodes of the final latent graph, $G^M$, produced by the *PROCESSOR*. The decoder's learned function, $\delta^v$, is an MLP. After the *DECODER*, the future position and velocity are updated using an Euler integrator, so the $\mathbf{y}_i$ corresponds to accelerations, $\ddot{p}^t_i$.

In [None]:
%%writefile open_gns/models/decoder.py

import torch
from open_gns.models.encoder import make_mlp

class Decoder(torch.nn.Module):
    def __init__(self, input_size):
        super(Decoder, self).__init__()
        self.decoder = make_mlp(input_size, output_size=3, layer_norm=False)
    
    def forward(self, x):
        return self.decoder(x)



In [None]:
from open_gns.models.decoder import Decoder
# Test the decoder
decorder = Decoder(input_size=x.size(1))
out = decorder(x)
print(out.size())

In [12]:
%%writefile ../open_gns/models/__init__.py
import torch
from open_gns.models.encoder import Encoder
from open_gns.models.processor import Processor
from open_gns.models.decoder import Decoder

class EncodeProcessDecode(torch.nn.Module):
    def __init__(self, input_size, hidden_size=128):
        super(EncodeProcessDecode, self).__init__()
        self.encoder = Encoder(input_size)
        self.processor = Processor(hidden_size)
        self.decoder = Decoder(hidden_size)

    def forward(self, x, edge_index):
        x, edge_attr, _ = self.encoder(x, edge_index)
        x, edge_attr, _ = self.processor(x, edge_index, edge_attr)
        return self.decoder(x)

Overwriting ../open_gns/models/__init__.py


In [None]:
gns = EncodeProcessDecode(input_size)
y = gns(data.x.float(), data.edge_index)
print(y.size())

# Parameters
Parameters for BoxBath:

- Connectivity radius: $R=0.08$