In [1]:
import sys, os 
sys.path.insert(0, os.path.dirname(os.getcwd()))


import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_gnn as tfgnn

print(f'Running TF-GNN {tfgnn.__version__} with TensorFlow {tf.__version__}.')

2024-02-20 17:45:00.814238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-20 17:45:00.839897: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-20 17:45:00.839925: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-20 17:45:00.840691: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-20 17:45:00.845147: I tensorflow/core/platform/cpu_feature_guar

Running TF-GNN 1.0.2 with TensorFlow 2.15.0.


In [2]:
from wofscast import deep_typed_graph_net
from wofscast import grid_mesh_connectivity
from wofscast import icosahedral_mesh, square_mesh
from wofscast.encode_process_decode import EncodeProcessDecode
from wofscast import losses
from wofscast import model_utils, data_utils
from wofscast import my_graphcast as graphcast
#from . import predictor_base
#from . import typed_graph
from glob import glob
import xarray
import dataclasses

In [3]:
input_variables = ['U', 'V', 'W',]# 'T', 'P', 'REFL_10CM', 'UP_HELI_MAX']
target_variables = ['U', 'V', 'W',]# 'T', 'P', 'REFL_10CM', 'UP_HELI_MAX']
forcing_variables = ["XLAND"]

vars_2D = [] #['UP_HELI_MAX']

n_vars_2D = len(vars_2D)
n_vars_3D = len(input_variables) - n_vars_2D

# Weights used in the loss equation.
VARIABLE_WEIGHTS = {v : 1.0 for v in target_variables}
#VARIABLE_WEIGHTS['REFL_10CM'] = 2.0
#VARIABLE_WEIGHTS['UP_HELI_MAX'] = 2.0


# Not pressure levels, but just vertical array indices at the moment. 
pressure_levels = list(np.arange(0,30,4))
radius_query_fraction_edge_length=0.6
hidden_layers = 1

# Loads data from the past 10 minutes and 
# creates a target lead time 5-30, in 5 min intervals
input_duration = '10min'
train_lead_times = '5min' 
eval_lead_times = slice('10min', '25min')

In [4]:
task_config = graphcast.TaskConfig(
      input_variables=input_variables,
      target_variables=target_variables,
      forcing_variables=forcing_variables,
      pressure_levels=pressure_levels,
      input_duration=input_duration,
  )

In [5]:
# Load the data 
data_paths = glob(os.path.join('/work/mflora/wofs-cast-data/dataset*.nc'))
data_paths.sort()

data_paths = [data_paths[0]]*32

train_input_list = []
train_target_list = []
train_forcing_list = []

dataset = xarray.load_dataset(data_paths[0])

for path in data_paths:

    # @title Extract training and eval data
    example_batch = dataset.expand_dims(dim='batch', axis=0)

    _train_inputs, _train_targets, _train_forcings = data_utils.extract_inputs_targets_forcings(
        example_batch, target_lead_times=train_lead_times,
        **dataclasses.asdict(task_config))

    train_input_list.append(_train_inputs)
    train_target_list.append(_train_targets)
    train_forcing_list.append(_train_forcings)
    
# just load some data from the last dataset for evaluation. 
eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
        example_batch, target_lead_times=eval_lead_times,
        **dataclasses.asdict(task_config))

train_inputs = xarray.concat(train_input_list, dim='batch')
train_targets = xarray.concat(train_target_list, dim='batch')
train_forcings = xarray.concat(train_forcing_list, dim='batch')
    
print("All Examples:  ", example_batch.dims.mapping)
print("*"*80)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("*"*80)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)

All Examples:   {'batch': 1, 'time': 6, 'level': 50, 'lat': 300, 'lon': 300, 'datetime': 6}
********************************************************************************
Train Inputs:   {'batch': 32, 'time': 2, 'level': 8, 'lat': 300, 'lon': 300}
Train Targets:  {'batch': 32, 'time': 1, 'level': 8, 'lat': 300, 'lon': 300}
Train Forcings: {'batch': 32, 'time': 1, 'lat': 300, 'lon': 300}
********************************************************************************
Eval Inputs:    {'batch': 1, 'time': 1, 'level': 8, 'lat': 300, 'lon': 300}
Eval Targets:   {'batch': 1, 'time': 4, 'level': 8, 'lat': 300, 'lon': 300}
Eval Forcings:  {'batch': 1, 'time': 4, 'lat': 300, 'lon': 300}


