In [105]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, radius_graph
import enum
import torch
import torch.nn as nn
from torch.nn import Sequential
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import LayerNorm
import torch.nn.functional as F
from torch import optim
import torch_scatter
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from tqdm import trange
import random


# Graph construction

In [75]:
# Acceleration data
df = pd.read_csv("Data/Data4.csv")
df.drop(columns=["0"], inplace=True)

# Coordinates data
df_coord = np.array([[2.5, 5*i] for i in range(0, 31)])

In [76]:
# Initialize the data_list
data_list = []

# Time step
dt = 1e-4

# Establish some data
number_trajectories = 1
number_ts = -1

# Create file_path
data_path = "Data"

In [77]:
for dt in range(df.shape[0]):
    if dt == number_ts:
        break

    # Get acceleration
    acceleration = torch.tensor(df.iloc[dt].values, dtype=torch.float).unsqueeze(1)

    # Get edge_index
    coordinates = torch.tensor(df_coord, dtype=torch.float)
    edges_index = radius_graph(coordinates, r=1.1*5, loop=False).squeeze(0).type(torch.long)

    # Get edge_attr
    u_i = coordinates[edges_index[0]]
    u_j = coordinates[edges_index[1]]
    u_ij = u_i - u_j
    u_ij_norm = torch.norm(u_ij, p=2, dim=1, keepdim=True)
    edge_attr = torch.cat([u_ij, u_ij_norm], dim=-1).type(torch.float)
    
    # Store data
    data_list.append(Data(x=acceleration, edge_index=edges_index, edge_attr=edge_attr, y=acceleration))
        
print("Done collecting data!")

# Save 
torch.save(data_list, os.path.join(data_path, "Data4.pt"))
print("Data saved!")
print("Output Location: ", os.path.join(data_path, "Data4.pt"))

Done collecting data!
Data saved!
Output Location:  Data\Data4.pt


# Normalization
Normalization is necessary for the features and output parameters to zero mean and unit variance in order to stabilize training. The method defined below, get_stats(), is run before training. It accepts the processed data_list, calculates the mean and standard deviation for the node features, edge features, and node outputs, and normalizes these using the calculated statistics.

In [78]:
def normalize(to_normalize,mean_vec,std_vec):
    return (to_normalize-mean_vec)/std_vec

def unnormalize(to_unnormalize,mean_vec,std_vec):
    return to_unnormalize*std_vec+mean_vec

def get_stats(data_list):
    '''
    Method for normalizing processed datasets. Given  the processed data_list,
    calculates the mean and standard deviation for the node features, edge features,
    and node outputs, and normalizes these using the calculated statistics.
    '''

    # Mean and std of the node features are calculated
    mean_vec_x=torch.zeros(data_list[0].x.shape[1:])
    std_vec_x=torch.zeros(data_list[0].x.shape[1:])

    # Mean and std of the edge features are calculated
    mean_vec_edge=torch.zeros(data_list[0].edge_attr.shape[1:])
    std_vec_edge=torch.zeros(data_list[0].edge_attr.shape[1:])

    # Mean and std of the output parameters are calculated
    mean_vec_y=torch.zeros(data_list[0].y.shape[1:])
    std_vec_y=torch.zeros(data_list[0].y.shape[1:])

    # Define the maximum number of accumulations to perform such that we do
    # not encounter memory issues
    max_accumulations = 10**6

    #Define a very small value for normalizing to
    eps=torch.tensor(1e-8)

    #Define counters used in normalization
    num_accs_x = 0
    num_accs_edge = 0
    num_accs_y = 0

    #Iterate through the data in the list to accumulate statistics
    for dp in data_list:

        # Add to the mean and std vectors for the node features
        mean_vec_x+=torch.sum(dp.x,dim=0)
        std_vec_x+=torch.sum(dp.x**2,dim=0)
        num_accs_x+=dp.x.shape[0]

        # Add to the mean and std vectors for the edge features
        mean_vec_edge+=torch.sum(dp.edge_attr,dim=0)
        std_vec_edge+=torch.sum(dp.edge_attr**2,dim=0)
        num_accs_edge+=dp.edge_attr.shape[0]

        # Add to the mean and std vectors for the node outputs
        mean_vec_y+=torch.sum(dp.y,dim=0)
        std_vec_y+=torch.sum(dp.y**2,dim=0)
        num_accs_y+=dp.y.shape[0]

        if(num_accs_x>max_accumulations or num_accs_edge>max_accumulations or num_accs_y>max_accumulations):
            break

    mean_vec_x = mean_vec_x/num_accs_x
    std_vec_x = torch.maximum(torch.sqrt(std_vec_x/num_accs_x - mean_vec_x**2),eps)

    mean_vec_edge = mean_vec_edge/num_accs_edge
    std_vec_edge = torch.maximum(torch.sqrt(std_vec_edge/num_accs_edge - mean_vec_edge**2),eps)

    mean_vec_y = mean_vec_y/num_accs_y
    std_vec_y = torch.maximum(torch.sqrt(std_vec_y/num_accs_y - mean_vec_y**2),eps)

    mean_std_list=[mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge,mean_vec_y,std_vec_y]

    return mean_std_list

