In [34]:
import tensorflow_gnn as tfgnn
import sys, os 
sys.path.insert(0, os.path.dirname(os.getcwd()))

from wofscast import model_utils
import xarray as xr
import numpy as np

In [None]:

ds = xr.load_dataset('/work/mflora/wofs-cast-data/dataset_20200505.nc')
# subsample by level
ds = ds.drop_vars('datetime')
levels = [0,1,10]
ds = ds.isel(level=(levels))
ds = ds.expand_dims(dim='batch')

In [None]:
def stack_dataset(dataset: xr.Dataset) -> np.ndarray:
    """Stack a dataset from (batch, time, lat, lon, level, multiple vars)
       -> [num_grid_points, batch, n_channels]
       
    Args:
        dataset (xarray.Dataset): The input dataset with dimensions (batch, time, lat, lon, level, ...).
        
    Returns:
        numpy.ndarray: A NumPy array with shape [num_grid_points, batch, n_channels].
    """
    # Stack spatial dimensions into a single 'nodes' dimension
    ds_stacked = dataset.stack(nodes=('lat', 'lon'))
    
    # Combine all other dimensions (except 'batch') into a single 'channels' dimension
    # This includes 'time', 'level', and potentially multiple variables if the dataset has more than one data variable
    non_batch_dims = [dim for dim in ds_stacked.dims if dim != 'batch' and dim != 'nodes']
    ds_combined = ds_stacked.to_array(dim='variable').stack(channels=(non_batch_dims + ['variable']))
    
    # Transpose to order dimensions as [nodes, batch, channels]
    ds_ordered = ds_combined.transpose('nodes', 'batch', 'channels')
    
    # Convert to NumPy array
    result_array = ds_ordered.values
    
    # The result_array shape will be [num_grid_points, batch, n_channels]
    # where 'n_channels' includes all combinations of 'time', 'level', and variables
    return result_array

In [None]:
# grid: NWP grid points/nodes
# mesh: GNN Mesh nodes 

# latent size: n_channels -> latent size
# hidden_layers : the hidden layers in NN used for 
#                 the message-passing in the GNN

class MeshGNN:
    def __init__(self, latent_size = 32, hidden_layers = 1):
        # Encoder, which moves data from the grid to the mesh with a single message
    # passing step.
    self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
        embed_nodes=True,  # Embed raw features of the grid and mesh nodes.
        embed_edges=True,  # Embed raw features of the grid2mesh edges.
        edge_latent_size=dict(grid2mesh=latent_size),
        node_latent_size=dict(
            mesh_nodes=latent_size,
            grid_nodes=latent_size),
        mlp_hidden_size=latent_size,
        mlp_num_hidden_layers=hidden_layers,
        num_message_passing_steps=1,
        use_layer_norm=True,
        include_sent_messages_in_node_update=False,
        activation="swish",
        f32_aggregation=True,
        aggregate_normalization=None,
        name="grid2mesh_gnn",
    )
    
    def __call__(self, inputs, targets):
        # Convert all input data into flat vectors for each of the grid nodes.
        # xarray.Dataset(batch, time, lat, lon, level, multiple vars, forcings)
        # -> [num_grid_points, batch, num_channels]
        grid_node_features = self._inputs_to_grid_node_features(X)
    
        # Transfer data for the grid to the mesh,
        # [num_mesh_nodes, batch, latent_size], [num_grid_nodes, batch, latent_size]
        (latent_mesh_nodes, latent_grid_nodes
         ) = self._run_grid2mesh_gnn(grid_node_features
    
    def _inputs_to_grid_node_features(self, inputs: xarray.Dataset) -> np.ndarray:
        """xarray Dataset -> nu,py [num_grid_nodes, batch, num_channels]"""
        return stack_dataset(inputs)
    


    
        