In [1]:
import jax 
import haiku as hk

def count_mlp_params(input_size, n_hidden=512, n_output=512, add_layer_norm=True):
    """Get weight count for an MLP with a single hidden layer and optional LayerNorm"""
    def get_mlp_with_layer_norm(X):
        mlp = hk.nets.MLP(
          output_sizes=[512, 512], activation=jax.nn.swish)
        
        layer_norm = hk.LayerNorm(
            axis=-1, create_scale=True, create_offset=True)
    
        mlp = hk.Sequential([mlp, layer_norm])
        logits = mlp(X)
        return 0.
    
    def get_mlp(X):
        mlp = hk.nets.MLP(
          output_sizes=[512, 512], activation=jax.nn.swish)
        logits = mlp(X)
        return 0.
    
    # Create a random key for initialization
    rng = jax.random.PRNGKey(42)

    # Example input (dummy data, shape should match the expected input shape of your network)
    X = jax.numpy.zeros([1, input_size])  # Adjust the shape as necessary for your use case

    if add_layer_norm:
        get_mlp_t = hk.transform(get_mlp_with_layer_norm)
    else:
        get_mlp_t = hk.transform(get_mlp)
        
    get_mlp_t = hk.without_apply_rng(get_mlp_t)
    
    # Initialize the model parameters
    params = get_mlp_t.init(rng, X)

    return count_total_parameters(params)


def calculate_mlp_with_layernorm_params(input_size, hidden_size=512, output_size=512):
    # Weights and biases for input to hidden
    input_to_hidden_params = (input_size * hidden_size) + hidden_size
    # Weights and biases for hidden to output
    hidden_to_output_params = (hidden_size * output_size) + output_size
    # LayerNorm parameters (scale and shift)
    layernorm_params = 2 * output_size  # 512 scale + 512 shift
    total_params = input_to_hidden_params + hidden_to_output_params + layernorm_params
    return total_params

def count_total_parameters(params_dict):
    """
    Count the total number of parameters in a nested dictionary of parameters.
    Assumes that the dictionary contains `Array` objects that have a `size` attribute.
    
    Args:
    - params_dict (dict): A nested dictionary of parameters.
    
    Returns:
    - int: The total number of parameters.
    """
    total_params = 0

    # Define a helper function to recurse through the dictionary
    def recurse_through_dict(d):
        nonlocal total_params
        for k, v in d.items():
            if isinstance(v, dict):
                recurse_through_dict(v)  # Recurse if value is a dictionary
            else:
                # Assume that the object has a 'size' attribute
                total_params += v.size
    
    recurse_through_dict(params_dict)
    return total_params

def get_grid_node_input_size(N_2D, N_3D, N_levels, N_timesteps, N_forcings, N_constants):
    return ((N_2D + (N_3D * N_levels))*N_timesteps) + N_forcings*(N_timesteps+1) + N_constants
    

ModuleNotFoundError: No module named 'jax'

In [None]:
# For WoFSCast 
#vG_size = get_grid_node_input_size(N_2D=5, 
#                                   N_3D=6, 
#                                   N_levels=37,
#                                   N_timesteps=2, 
#                                   N_forcings=5, 
#                                   N_constants=5)

vG_size = 234

input_sizes = {
    'vG': vG_size,  # Number of features for grid nodes (vertical slice of atmosphere at grid point i)
    'vM': 3,    # Number of features for mesh nodes (cos(lat), cos(lon), sin(lon))
    'eM': 4,    # Number of features for mesh edges (length of edge, 
                # vector (x,y,z) diff between the 3d positions of sender and receiver nodes)
    'eG2M': 4,  # Number of features for grid-to-mesh edges " "
    'eM2G': 4   # Number of features for mesh-to-grid edges " "
}

LATENT_SPACE = 64
N_OUTPUT = 227 # output excluding the input forcing, extra input at t-1, and constants

N_MESSAGE_PASSING_STEPS = 4

### Embedding MLPs

These are the MLPs that transform all the inputs into the same size 

In [None]:
total_embed_count = sum(count_mlp_params(size) for size in input_sizes.values())

# Print total_embedding_count in terms of millions
print(f'Total Embedding Count: {total_embed_count / 1_000_000:.2f} million')

### Encoder Step: Grid2Mesh GNN 

This step is composed of the different MLPs for updating the grid2mesh edges, the mesh nodes, and grid nodes


In [None]:
# CHECKED WITH THE GRAPHCAST COLAB NOTEBOOKS AND SIZES ARE CORRECT!!!!

# Edge update (concatenate embedded edge, grid, and mesh features)
# Grid2Mesh Edge Update : (None, LATENT_SPACE*3) -> (None, LATENT_SPACE) -> (None, LATENT_SPACE)
grid2mesh_edges = count_mlp_params(
    input_size=3*LATENT_SPACE, 
    n_hidden=LATENT_SPACE, 
    n_output=LATENT_SPACE)