In [6]:
def stack_dataset(dataset: xarray.Dataset) -> np.ndarray:
    """Stack a dataset from (batch, time, lat, lon, level, multiple vars)
       -> [n_examples, num_grid_points, n_channels]
       
    Args:
        dataset (xarray.Dataset): The input dataset with dimensions (batch, time, lat, lon, level, ...).
        
    Returns:
        numpy.ndarray: A NumPy array with shape [num_grid_points, batch_size, n_channels].
    """
    # Stack spatial dimensions into a single 'nodes' dimension
    ds_stacked = dataset.stack(nodes=('lat', 'lon'))
    
    # Combine all other dimensions (except 'batch') into a single 'channels' dimension
    # This includes 'time', 'level', and potentially multiple variables if the dataset has more than one data variable
    non_batch_dims = [dim for dim in ds_stacked.dims if dim != 'batch' and dim != 'nodes']
    ds_combined = ds_stacked.to_array(dim='variable').stack(channels=(non_batch_dims + ['variable']))
    
    # Transpose to order dimensions as [nodes, batch, channels]
    ds_ordered = ds_combined.transpose('nodes', 'batch', 'channels')
    
    # Convert to NumPy array
    result_array = ds_ordered.values
    
    # The result_array shape will be [num_grid_points, batch, n_channels]
    # where 'n_channels' includes all combinations of 'time', 'level', and variables
    return result_array

In [7]:
def _get_max_edge_distance(mesh):
    senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces)
    edge_distances = np.linalg.norm(
      mesh.vertices[senders] - mesh.vertices[receivers], axis=-1)
    
    return edge_distances.max()