# Encoder

In [79]:
class Encoder(torch.nn.Module):
    def __init__(self, input_dim_node, input_dim_edge, hidden_dim):
        super(Encoder, self).__init__()

        self.node_encoder = Sequential(Linear(input_dim_node , hidden_dim),
                                ReLU(),
                                Linear( hidden_dim, hidden_dim),
                                LayerNorm(hidden_dim))
        
        self.edge_encoder = Sequential(Linear(input_dim_edge , hidden_dim),
                                ReLU(),
                                Linear( hidden_dim, hidden_dim),
                                LayerNorm(hidden_dim))

    def forward(self, x, edge_attr):
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        return x, edge_attr

# Decoder

In [80]:
class Decoder(torch.nn.Module):
    def __init__(self, hidden_dim, output_dim_node):
        super(Decoder, self).__init__()

        self.node_decoder = Sequential(Linear(hidden_dim, hidden_dim),
                                ReLU(),
                                Linear(hidden_dim, output_dim_node))
    
    def forward(self, x):
        return self.node_decoder(x)
        

# GNN

In [81]:
class MeshGraphNet(torch.nn.Module):
    def __init__(self, input_dim_node, input_dim_edge, hidden_dim, output_dim, args, emb=False):
        super(MeshGraphNet, self).__init__()

        self.num_layers = args.num_layers

        # Encoder
        self.encoder = Encoder(input_dim_node, input_dim_edge, hidden_dim)

        # Processor
        self.processor = nn.ModuleList()
        assert (self.num_layers >= 1), 'Number of message passing layers is not >=1'
        
        processor_layer=self.build_processor_model()
        for _ in range(self.num_layers):
            self.processor.append(processor_layer(hidden_dim,hidden_dim))

        # Decoder
        self.decoder = Decoder(hidden_dim, output_dim)

    def build_processor_model(self):
        return ProcessorLayer

    def forward(self,data,mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge):

        x, edge_attr, edge_index = data.x, data.edge_attr, data.edge_index

        # Normalize the input data
        x = normalize(x,mean_vec_x,std_vec_x)
        edge_attr = normalize(edge_attr,mean_vec_edge,std_vec_edge)

        # Encoder
        x, edge_attr = self.encoder(x, edge_attr)

        # Processor
        for i in range(self.num_layers):
            x, edge_attr = self.processor[i](x, edge_index, edge_attr)

        # Decoder
        x = self.decoder(x)

        return x

    def loss(self, pred, inputs, mean_vec_y, std_vec_y):

        # Normalize the output data
        y = inputs.y
        y = normalize(y,mean_vec_y,std_vec_y)

        # Calculate the loss
        loss = torch.sum((pred-y)**2, axis=1)

        # Calculate the sqrt loss
        loss = torch.sqrt(loss)

        return loss 

# ProcessorLayer Class: Edge Message Passing, Aggregation, and Updating

Now let's implement the processor, which overrides "[MessagePassing](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html)" base class. Following the prototype of the base class, we need to implement three main methods, namely message passing, aggregation, and updating. Also, two types of MLP layers, namely node MLP and edge MLP, are defined and used during the construction of processor, whose details will be given in the cell bellow.

Essentailly, our processor class serves as the GNN layers composed of message passing, aggregation, and updating, updating information at each layer of the computational graph for each node. The message passing process can be described as:

1.   **Message passing**

Initiated by the propagate function, the message function most generally calculates messages, m, for edge u at layer l with function MSG given previous embeddings h_u:
$$m_u^{(l)}=MSG^{(l)}(h_u^{(l-1)})$$

Note that for MeshGraphNets, messages are calculated for edges and passed to nodes. This function thus takes edge embeddings and the adjacent node embeddings and concatenates them. These concatenated previous embeddings constitute h_u above. These are then put through an MLP (our MSG function) to give the final messages, m_u, which are passed to the aggregate function.

2.   **Aggregation**

Aggregation takes the updated edge embeddings and aggregates then over the connectivity matrix indexing using sum reduction. Most generally, we have:

$$h_v^{(l)}=AGG^{(l)}(\{m_u^{(l)},u\in N(v)\})$$

For MeshGraphNets, aggregation (AGG) for node v is sum over the neighbor nodes. However, there is also an additional aggregation step: aggregating with the self embedding. This is done outside of the aggregation function, in the forward function after the return of propagate:

$$h_v^{(l)}=\{h_v^{(l-1)},AGG^{(l)}(\{m_u^{(l)},u\in N(v)\})\}$$

3.   **Updating**

The nodes embeddings are finally updated by passing $h_v^{(l)}$ through the node MLP with a skip connection. This is most generally written as:

$$h_v^{(l)}=Processor(h_v^{(l)})$$

Where for us the Processor is an MLP.


