In [1]:
import torch
import random
import pandas as pd
import torch_scatter
import torch.nn as nn
from torch.nn import Linear, Sequential, LayerNorm, ReLU
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import DataLoader

import numpy as np
import time
import torch.optim as optim
from tqdm import trange
import pandas as pd
import copy
import matplotlib.pyplot as plt
import os


print("PyTorch has version {}".format(torch.__version__))

PyTorch has version 1.13.1+cu117


### Get the path for various directories.

In [2]:
root_dir = os.getcwd()
dataset_dir = os.path.join(root_dir, 'Airflow')
checkpoint_dir = os.path.join(root_dir, 'best_models')
postprocess_dir = os.path.join(root_dir, 'animations')

In [3]:
print("dataset_dir {}".format(dataset_dir))

dataset_dir F:\GNNsim2\Airflow


### Import libraries involved in loading the dataset

In [5]:
import numpy as np
import torch
import h5py
import tensorflow.compat.v1 as tf
import functools
import json
from torch_geometric.data import Data
import enum

### Lets Look at the metadata about with Airfoil dataset

In [8]:
with open(os.path.join(root_dir, 'Airfoil/meta.json')) as f:
    metadata = json.load(f)

In [9]:
metadata

{'simulator': 'su2',
 'dt': 0.0002,
 'collision_radius': None,
 'features': {'node_type': {'type': 'static',
   'shape': [1, 5233, 1],
   'dtype': 'int32'},
  'cells': {'type': 'static', 'shape': [1, 10216, 3], 'dtype': 'int32'},
  'mesh_pos': {'type': 'static', 'shape': [1, 5233, 2], 'dtype': 'float32'},
  'density': {'type': 'dynamic', 'shape': [601, 5233, 1], 'dtype': 'float32'},
  'pressure': {'type': 'dynamic', 'shape': [601, 5233, 1], 'dtype': 'float32'},
  'velocity': {'type': 'dynamic',
   'shape': [601, 5233, 2],
   'dtype': 'float32'}},
 'field_names': ['node_type',
  'cells',
  'mesh_pos',
  'density',
  'pressure',
  'velocity'],
 'trajectory_length': 601}

### Load the dataset!

In [15]:
def _parse(proto, meta):
  """Parses a trajectory from tf.Example."""
  feature_lists = {k: tf.io.VarLenFeature(tf.string)
                   for k in meta['field_names']}
  features = tf.io.parse_single_example(proto, feature_lists)
  out = {}
  for key, field in meta['features'].items():
    data = tf.io.decode_raw(features[key].values, getattr(tf, field['dtype']))
    data = tf.reshape(data, field['shape'])
    if field['type'] == 'static':
      data = tf.tile(data, [meta['trajectory_length'], 1, 1])
    elif field['type'] == 'dynamic_varlen':
      length = tf.io.decode_raw(features['length_'+key].values, tf.int32)
      length = tf.reshape(length, [-1])
      data = tf.RaggedTensor.from_row_lengths(data, row_lengths=length)
    elif field['type'] != 'dynamic':
      raise ValueError('invalid data format')
    out[key] = data
  return out


def load_dataset(split):
  """Load dataset."""
  with open(os.path.join(root_dir, 'Airfoil/meta.json'), 'r') as fp:
    meta = json.loads(fp.read())
  ds = tf.data.TFRecordDataset(os.path.join(root_dir, 'Airfoil/'+split+'.tfrecord'))
  ds = ds.map(functools.partial(_parse, meta=meta), num_parallel_calls=8)
  ds = ds.prefetch(1)
  return ds

In [16]:
# Fetch the data to a list variable 'l'
ds = load_dataset('test')
ds = ds.flat_map(tf.data.Dataset.from_tensor_slices)
l = list(ds.prefetch(0))

In [18]:
#The length of data is:
len(l)

60100

### Utility functions 

Here we define the functions that are needed for assisting in data processing.

triangle_to_edges:  decomposes 2D triangular meshes to edges and returns the undirected graph nodes. 

NodeType: is subclass of enum with unique and unchanging integer valued attributes over instances in order to make sure values are unchanged

In [19]:
#Utility functions, provided in the release of the code from the original MeshGraphNets study:
#https://github.com/deepmind/deepmind-research/tree/master/meshgraphnets