class GNNDataPreprocessor:
    def __init__(self, inputs, 
                 n_levels, 
                 n_vars_3D, 
                 n_vars_2D,
                 n_meshes=2, 
                 grid2mesh_radius=0.1, 
                 mesh2grid_edge_normalization_factor=None):
        
        # For the spatial feature creation for the Grid2Mesh and Mesh2Grid Graphs
        self._spatial_features_kwargs = dict(
            add_node_positions=False,
            add_node_latitude=True,
            add_node_longitude=True,
            add_relative_positions=True,
            relative_longitude_local_coordinates=True,
            relative_latitude_local_coordinates=True,
        )

        # Init the multi-mesh structure
        self._init_mesh(n_meshes)
        
        self.num_outputs = self.determine_num_outputs(n_levels, n_vars_3D, n_vars_2D)
        
        # Obtain the query radius in absolute units for the unit-sphere for the
        # grid2mesh model, by rescaling the `radius_query_fraction_edge_length`.
        self._query_radius = (_get_max_edge_distance(self._finest_mesh)
                          * grid2mesh_radius)
        
        self._mesh2grid_edge_normalization_factor = (
            mesh2grid_edge_normalization_factor
        )
        
        self._init_mesh_properties()
        
        self._init_grid_properties(
              grid_lat=inputs.lat, grid_lon=inputs.lon)
        
        #self.compute_grid2mesh_features()
        
        #self.build_grid2mesh_tensor= True
        
        
    def determine_num_outputs(self, n_levels, n_vars_3D, n_vars_2D):
        num_outputs = n_vars_2D + (n_levels * n_vars_3D)
        
        return num_outputs
    
    def _init_mesh(self, n_meshes):
        # Specification of the multimesh.
        self._meshes = (
        icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
            splits=n_meshes))
    
    @property
    def _finest_mesh(self):
        return self._meshes[-1]
    
    def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray):
        """Inits static properties that have to do with grid nodes."""
        self._grid_lat = grid_lat.astype(np.float32)
        self._grid_lon = grid_lon.astype(np.float32)
        # Initialized the counters.
        self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0]

        # Initialize lat and lon for the grid.
        grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat)
        self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32)
        self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32)
    

    def _init_mesh_properties(self):
        """Inits static properties that have to do with mesh nodes."""
        self._num_mesh_nodes = self._finest_mesh.vertices.shape[0]
        mesh_phi, mesh_theta = model_utils.cartesian_to_spherical(
            self._finest_mesh.vertices[:, 0],
            self._finest_mesh.vertices[:, 1],
            self._finest_mesh.vertices[:, 2])
        (
            mesh_nodes_lat,
            mesh_nodes_lon,
        ) = model_utils.spherical_to_lat_lon(
            phi=mesh_phi, theta=mesh_theta)
        # Convert to f32 to ensure the lat/lon features aren't in f64.
        self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32)
        self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32)
    
    
    def compute_grid2mesh_features(self, grid_node_features):
        # Create some edges according to distance between mesh and grid nodes.
        # Create some edges according to distance between mesh and grid nodes.
        assert self._grid_lat is not None and self._grid_lon is not None
        (grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices(
            grid_latitude=self._grid_lat,
            grid_longitude=self._grid_lon,
            mesh=self._finest_mesh,
            radius=self._query_radius)

        # Edges sending info from grid to mesh.
        senders = grid_indices
        receivers = mesh_indices

        # Precompute structural node and edge features according to config options.
        # Structural features are those that depend on the fixed values of the
        # latitude and longitudes of the nodes.
        (senders_node_features, receivers_node_features,
         edge_features) = model_utils.get_bipartite_graph_spatial_features(
             senders_node_lat=self._grid_nodes_lat,
             senders_node_lon=self._grid_nodes_lon,
             receivers_node_lat=self._mesh_nodes_lat,
             receivers_node_lon=self._mesh_nodes_lon,
             senders=senders,
             receivers=receivers,
             **self._spatial_features_kwargs,
         )

        batch_size = grid_node_features.shape[1]
        
        # Get the sender node features (from the grid)
        data = _add_batch_second_axis(senders_node_features, batch_size)
        senders_node_features = np.concatenate([grid_node_features, data], axis=-1) 
        
        # Get the receiver node features (to the mesh)
        dummy_mesh_node_features = np.zeros((self._num_mesh_nodes, batch_size, grid_node_features.shape[-1]),
                dtype=grid_node_features.dtype)
        
        data = _add_batch_second_axis(receivers_node_features, batch_size)
        receivers_node_features =tf.concat([dummy_mesh_node_features, data],axis=-1)
        
        # Get the edge features (grid2mesh edges)
        edge_features =  _add_batch_second_axis(edge_features, batch_size)
        
        self.grid2mesh_features = {'grid_nodes': {'senders_node_features' : senders_node_features}, 
                            'mesh_nodes': {'receivers_node_features': receivers_node_features}, 
                            'edge_features' : edge_features, 
                            'senders' : senders, 
                            'receivers': receivers,
                           }
        
    def _init_grid2mesh_tensor(self, grid_node_features) -> tfgnn.GraphTensor: 
        """Build Grid2Mesh graph tensor for a single xample. 
        
            grid_node_features : Flattened version of the input (n_examples, n_nodes, n_channels) 
            
        """
        senders = self.grid2mesh_data['senders']
        receivers = self.grid2mesh_data['receivers']
        senders_node_features = self.grid2mesh_data['senders_node_features']
        receivers_node_features = self.grid2mesh_data['receivers_node_features']
        edge_features = self.grid2mesh_data['edge_features']
        
        n_grid_node = np.array([self._num_grid_nodes])
        n_mesh_node = np.array([self._num_mesh_nodes])
        n_edge = np.array([receivers.shape[0]])
        n_grid_features = senders_node_features.shape[-1]
        n_mesh_features = receivers_node_features.shape[-1]
        
        # Expand for a single batch size.
        ###print(f"{grid_node_features.shape=}, {type(grid_node_features)=}") # (None, 90000, 48)
        #grid_node_features = tf.transpose(grid_node_features, perm=[1,0,2])
        #print(f"New Shape: {grid_node_features.shape=}")
        
        # Concatenate node structural features with input features.
        #batch_size = grid_node_features.shape[1]
        
        # Concatenate with the sample input. 
        #data = tf.convert_to_tensor(_add_batch_second_axis(senders_node_features, batch_size, symbolic=True))
        data = tf.convert_to_tensor(senders_node_features)

        senders_node_features = tf.concat([grid_node_features, data], axis=-1)
        # To make sure capacity of the embedded is identical for the grid nodes and
        # the mesh nodes, we also append some dummy zero input features for the
        # mesh nodes.
        dummy_mesh_node_features = tf.zeros((self._num_mesh_nodes, grid_node_features.shape[-1]),
                dtype=grid_node_features.dtype)
        
        receiver_data = tf.convert_to_tensor(receivers_node_features)
        receivers_node_features =tf.concat([dummy_mesh_node_features, receiver_data],axis=-1)
   
        # Batch size needs to be the middle axis
        #senders_node_features = tf.transpose(senders_node_features, perm=[1,0,2])
        #receivers_node_features = tf.transpose(receivers_node_features, perm=[1,0,2])

        # Initialize NodeSets (x,y coordindates of the senders and receivers)
        senders_node_features_dict = {tfgnn.HIDDEN_STATE: senders_node_features}
        receivers_node_features_dict = {tfgnn.HIDDEN_STATE: receivers_node_features}
        
        grid_node_set = tfgnn.NodeSet.from_fields(features=senders_node_features_dict, 
                                            sizes=n_grid_node)
        mesh_node_set = tfgnn.NodeSet.from_fields(features=receivers_node_features_dict, 
                                            sizes=n_mesh_node)
        
        # Create adjacency using source and target nodes
        adjacency = tfgnn.Adjacency.from_indices(
            source=("grid_nodes", senders),
            target=("mesh_nodes", receivers)
        )

        edge_features =  tf.convert_to_tensor(edge_features)
        n_edge_features = edge_features.shape[-1]
     
        # Initialize EdgeSet
        edge_set = tfgnn.EdgeSet.from_fields(features={tfgnn.HIDDEN_STATE: edge_features},
                                       sizes=n_edge, 
                                       adjacency=adjacency)
        
        # Define the number of graph components; for a single component, this would be [1]
        num_components = tf.constant([1], dtype=tf.int32)  # Assuming a single graph component

        # Initialize Context without specific features but with the number of components
        graph_context = tfgnn.Context.from_fields(features = {}, 
                                                  sizes=num_components)
        
        # Constructing the GraphTensor
        graph_tensor = tfgnn.GraphTensor.from_pieces(
            #context = graph_context, 
            node_sets={
                "grid_nodes": grid_node_set,
                "mesh_nodes": mesh_node_set
            },
            edge_sets={
                "grid2mesh_edges": edge_set
            }
        )
        
        return graph_tensor
        
    def _init_mesh_tensor(self, latent_mesh_nodes) -> tfgnn.GraphTensor:
        """Build Mesh graph tensor"""
        # Add the structural edge features of this graph. Note we don't need
        # to add the structural node features, because these are already part of
        # the latent state, via the original Grid2Mesh gnn, however, we need
        # the edge ones, because it is the first time we are seeing this particular
        # set of edges.
        batch_size = latent_mesh_nodes.shape[1]
        
        merged_mesh = icosahedral_mesh.merge_meshes(self._meshes)

        # Work simply on the mesh edges.
        senders, receivers = icosahedral_mesh.faces_to_edges(merged_mesh.faces)

        # Precompute structural node and edge features according to config options.
        # Structural features are those that depend on the fixed values of the
        # latitude and longitudes of the nodes.
        assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None
        node_features, edge_features = model_utils.get_graph_spatial_features(
            node_lat=self._mesh_nodes_lat,
            node_lon=self._mesh_nodes_lon,
            senders=senders,
            receivers=receivers,
            **self._spatial_features_kwargs,
        )

        n_mesh_node = np.array([self._num_mesh_nodes])
        n_edge = np.array([senders.shape[0]])
        assert n_mesh_node == len(node_features)
        
        # Initialize NodeSets (x,y coordindates of the senders and receivers
        node_features_dict = {tfgnn.HIDDEN_STATE: latent_mesh_nodes}
        
        mesh_node_set = tfgnn.NodeSet.from_fields(features=node_features_dict, 
                                                  sizes=n_mesh_node)
        
        # Create adjacency using source and target nodes
        adjacency = tfgnn.Adjacency.from_indices(
            source=("mesh_nodes", senders),
            target=("mesh_nodes", receivers)
        )
        
        # Initialize EdgeSet
        edge_features =  _add_batch_second_axis(edge_features, batch_size)
        edge_set = tfgnn.EdgeSet.from_fields(features={tfgnn.HIDDEN_STATE: edge_features},
                                       sizes=n_edge, 
                                       adjacency=adjacency)
        
        # Define the number of graph components; for a single component, this would be [1]
        num_components = tf.constant([1], dtype=tf.int32)  # Assuming a single graph component

        # Initialize Context without specific features but with the number of components
        graph_context = tfgnn.Context.from_fields(sizes=num_components)
        
        # Constructing the GraphTensor
        graph_tensor = tfgnn.GraphTensor.from_pieces(
            context = graph_context, 
            node_sets={
                "mesh_nodes": mesh_node_set
            },
            edge_sets={
                "mesh_edges": edge_set
            }
        )
        
        return graph_tensor
    
    def _init_mesh2grid_tensor(self, mesh_nodes, grid_nodes) -> tfgnn.GraphTensor:
        """Build Mesh2Grid graph from the updated latent mesh nodes and the 
        original latent grid nodes from the grid2mesh transformation."""

        batch_size = mesh_nodes.shape[1]
        
        # Create some edges according to how the grid nodes are contained by
        # mesh triangles.
        (grid_indices,
         mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices(
             grid_latitude=self._grid_lat,
             grid_longitude=self._grid_lon,
             mesh=self._finest_mesh)

        # Edges sending info from mesh to grid.
        senders = mesh_indices.astype(int)#, dtype=tf.int32
        receivers = grid_indices.astype(int)#, dtype=tf.int32

        # Precompute structural node and edge features according to config options.
        assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None
        
        (senders_node_features, receivers_node_features,
         edge_features) = model_utils.get_bipartite_graph_spatial_features(
             senders_node_lat=self._mesh_nodes_lat,
             senders_node_lon=self._mesh_nodes_lon,
             receivers_node_lat=self._grid_nodes_lat,
             receivers_node_lon=self._grid_nodes_lon,
             senders=senders,
             receivers=receivers,
             edge_normalization_factor=self._mesh2grid_edge_normalization_factor,
             **self._spatial_features_kwargs,
         )

        n_grid_node = np.array([self._num_grid_nodes])
        n_mesh_node = np.array([self._num_mesh_nodes])
        n_edge = np.array([senders.shape[0]])

        # Initialize NodeSets for grid and mesh nodes
        grid_node_set = tfgnn.NodeSet.from_fields(features = {tfgnn.HIDDEN_STATE: 
                                                              grid_nodes
                                       }, 
                                                  sizes=n_grid_node)
        
        mesh_node_set = tfgnn.NodeSet.from_fields(features={tfgnn.HIDDEN_STATE: 
                                                            mesh_nodes}, 
                                                  sizes=n_mesh_node)

        # Create adjacency for EdgeSet using source (senders) and target (receivers) nodes
        adjacency = tfgnn.Adjacency.from_indices(
            source=("mesh_nodes", senders),
            target=("grid_nodes", receivers)
        )

        # Initialize EdgeSet
        edge_features =  _add_batch_second_axis(edge_features, batch_size)
        edge_set = tfgnn.EdgeSet.from_fields(features={tfgnn.HIDDEN_STATE : edge_features},
                                             sizes=n_edge, 
                                             adjacency=adjacency)

        graph_context = tfgnn.Context.from_fields(features={}, sizes=tf.constant([1]))
        
        # Constructing the GraphTensor
        graph_tensor = tfgnn.GraphTensor.from_pieces(
            context = graph_context, 
            node_sets={
                "grid_nodes": grid_node_set,
                "mesh_nodes": mesh_node_set
            },
            edge_sets={
                "mesh2grid_edges": edge_set
            }
        )

        return graph_tensor
    
    
    def _grid_node_outputs_to_prediction(
      self,
      grid_node_outputs,
      targets_template: xarray.Dataset,
      ) -> xarray.Dataset:
        """[num_grid_nodes, batch, num_outputs] -> xarray."""
    
        # numpy array with shape [lat_lon_node, batch, channels]
        # to xarray `DataArray` (batch, lat, lon, channels)
        assert self._grid_lat is not None and self._grid_lon is not None
        grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0])

        grid_outputs_lat_lon_leading = grid_node_outputs.reshape(
            grid_shape + grid_node_outputs.shape[1:])
        dims = ("lat", "lon", "batch", "channels")
    
        grid_xarray_lat_lon_leading = xarray.DataArray(
            data=grid_outputs_lat_lon_leading,
            dims=dims)
    
        # Monte: Possibly Deprecated with newest version of jax
        grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading)

        # xarray `DataArray` (batch, lat, lon, channels)
        # to xarray `Dataset` (batch, one time step, lat, lon, level, multiple vars)

        return model_utils.stacked_to_dataset(
            grid_xarray.variable, targets_template)
    