In [82]:
class ProcessorLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(ProcessorLayer, self).__init__(**kwargs)

        # Note that the node and edge encoders both have the same hidden dimension
        # size. This means that the input of the edge processor will always be
        # three times the specified hidden dimension
        # (input: adjacent node embeddings and self embeddings)
        self.edge_mlp = Sequential(Linear( 3* in_channels , out_channels),
                                   ReLU(),
                                   Linear( out_channels, out_channels),
                                   LayerNorm(out_channels))

        self.node_mlp = Sequential(Linear( 2* in_channels , out_channels),
                                   ReLU(),
                                   Linear( out_channels, out_channels),
                                   LayerNorm(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.edge_mlp[0].reset_parameters()
        self.edge_mlp[2].reset_parameters()

        self.node_mlp[0].reset_parameters()
        self.node_mlp[2].reset_parameters()


    def forward(self, x, edge_index, edge_attr, size=None):

        out, updated_edges = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) # Out has the shape of [E, out_channels]
        updated_nodes = torch.cat([x, out], dim=1) # Complete the aggregation through self-aggregation
        updated_nodes = x + self.node_mlp(updated_nodes) # Residual connection

        return updated_nodes, updated_edges

    def message(self, x_i, x_j, edge_attr):

        updated_edges = torch.cat([x_i, x_j, edge_attr], dim=1) # Shape of [E, 3*in_channels]
        updated_edges = self.edge_mlp(updated_edges)

        return updated_edges

    def aggregate(self, updated_edges, edge_index, dim_size = None):
        
        # The axis along which to index number of nodes.
        node_dim = 0

        out = torch_scatter.scatter(updated_edges, edge_index[0, :], dim=node_dim, reduce = 'sum')

        return out, updated_edges



In [83]:
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d

In [88]:
args = {
    "num_layers": 2,
    "hidden_dim": 64,
    "output_dim": 1
}

args = objectview(args)
model = MeshGraphNet(1, 3, 64, 1, args)

In [91]:
stats = get_stats(data_list)
mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y = stats

model(data_list[0], mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge)
print("Model works!")

Model works!


# Optimizer

In [92]:
def build_optimizer(args, params):
    weight_decay = args.weight_decay
    filter_fn = filter(lambda p : p.requires_grad, params)
    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)
    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)
    return scheduler, optimizer

# Training

In [111]:
def train(dataset, stats_list, args, device="cpu"):

    df = pd.DataFrame(columns=["Epoch", "Loss"])

    # Define the model name
    model_name = 'model_nl'+str(args.num_layers)+'_bs'+str(args.batch_size) + \
               '_hd'+str(args.hidden_dim)+'_ep'+str(args.epochs)+'_wd'+str(args.weight_decay) + \
               '_lr'+str(args.lr)+'_shuff_'+str(args.shuffle)+'_tr'+str(args.train_size)+'_te'+str(args.test_size)

    # DataLoader
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    # The statistics of the data decomposed
    [mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y] = stats_list
    (mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y) = (mean_vec_x.to(device), std_vec_x.to(device), mean_vec_edge.to(device), std_vec_edge.to(device), mean_vec_y.to(device), std_vec_y.to(device))

    # Build model
    num_node_features = dataset[0].x.shape[1]
    num_edge_features = dataset[0].edge_attr.shape[1]
    output_dim = 1

    model = MeshGraphNet(num_node_features, num_edge_features, args.hidden_dim, output_dim, args).to(device)
    scheduler, optimizer = build_optimizer(args, model.parameters())

    # Training
    losses = []
    best_model = None

    for epoch in trange(args.epochs, desc="Training", unit="Epochs"):
        total_loss = 0
        model.train()
        num_loops = 0
        for batch in loader:
            optimizer.zero_grad()
            batch = batch.to(device)
            pred = model(batch, mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge)
            loss = model.loss(pred, batch, mean_vec_y, std_vec_y)
            loss = loss.mean()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num_loops += 1

        total_loss /= num_loops
        losses.append(total_loss)

        if epoch % 100 == 0:
            print(f"Epoch {epoch} | Loss: {total_loss}")

    return model, losses



In [112]:
for args in [
        {'model_type': 'meshgraphnet',
         'num_layers': 2,
         'batch_size': 16,
         'hidden_dim': 10,
         'epochs': 500,
         'opt': 'adam',
         'opt_scheduler': 'none',
         'opt_restart': 0,
         'weight_decay': 5e-4,
         'lr': 0.001,
         'train_size': 45,
         'test_size': 10,
         'device':'cpu',
         'shuffle': True,
         'save_velo_val': True,
         'save_best_model': True,
         'checkpoint_dir': './best_models/',
         'postprocess_dir': './2d_loss_plots/'},
    ]:
        args = objectview(args)

torch.manual_seed(5)  #Torch
random.seed(5)        #Python
np.random.seed(5)     #NumPy

In [109]:
# Load dataset
dataset = torch.load("Data/Data4.pt")

stats_list = get_stats(dataset)

In [110]:
# Train
model, losses = train(dataset, stats_list, args)

Training:   0%|          | 0/500 [00:00<?, ?Epochs/s]

Training:   1%|          | 5/500 [01:19<2:10:57, 15.87s/Epochs]


KeyboardInterrupt: 