def triangles_to_edges(faces):
  """Computes mesh edges from triangles.
     Note that this triangles_to_edges method was provided as part of the
     code release for the MeshGraphNets paper by DeepMind, available here:
     https://github.com/deepmind/deepmind-research/tree/master/meshgraphnets
  """
  # collect edges from triangles
  edges = tf.concat([faces[:, 0:2],
                     faces[:, 1:3],
                     tf.stack([faces[:, 2], faces[:, 0]], axis=1)], axis=0)
  # those edges are sometimes duplicated (within the mesh) and sometimes
  # single (at the mesh boundary).
  # sort & pack edges as single tf.int64
  receivers = tf.reduce_min(edges, axis=1)
  senders = tf.reduce_max(edges, axis=1)
  packed_edges = tf.bitcast(tf.stack([senders, receivers], axis=1), tf.int64)
  # remove duplicates and unpack
  unique_edges = tf.bitcast(tf.unique(packed_edges)[0], tf.int32)
  senders, receivers = tf.unstack(unique_edges, axis=1)
  # create two-way connectivity
  return (tf.concat([senders, receivers], axis=0),
          tf.concat([receivers, senders], axis=0))



class NodeType(enum.IntEnum):
    """
    Define the code for the one-hot vector representing the node types.
    Note that this is consistent with the codes provided in the original
    MeshGraphNets study: 
    https://github.com/deepmind/deepmind-research/tree/master/meshgraphnets
    """
    NORMAL = 0
    OBSTACLE = 1
    AIRFOIL = 2
    HANDLE = 3
    INFLOW = 4
    OUTFLOW = 5
    WALL_BOUNDARY = 6
    SIZE = 9

### Represent Features to create torch graph data

In [21]:
#number of trajectories to train on.
number_trajectories = 5

#Splitting the dataset list to list of 100 trajectories of 601 trajectory length!
data = [l[:601+i*601] for i in range(100)]

# The time interval is 0.01s
dt = 0.01

#data_list consists of all the pytorch graph data of 
data_list = []
for i in range(100):
    
    if(i==number_trajectories):
        break
    print("Trajectory: ",i)

    #We iterate over all the time steps to produce an example graph except
    #for the last one, which does not have a following time step to produce
    #node output values
    for ts in range(600):
        #Get node features
        #Note that it's faster to convert to numpy then to torch than to
        #import to torch from h5 format directly
        
        
        # Concat velocity and node type to construct node features in pytorch tensor!
        momentum = torch.tensor(np.array(data[i][ts]['velocity']))
        node_type = torch.tensor(np.array(tf.one_hot(data[i][0]['node_type'], NodeType.SIZE))).squeeze(1)
        x = torch.cat((momentum,node_type),dim=-1).type(torch.float)
        
        
        # Get edge indices in torch tensor!
        b = data[i][ts]['cells']
        #look at function triangles_to_edges
        edges = triangles_to_edges(tf.convert_to_tensor(np.array(b)))
#         print(edges)
        edge_index = torch.cat( (torch.tensor(edges[0].numpy()).unsqueeze(0) ,
                     torch.tensor(edges[1].numpy()).unsqueeze(0)), dim=0).type(torch.long)
#         print(edge_index[0])

        # Get edge features       
        # Edge feature for each node pairs in edge_index
        u_i=torch.tensor(np.array(data[i][ts]['mesh_pos']))[edge_index[0]]
        u_j=torch.tensor(np.array(data[i][ts]['mesh_pos']))[edge_index[1]]
        u_ij=u_i-u_j
        u_ij_norm = torch.norm(u_ij,p=2,dim=1,keepdim=True)
        edge_attr = torch.cat((u_ij,u_ij_norm),dim=-1).type(torch.float)

        #Node outputs, for training (velocity)
        v_t=torch.tensor(np.array(data[i][ts]['velocity']))
        v_tp1=torch.tensor(np.array(data[i][ts+1]['velocity']))
        y=((v_tp1-v_t)/dt).type(torch.float)

        #Node outputs, for testing integrator (pressure)
        p=torch.tensor(np.array(data[i][ts]['pressure']))
        #Node outputs, for density
        d=torch.tensor(np.array(data[i][ts]['density']))

        #Data needed for visualization code
        cells=torch.tensor(np.array(data[i][ts]['cells']))
        mesh_pos=torch.tensor(np.array(data[i][ts]['mesh_pos']))

        data_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr,y=y,p=p,
                              cells=cells,mesh_pos=mesh_pos))

Trajectory:  0
Trajectory:  1
Trajectory:  2
Trajectory:  3
Trajectory:  4


In [23]:
len(data_list)

2995