def _add_batch_second_axis(data, batch_size):
  # data [leading_dim, trailing_dim]
  assert data.ndim == 2
  ones = np.ones([batch_size, 1], dtype=data.dtype)
  return data[:, None] * ones  # [leading_dim, batch, trailing_dim]



In [8]:
class GraphCast(tf.keras.Model):
    
    def __init__(self, 
                 preprocessor, 
                 latent_size=8, 
                 num_mlp_hidden_layers=1, 
                 mlp_hidden_size=4, 
                 num_message_passing_steps=2,
                ):
        super().__init__()
        
        self.preprocessor = preprocessor
        self.num_mlp_hidden_layers = num_mlp_hidden_layers
        self.latent_size = latent_size
        self.mlp_hidden_size = mlp_hidden_size
        self.num_message_passing_steps = num_message_passing_steps
        
        #TODO: determine num_outputs? 
        num_outputs = self.preprocessor.num_outputs
        
        self._init_grid2mesh_gnn()
        self._init_mesh_gnn()
        self._init_mesh2grid_gnn(num_outputs)
    

    def call(self, inputs, training=False):

        # GRID2MESH ---------------------------------------------
        input_tensor = self.preprocessor._init_grid2mesh_tensor(inputs)
        result = self.grid2mesh_gnn(input_tensor)

        latent_mesh_nodes = result.node_sets['mesh_nodes'][tfgnn.HIDDEN_STATE]
        latent_grid_nodes = result.node_sets['grid_nodes'][tfgnn.HIDDEN_STATE]

        # MESH ---------------------------------------------
        latent_tensor = self.preprocessor._init_mesh_tensor(latent_mesh_nodes)
        result = self.mesh_gnn(latent_tensor)
        updated_latent_mesh_nodes = result.node_sets['mesh_nodes'][tfgnn.HIDDEN_STATE]

        # MESH2GRID ---------------------------------------------
        final_tensor = self.preprocessor._init_mesh2grid_tensor(updated_latent_mesh_nodes, 
                                                                latent_grid_nodes)
        final_result = self.mesh2grid_gnn(final_tensor)
        output_grid_nodes = final_result.node_sets['grid_nodes'][tfgnn.HIDDEN_STATE]

        # Output Nodes to Grid 
        # Convert output flat vectors for the grid nodes to the format of the output.
        # [num_grid_nodes, batch, output_size] ->
        # xarray (batch, one time step, lat, lon, level, multiple vars)
        return output_grid_nodes
        
    def _init_grid2mesh_gnn(self):
        self.grid2mesh_gnn = EncodeProcessDecode(
            encode_edges=True, # Encode raw features of the grid2mesh edges.
            encode_nodes=True, # Encode raw features of the grid and mesh nodes.
            edge_output_size=False,  
            node_output_size=False,  
            context_output_size=None,  # Don't need this output.
            # Other configurable hyperparameters (most combinations should train).
            num_message_passing_steps=1,
            num_mlp_hidden_layers=self.num_mlp_hidden_layers,
            mlp_hidden_size=self.mlp_hidden_size,
            latent_size=self.latent_size,
            use_layer_norm=True,
            shared_processors=False,
            )
        
    def _init_mesh_gnn(self):
        
        self.mesh_gnn = EncodeProcessDecode(
            encode_edges=True,  # Encode raw features of the multi-mesh edges.
            encode_nodes=False, # Node features already embdded by previous layers.
            edge_output_size=False,  
            node_output_size=False,  
            context_output_size=None,  # Don't need this output.
            # Other configurable hyperparameters (most combinations should train).
            num_message_passing_steps=self.num_message_passing_steps,
            num_mlp_hidden_layers=self.num_mlp_hidden_layers,
            mlp_hidden_size=self.mlp_hidden_size,
            latent_size=self.latent_size,
            use_layer_norm=True,
            shared_processors=False,
            )
     
    def _init_mesh2grid_gnn(self, num_outputs):
        self.mesh2grid_gnn = EncodeProcessDecode(
            encode_edges=True,  # Encode raw features of the mesh2grid edges.
            encode_nodes=False, # Node features already embdded by previous layers.
            edge_output_size=None,  
            node_output_size={'grid_nodes' : num_outputs}, 
            # Back to the number of inputs features from the grid + edges.  
            context_output_size=None,  # Don't need this output.
            # Other configurable hyperparameters (most combinations should train).
            num_message_passing_steps=1,
            num_mlp_hidden_layers=self.num_mlp_hidden_layers,
            mlp_hidden_size=self.mlp_hidden_size,
            latent_size=self.latent_size,
            use_layer_norm=True,
            shared_processors=False,
            )

