# GNN based on NHERI AI Academy 2024

So, let's use the same architecture used in the previous version.

In [71]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as pl
import seaborn as sns

import torch 
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, radius_graph
from torch_geometric.utils import add_self_loops
import torch_cluster
import torch_scatter

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import trange
from torch_geometric.data import Data
from torch_geometric.nn import radius_graph
from torch_geometric.utils import add_self_loops
from torch_geometric.utils import to_networkx
import networkx as nx

## Data
Case 1a

In [72]:
dataset = torch.load(r"Data\Data4.pt")

# Model

## Classic MLP

In [73]:
class MLP(nn.Module):
    """
    Create a simple MLP
    """
    def __init__(
        self,
        input_size,
        layer_sizes,
        output_size=None,
        output_activation=nn.Identity,
        activation=nn.ReLU,
        layernorm=True
    ):
        super(MLP, self).__init__()
        sizes = [input_size] + layer_sizes
        if output_size is not None:
            sizes.append(output_size)
        
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1]))
            if i < len(sizes) - 2:
                if layernorm:
                    layers.append(nn.LayerNorm(sizes[i + 1]))
                layers.append(activation())
            else:
                layers.append(output_activation())
        
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

## Encoder

In [74]:
class Encoder(nn.Module):
    def __init__(self, in_node, in_edge, out_node, out_edge, mlp_num_layers, mlp_hidden_dim):
        super(Encoder, self).__init__()
        self.node_encoder = MLP(in_node, [mlp_hidden_dim] * mlp_num_layers, out_node)
        self.edge_encoder = MLP(in_edge, [mlp_hidden_dim] * mlp_num_layers, out_edge)

    def forward(self, x, edge_features):
        return self.node_encoder(x), self.edge_encoder(edge_features)


## Decoder

In [75]:
class Decoder(nn.Module):
    def __init__(self, in_node, in_edge, out_node, out_edge, mlp_num_layers, mlp_hidden_dim):
        super(Decoder, self).__init__()
        self.node_decoder = MLP(in_node, [mlp_hidden_dim] * mlp_num_layers, out_node)
        self.edge_decoder = MLP(in_edge, [mlp_hidden_dim] * mlp_num_layers, out_edge)

    def forward(self, x, edge_features):
        return self.node_decoder(x), self.edge_decoder(edge_features)

## Message Passing

In [76]:
class InteractionNetwork(MessagePassing):
    def __init__(self, node_fn, edge_fn):
        super(InteractionNetwork, self).__init__()
        self.node_fn = node_fn
        self.edge_fn = edge_fn

    def forward(self, x, edge_index, edge_attr):
        edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_features=edge_attr)
        node_out = self.node_fn(torch.cat([x, aggr], dim=1))
        return x + node_out, edge_attr + edge_out

    def message(self, xi, xj, edge_features):
        return self.edge_fn(torch.cat([xi, xj, edge_features], dim=1))

    def aggregate(self, inputs, index, dim_size=None):
        return torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum")

## Processor

In [77]:
class Processor(MessagePassing):
    def __init__(self, node_in, node_out, edge_in, edge_out, num_passing_steps, mlp_num_layers, mlp_hidden_dim):
        super(Processor, self).__init__()
        
        self.gnn_stacks = nn.ModuleList([InteractionNetwork(
            node_fn = MLP(node_in * 2 + edge_in, [mlp_hidden_dim] * mlp_num_layers, node_out),
            edge_fn = MLP(node_in * 2 + edge_in, [mlp_hidden_dim] * mlp_num_layers, edge_out)
        )] * num_passing_steps)

    def forward(self, x, edge_index, edge_attr):
        for gnn in self.gnn_stacks:
            x, edge_attr = gnn(x, edge_index, edge_attr)
        return x, edge_attr

In [78]:
class GNN(nn.Module):
    def __init__(self, node_in, node_out, edge_in, edge_out, latent_dim, num_passing_steps, mlp_num_layers, mlp_hidden_dim):
        super(GNN, self).__init__()
        self.encoder = Encoder(node_in, edge_in, latent_dim, latent_dim, mlp_num_layers, mlp_hidden_dim)
        self.processor = Processor(latent_dim, latent_dim, latent_dim, latent_dim, num_passing_steps, mlp_num_layers, mlp_hidden_dim)
        self.decoder = Decoder(latent_dim, latent_dim, node_out, edge_out, mlp_num_layers, mlp_hidden_dim)

    def forward(self, x, edge_index, edge_attr):
        x, edge_attr = self.encoder(x, edge_attr)
        x, edge_attr = self.processor(x, edge_index, edge_attr)
        x, edge_attr = self.decoder(x, edge_attr)
        return x, edge_attr

In [81]:
idx = 0
data = dataset[idx]
x = data.x
edge_index = data.edge_index
edge_attr = data.edge_attr
y = data.y
print(x.shape, edge_index.shape, edge_attr.shape, y.shape)

torch.Size([1, 6]) torch.Size([2, 10]) torch.Size([10, 3]) torch.Size([1, 6])


In [56]:
class Simulator(nn.Module):
    def __init__(self, 
                 node_in: int, edge_in: int, 
                 node_out: int, edge_out: int,
                 latent_dim: int, num_passing_steps: int, 
                 mlp_num_layers: int, mlp_hidden_dim: int, 
                 connectivity_radius: float):
        super(Simulator, self).__init__()
        self._connectivity_radius = connectivity_radius
        self.GNN = GNN(node_in, node_out, edge_in, edge_out, latent_dim, num_passing_steps, mlp_num_layers, mlp_hidden_dim)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        return self.GNN(x, edge_index, edge_attr)

    def create_graph(self, input: torch.Tensor, target: torch.Tensor, coordinates: torch.Tensor):
        edge_index = radius_graph(coordinates, r=self._connectivity_radius, loop=False)
        
        u_i = coordinates[edge_index[0]]
        u_j = coordinates[edge_index[1]]
        u_ij = u_i - u_j
        u_ij_norm = torch.norm(u_ij, p=2, dim=1, keepdim=True)
        edge_attr = torch.cat([u_ij, u_ij_norm], dim=-1).float()

        return input, edge_index, edge_attr, target

    def predict_step(self, input: torch.Tensor, target: torch.Tensor, coordinates: torch.Tensor):
        input, edge_index, edge_attr, target = self.create_graph(input, target, coordinates)
        output = self.forward(input, edge_index, edge_attr)
        return output, target