# Mesh node update (concatenate mesh features + 
#  updated edge features[summed over the different input edges])
# Grid2Mesh Edge Update : (None, 2*LATENT_SPACE) -> (None, LATENT_SPACE) -> (None, LATENT_SPACE)
grid2mesh_mesh_nodes = count_mlp_params(input_size=2*LATENT_SPACE, 
                                        n_hidden=LATENT_SPACE, 
                                        n_output=LATENT_SPACE)

# Grid node update (?)
grid2mesh_grid_nodes = count_mlp_params(input_size=LATENT_SPACE, 
                                        n_hidden=LATENT_SPACE, 
                                        n_output=LATENT_SPACE)

# Print Grid2Mesh GNN in terms of millions
grid2mesh_count = (grid2mesh_edges+grid2mesh_mesh_nodes+grid2mesh_grid_nodes)

print(f'Grid2Mesh GNN: {grid2mesh_count / 1_000_000:.2f} million')

### Processing Step: Multi-Mesh GNN

This step is composed of the main processing and the series of MLP for updating mesh nodes and edges. The "message-passing" occurs in this step and introduces a new MLP for each passage.  

In [None]:
# On the multi-mesh, each "message-passing" step is basically appending another NN. 
# The output is autoregressively passed into a new NN. Each new NN is independent of the 
# previous NNs, so no parameters are shared. 

# Edge update (concatenate embedded mesh edges connecting receiver to the senders,
# + receiving embedded mesh nodes + sending embedded mesh nodes)

# Output size is the latent space size, since the updates residuals added to the 
# original embedded edge data. 

# Mesh Edge NN = 
# (None, 3*LATENT_SPACE) -> (None, LATENT_SPACE) -> (None, LATENT_SPACE)
mesh_edges = count_mlp_params(input_size=3*LATENT_SPACE, 
                              n_hidden=LATENT_SPACE, 
                              n_output=LATENT_SPACE)


# Mesh node update (concatenate mesh features + 
#  updated edge features[summed over the different input edges])
# Mesh Node Update : (None, 2*LATENT_SPACE) -> (None, LATENT_SPACE) -> (None, LATENT_SPACE)
mesh_nodes = count_mlp_params(input_size=2*LATENT_SPACE, 
                              n_hidden=LATENT_SPACE, 
                              n_output=LATENT_SPACE)

mesh_count = (mesh_edges+mesh_nodes) * N_MESSAGE_PASSING_STEPS

# Print Mesh GNN in terms of millions
print(f'Mesh GNN: {mesh_count / 1_000_000:.2f} million')

### Decoder Step: Mesh2Grid GNNs

This step involves MLPs mapping mesh nodes and mesh2grid edges to the grid nodes. 

In [None]:
# Edge update (concatenate edge features, mesh features, and grid features)
mesh2grid_edges = count_mlp_params(input_size=3*LATENT_SPACE, 
                                   n_hidden=LATENT_SPACE, 
                                   n_output=LATENT_SPACE)

# Grid node update 
mesh2grid_grid_nodes = count_mlp_params(input_size=2*LATENT_SPACE, 
                                        n_hidden=LATENT_SPACE, 
                                        n_output=LATENT_SPACE)

# Mesh Node Update 
mesh2grid_mesh_nodes = count_mlp_params(input_size=LATENT_SPACE, 
                                        n_hidden=LATENT_SPACE, 
                                        n_output=LATENT_SPACE)


mesh2grid_count = (mesh2grid_edges+mesh2grid_grid_nodes+mesh2grid_mesh_nodes)

# Mesh node update
# No update! 

# Print Mesh2Grid GNN in terms of millions
print(f'Mesh2Grid GNN: {mesh2grid_count / 1_000_000:.2f} million')

### Final Output

This is final MLP that maps the latent grid nodes back to physical space.

In [None]:
#------------------------------------------------------
# Final output 
#------------------------------------------------------

# Grid node update 
output_grid_nodes = count_mlp_params(input_size=LATENT_SPACE, 
                                     n_hidden=LATENT_SPACE, 
                                     n_output=N_OUTPUT, 
                                     add_layer_norm=False)

# Mesh node updates (in GraphCast, disagrees with Neural-LAM paper)
decoder_mesh_nodes = count_mlp_params(input_size=LATENT_SPACE, 
                                     n_hidden=LATENT_SPACE, 
                                     n_output=N_OUTPUT, 
                                     add_layer_norm=False)


total = (
    total_embed_count + grid2mesh_count + mesh_count + mesh2grid_count + output_grid_nodes + decoder_mesh_nodes)

expected_total = 36.778 * 1_000_000

print(f'\nTotal Count: {total / 1_000_000:.3f} million')