In [20]:


def to_graph_tensor(data, edge_set_name):
    # Initialize Context without specific features but with the number of components
    graph_context = tfgnn.Context.from_fields(features = {}, 
                                            sizes=tf.constant([1], dtype=tf.int32))
    
    node_features = [f for f in data.keys() if 'node' in f]
    
    need_source_and_target=False
    if len(node_features) > 1:
        need_source_and_target = True
    
    node_sets = {}
    
    pairs = {}
    source=None
    target=None
    for node_set_name in node_features:
        subkey = list(data[node_set_name].keys())[0]
        size = (data[node_set_name][subkey]).shape[0] # (n_nodes, batch_size, n_channels)
        print(f"{node_set_name=}, {subkey=}, {size=}, {data[node_set_name][subkey].shape=}")
        node_sets[node_set_name] = tfgnn.NodeSet.from_fields(
                                    features={tfgnn.HIDDEN_STATE: 
                                              data[node_set_name][subkey]}, 
                                              sizes=(size,))
        
       

        if subkey.split('_')[0] == 'senders':
            source = (node_set_name, data['senders']) 
        else:
            target = (node_set_name, data['receivers']) 
           
    #source=("grid_nodes", senders),
    #target=("mesh_nodes", receivers)
        
    # Get adjacency matrix
    adjacency=tfgnn.Adjacency.from_indices(source=source,target=target)
    
    # Constructing the GraphTensor
    graph_tensor = tfgnn.GraphTensor.from_pieces(
            context = graph_context, 
            node_sets=node_sets,
            edge_sets={
                edge_set_name: tfgnn.EdgeSet.from_fields(features={tfgnn.HIDDEN_STATE: 
                                                                   data['edge_features']},
                                       sizes=np.array([data['senders'].shape[0]]), 
                                       adjacency=adjacency)}
        )
        
    return graph_tensor

In [22]:
# Create the graph inputs. 
preprop = GNNDataPreprocessor(train_inputs, 
                 n_levels=len(pressure_levels), 
                 n_vars_3D=n_vars_3D, 
                 n_vars_2D=n_vars_2D,
                 n_meshes=2, 
                 grid2mesh_radius=0.1, 
                 mesh2grid_edge_normalization_factor=None)

grid_node_features = stack_dataset(train_inputs)

preprop.compute_grid2mesh_features(grid_node_features)
data = preprop.grid2mesh_features

graph = to_graph_tensor(data, 'grid2mesh_edges')



node_set_name='grid_nodes', subkey='senders_node_features', size=90000, data[node_set_name][subkey].shape=(90000, 32, 51)


2024-02-20 17:48:42.945451: W external/local_tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 560.30MiB (rounded to 587520000)requested by op _EagerConst
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
2024-02-20 17:48:42.945493: I external/local_tsl/tsl/framework/bfc_allocator.cc:1039] BFCAllocator dump for GPU_0_bfc
2024-02-20 17:48:42.945501: I external/local_tsl/tsl/framework/bfc_allocator.cc:1046] Bin (256): 	Total Chunks: 6, Chunks in use: 4. 1.5KiB allocated for chunks. 1.0KiB in use in bin. 32B client-requested in use in bin.
2024-02-20 17:48:42.945505: I external/local_tsl/tsl/framework/bfc_allocator.cc:1046] Bin (512): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2024-02-20 17:48:42.945509: I external/lo

InternalError: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.

In [None]:
graph = graph

In [18]:
preprop.grid2mesh_features

{'grid_nodes': {'senders_node_features': array([[[ 3.2143555e+00,  4.0122070e+00,  5.9156418e-03, ...,
            5.1528192e-01, -7.5325884e-02,  9.9715894e-01],
          [ 3.2143555e+00,  4.0122070e+00,  5.9156418e-03, ...,
            5.1528192e-01, -7.5325884e-02,  9.9715894e-01],
          [ 3.2143555e+00,  4.0122070e+00,  5.9156418e-03, ...,
            5.1528192e-01, -7.5325884e-02,  9.9715894e-01],
          ...,
          [ 3.2143555e+00,  4.0122070e+00,  5.9156418e-03, ...,
            5.1528192e-01, -7.5325884e-02,  9.9715894e-01],
          [ 3.2143555e+00,  4.0122070e+00,  5.9156418e-03, ...,
            5.1528192e-01, -7.5325884e-02,  9.9715894e-01],
          [ 3.2143555e+00,  4.0122070e+00,  5.9156418e-03, ...,
            5.1528192e-01, -7.5325884e-02,  9.9715894e-01]],
  
         [[ 3.2749023e+00,  3.8867185e+00,  2.9263494e-03, ...,
            5.1528192e-01, -7.5875051e-02,  9.9711734e-01],
          [ 3.2749023e+00,  3.8867185e+00,  2.9263494e-03, ...,
          

In [None]:
preprop.grid2mesh_features

In [None]:

    
class ToMeshGraphTensor(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__()
        self.units = units
    def call(self, inputs):
        return inputs 
    
class ToMesh2GridGraphTensor(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__()
        self.units = units
    def call(self, inputs):
        return inputs 
    
    

In [None]:
input_tensor = self.preprocessor._init_grid2mesh_tensor(inputs)


def model_fn(gtspec: tfgnn.GraphTensorSpec): 
    graph = inputs = tf.keras.layers.Input(type_spec=gtspec)

    # GRID2MESH ---------------------------------------------
    #input_tensor = self.preprocessor._init_grid2mesh_tensor(inputs)
    graph = self.grid2mesh_gnn(gtspec)

    latent_mesh_nodes = graph.node_sets['mesh_nodes'][tfgnn.HIDDEN_STATE]
    latent_grid_nodes = graph.node_sets['grid_nodes'][tfgnn.HIDDEN_STATE]

    # MESH ---------------------------------------------
    graph = self.preprocessor._init_mesh_tensor(latent_mesh_nodes)
    graph = self.mesh_gnn(graph)
    updated_latent_mesh_nodes = graph.node_sets['mesh_nodes'][tfgnn.HIDDEN_STATE]

    # MESH2GRID ---------------------------------------------
    graph = self.preprocessor._init_mesh2grid_tensor(updated_latent_mesh_nodes, 
                                                                latent_grid_nodes)
    graph = self.mesh2grid_gnn(graph)
    
    return tf.keras.Model(inputs, graph)

In [None]:
preprop = GNNDataPreprocessor(train_inputs, 
                 n_levels=len(pressure_levels), 
                 n_vars_3D=n_vars_3D, 
                 n_vars_2D=n_vars_2D,
                 n_meshes=2, 
                 grid2mesh_radius=0.1, 
                 mesh2grid_edge_normalization_factor=None)


grid_node_features = stack_dataset(train_inputs)
target_node_features = stack_dataset(train_targets)

# Re-order so batch size is first 
grid_node_features_reshaped = tf.transpose(grid_node_features, perm=[1,0,2])
target_node_features_reshaped = tf.transpose(target_node_features, perm=[1,0,2])

train_ds = tf.data.Dataset.from_tensor_slices((grid_node_features_reshaped, target_node_features_reshaped))
train_ds = train_ds.shuffle(
            buffer_size=2).batch(2).prefetch(tf.data.AUTOTUNE)

val_ds = train_ds

In [None]:
model = GraphCast(preprop)

#inputs = tf.keras.layers.Input(shape=(None, 90000, 48), name='my_input')
#model.build((None, 90000, 48))

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss = 'mse',
             )
#print(model.summary())
model.fit(train_ds, shuffle=False, epochs=5)#, validation_data=val_ds)