In [1]:
import cdsapi
import datetime
import functools
from graphcast import autoregressive, casting, checkpoint, data_utils as du, graphcast, normalization, rollout
import haiku as hk 
import isodate
import jax
import math
import numpy as np
import pandas as pd
from pysolar.radiation import get_radiation_direct
from pysolar.solar import get_altitude
import pytz
import scipy
from typing import Dict
import xarray
import gc

In [2]:
client = cdsapi.Client() # Making a connection to CDS, to fetch data. 
# The fields to be fetched from the single-level source. 
singlelevelfields = [
                        '10m_u_component_of_wind',
                        '10m_v_component_of_wind',
                        '2m_temperature',
                        'geopotential',
                        'land_sea_mask',
                        'mean_sea_level_pressure',
                        'toa_incident_solar_radiation', 
                        'total_precipitation'
                    ]

# The fields to be fetched from the pressure-level source. 
pressurelevelfields = [
                        'u_component_of_wind',
                        'v_component_of_wind',
                        'geopotential',
                        'specific_humidity',
                        'temperature',
                        'vertical_velocity'
                    ]

# The 13 pressure levels.
pressure_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]

# Initializing other required constants.
pi = math.pi

# There is a gap of 6 hours between each graphcast prediction.
gap = 6

# Predicting for 4 timestamps.
predictions_steps = 4

watts_to_joules = 3600

# Timestamp of the first prediction.
first_prediction = datetime.datetime(2024, 1, 1, 18, 0)  

# Latitude range.
lat_range = range(-180, 181, 1) 

# Longitude range. 
lon_range = range(0, 360, 1) 

# A utility function used for ease of coding. 
# Converting the variable to a datetime object.
def toDatetime(dt) -> datetime.datetime:
    if isinstance(dt, datetime.date) and isinstance(dt, datetime.datetime):
        return dt
    
    elif isinstance(dt, datetime.date) and not isinstance(dt, datetime.datetime):
        return datetime.datetime.combine(dt, datetime.datetime.min.time())
    
    elif isinstance(dt, str):
        if 'T' in dt:
            return isodate.parse_datetime(dt)
        else:
            return datetime.datetime.combine(isodate.parse_date(dt), datetime.datetime.min.time())

2025-03-05 17:43:45,027 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.


## Get inputs

In [3]:
%%time
# Getting the single and pressure level values.
def getSingleAndPressureValues():
    # SINGLE LEVELS
    dataset = 'reanalysis-era5-single-levels'

    request_1 = {
                'product_type': 'reanalysis',
                'variable': singlelevelfields[:6],
                # 'variable': singlelevelfields[:5],
                'grid': '1.0/1.0',
                'year': [2024],
                'month': [1],
                'day': [1],
                'time': [
                    '00:00', '01:00', '02:00', '03:00',
                    '04:00', '05:00', '06:00', '07:00', 
                    '08:00', '09:00', '10:00', '11:00',
                    '12:00'
                ],
                'data_format': 'netcdf'
            }
    
    request_2 = {
                'product_type': 'reanalysis',
                'variable': singlelevelfields[6:],
                # 'variable': singlelevelfields[5:],
                'grid': '1.0/1.0',
                'year': [2024],
                'month': [1],
                'day': [1],
                'time': [
                    '00:00', '01:00', '02:00', '03:00',
                    '04:00', '05:00', '06:00', '07:00', 
                    '08:00', '09:00', '10:00', '11:00',
                    '12:00'
                ],
                'data_format': 'netcdf'
            }
    
    client.retrieve(
            dataset,
            request_1,
            'single-level_1.nc'
        )
    
    client.retrieve(
            dataset,
            request_2,
            'single-level_2.nc'
        )    
    
    # read data
    singlelevel_1 = xarray.open_dataset('single-level_1.nc', engine = 'netcdf4').to_dataframe()
    singlelevel_2 = xarray.open_dataset('single-level_2.nc', engine = 'netcdf4').to_dataframe()
    
    # drop useless columns
    singlelevel_1 = singlelevel_1.drop(['number', 'expver'], axis=1)
    singlelevel_2 = singlelevel_2.drop(['number', 'expver'], axis=1)
    
    # merge tables
    singlelevel = singlelevel_1.merge(
                                      singlelevel_2, 
                                      left_index=True, 
                                      right_index=True,
                                      how='inner'
                                  )
    
    del singlelevel_1, singlelevel_2
    gc.collect()

    singlelevel = singlelevel.rename(
                                     columns = {
                                         col:singlelevelfields[ind] for ind, col in enumerate(singlelevel.columns.values.tolist())
                                     }
                                 )
    
    singlelevel = singlelevel.rename(
                                     columns = {
                                         'geopotential': 'geopotential_at_surface'
                                     }
                                 )
    
    # Calculating the sum of the last 6 hours of rainfall. 
    singlelevel = singlelevel.sort_index()
    
    singlelevel['total_precipitation_6hr'] = (
                                        singlelevel.groupby(level=[0,1])['total_precipitation']
                                                   .rolling(window=6, min_periods=1)
                                                   .sum()
                                                   .reset_index(level=[0,1], drop=True)
                                    )
    
    singlelevel.pop('total_precipitation')


    
    # PRESSURE LEVELS
    dataset = 'reanalysis-era5-pressure-levels'
    
    request = {
            'product_type': 'reanalysis',
            'variable': pressurelevelfields,
            'grid': '1.0/1.0',
            'year': [2024],
            'month': [1],
            'day': [1],
            'time': ['06:00', '12:00'],
            'pressure_level': pressure_levels,
            'data_format': 'netcdf'
        }

    client.retrieve(
        dataset,
        request,
        'pressure-level.nc'
    )
    
    pressurelevel = xarray.open_dataset('pressure-level.nc', engine = 'netcdf4').to_dataframe()\

    # drop useless columns
    pressurelevel = pressurelevel.drop(['number', 'expver'], axis=1)
    
    pressurelevel = pressurelevel.rename(
                                         columns = {
                                             col:pressurelevelfields[ind] for ind, col in enumerate(pressurelevel.columns.values.tolist())
                                         }
                                     )

    # rename axis
    singlelevel = singlelevel.rename_axis(index={'valid_time': 'time'})
    pressurelevel = pressurelevel.rename_axis(index={'valid_time': 'time'})

    return singlelevel, pressurelevel



# Adding sin and cos of the year progress. 
def addYearProgress(secs, data):
    progress = du.get_year_progress(secs)
    data['year_progress_sin'] = math.sin(2 * pi * progress)
    data['year_progress_cos'] = math.cos(2 * pi * progress)

    return data



# Adding sin and cos of the day progress.
def addDayProgress(secs, lon:str, data:pd.DataFrame):
    lons = data.index.get_level_values(lon).unique()
    progress:np.ndarray = du.get_day_progress(secs, np.array(lons))
    prxlon = {lon:prog for lon, prog in list(zip(list(lons), progress.tolist()))}
    data['day_progress_sin'] = data.index.get_level_values(lon).map(lambda x: math.sin(2 * pi * prxlon[x]))
    data['day_progress_cos'] = data.index.get_level_values(lon).map(lambda x: math.cos(2 * pi * prxlon[x]))
    
    return data



# Adding day and year progress. 
def integrateProgress(data:pd.DataFrame):
    for dt in data.index.get_level_values('time').unique():
        seconds_since_epoch = toDatetime(dt).timestamp()
        data = addYearProgress(seconds_since_epoch, data)
        data = addDayProgress(
                            seconds_since_epoch,
                            'longitude' if 'longitude' in data.index.names else 'lon',
                            data
                        )

    return data



# Adding batch field and renaming some others.
def formatData(data:pd.DataFrame) -> pd.DataFrame:
    data = data.rename_axis(index = {'latitude': 'lat', 'longitude': 'lon'})
    if 'batch' not in data.index.names:
        data['batch'] = 0
        data = data.set_index('batch', append = True)
    
    return data

if __name__ == '__main__':
    values:Dict[str, xarray.Dataset] = {}
    
    single, pressure = getSingleAndPressureValues()
    values['inputs'] = pd.merge(pressure, single, left_index = True, right_index = True, how = 'inner')
    values['inputs'] = integrateProgress(values['inputs'])
    values['inputs'] = formatData(values['inputs'])

2025-03-05 17:43:53,606 INFO Request ID is 16231e42-97bb-4430-ba5a-703c9192cf78
2025-03-05 17:43:53,749 INFO status has been updated to accepted
2025-03-05 17:44:02,426 INFO status has been updated to running
2025-03-05 17:44:07,632 INFO status has been updated to successful


a3d238c72422729e714428e4117e00e3.nc:   0%|          | 0.00/9.26M [00:00<?, ?B/s]

2025-03-05 17:44:10,771 INFO Request ID is 03482800-b2ee-4f39-9494-caa59ad15f27
2025-03-05 17:44:10,888 INFO status has been updated to accepted
2025-03-05 17:44:16,235 INFO status has been updated to running
2025-03-05 17:44:19,726 INFO status has been updated to successful


cb051f353118e9e20bc7187bbc06d21e.nc:   0%|          | 0.00/1.57M [00:00<?, ?B/s]

2025-03-05 17:44:26,239 INFO Request ID is 0ef960e1-cef8-41f7-84ef-96d3b0d247c4
2025-03-05 17:44:26,358 INFO status has been updated to accepted
2025-03-05 17:44:40,181 INFO status has been updated to successful


c9670bee49fd495f1e01b95bb00aa8d9.nc:   0%|          | 0.00/19.0M [00:00<?, ?B/s]

CPU times: user 19.7 s, sys: 2.76 s, total: 22.5 s
Wall time: 1min 8s


In [4]:
2

2

## Get targets

In [5]:
# Includes the packages imported and constants assigned. 
# The functions created for the inputs also go here. 
predictionFields = [
                        'u_component_of_wind',
                        'v_component_of_wind',
                        'geopotential',
                        'specific_humidity',
                        'temperature',
                        'vertical_velocity',
                        '10m_u_component_of_wind',
                        '10m_v_component_of_wind',
                        '2m_temperature',
                        'mean_sea_level_pressure',
                        'total_precipitation_6hr'
                    ]


# Creating an array full of nan values.
def nans(*args) -> list:
    return np.full((args), np.nan)


# Adding or subtracting time.
def deltaTime(dt, **delta) -> datetime.datetime:
    return dt + datetime.timedelta(**delta)


def getTargets(dt, data:pd.DataFrame):
    # rename axis
    data = data.rename_axis(index = {'pressure_level': 'level'})
    
    # Creating an array consisting of unique values of each index.
    lat = sorted(data.index.get_level_values('lat').unique().tolist())
    lon = sorted(data.index.get_level_values('lon').unique().tolist())
    levels = sorted(data.index.get_level_values('level').unique().tolist())
    batch = data.index.get_level_values('batch').unique().tolist()
    
    time = [deltaTime(dt, hours = days * gap) for days in range(4)]

    # Creating an empty dataset using latitude, longitude, the pressure levels and each prediction timestamp.  
    target = xarray.Dataset(
        {
            field: (
                ['lat', 'lon', 'level', 'time'],
                nans(len(lat), len(lon), len(levels), len(time))
            ) for field in predictionFields
        }, 
        coords = {
            'lat': lat, 
            'lon': lon,
            'level': levels,
            'time': time, 
            'batch': batch
        }
    )

    return target.to_dataframe()
    

if __name__ == '__main__':
    # The code for creating inputs will be here. 
    values['targets'] = getTargets(first_prediction, values['inputs'])

## Get forcings

In [6]:
%%time
# Includes the packages imported and constants assigned.
# The functions created for the inputs and targets also go here. 
# Adding a timezone to datetime.datetime variables. 
def addTimezone(dt, tz = pytz.UTC) -> datetime.datetime:
    dt = toDatetime(dt)
    if dt.tzinfo == None:
        return pytz.UTC.localize(dt).astimezone(tz)
    else:
        return dt.astimezone(tz)



# Getting the solar radiation value wrt longitude, latitude and timestamp. 
def getSolarRadiation(longitude, latitude, dt):  
    altitude_degrees = get_altitude(latitude, longitude, addTimezone(dt))
    solar_radiation = get_radiation_direct(dt, altitude_degrees) if altitude_degrees > 0 else 0

    return solar_radiation * watts_to_joules



# Calculating the solar radiation values for timestamps to be predicted. 
def integrateSolarRadiation(data:pd.DataFrame):
    dates = list(data.index.get_level_values('time').unique())
    coords = [[lat, lon] for lat in lat_range for lon in lon_range]
    values = []
    
    # For each data, getting the solar radiation value at a particular coordinate.     
    for dt in dates:
        values.extend(
                      list(
                          map(
                              lambda coord: {
                                  'time': dt,
                                  'lon': coord[1],
                                  'lat': coord[0],
                                  'toa_incident_solar_radiation': getSolarRadiation(coord[1], coord[0], dt)
                              },
                              coords
                          )
                      )
                  )
  
    # Setting indices.
    values = pd.DataFrame(values).set_index(keys = ['lat', 'lon', 'time'])
      
    # The forcings dataset will now contain the solar radiation values.
    return pd.merge(data, values, left_index = True, right_index = True, how = 'inner')



def getForcings(data:pd.DataFrame):
    # Since forcings data does not contain batch as an index, it is dropped.
    # So are all the columns, since forcings data only has 5, which will be created.
    forcingdf = data.reset_index(level = 'level', drop = True).drop(labels = predictionFields, axis = 1)
    
    # Keeping only the unique indices.
    forcingdf = pd.DataFrame(index = forcingdf.index.drop_duplicates(keep = 'first'))

    # Adding the sin and cos of day and year progress.
    # Functions are included in the creation of inputs data section.
    forcingdf = integrateProgress(forcingdf)

    # Integrating the solar radiation values.
    forcingdf = integrateSolarRadiation(forcingdf)

    return forcingdf



if __name__ == '__main__':
    # The code for creating inputs and targets will be here. 
    values['forcings'] = getForcings(values['targets'])

CPU times: user 15min 8s, sys: 2.31 s, total: 15min 10s
Wall time: 15min 12s


In [7]:
2

2

## Postprocessing inputs, targets, forcings (transform to xarray)

In [8]:
# Includes the packages imported and constants assigned. 
# The functions created for the inputs, targets and forcings also go here. 
# A dictionary created, containing each coordinate a data variable requires.
class AssignCoordinates:
    coordinates = {
                    '2m_temperature': ['batch', 'lon', 'lat', 'time'],
                    'mean_sea_level_pressure': ['batch', 'lon', 'lat', 'time'],
                    '10m_v_component_of_wind': ['batch', 'lon', 'lat', 'time'],
                    '10m_u_component_of_wind': ['batch', 'lon', 'lat', 'time'],
                    'total_precipitation_6hr': ['batch', 'lon', 'lat', 'time'],
                    'temperature': ['batch', 'lon', 'lat', 'level', 'time'],
                    'geopotential': ['batch', 'lon', 'lat', 'level', 'time'],
                    'u_component_of_wind': ['batch', 'lon', 'lat', 'level', 'time'],
                    'v_component_of_wind': ['batch', 'lon', 'lat', 'level', 'time'],
                    'vertical_velocity': ['batch', 'lon', 'lat', 'level', 'time'],
                    'specific_humidity': ['batch', 'lon', 'lat', 'level', 'time'],
                    'toa_incident_solar_radiation': ['batch', 'lon', 'lat', 'time'],
                    'year_progress_cos': ['batch', 'time'],
                    'year_progress_sin': ['batch', 'time'],
                    'day_progress_cos': ['batch', 'lon', 'time'],
                    'day_progress_sin': ['batch', 'lon', 'time'],
                    'geopotential_at_surface': ['lon', 'lat'],
                    'land_sea_mask': ['lon', 'lat'],
                }

def modifyCoordinates(data:xarray.Dataset):
    
    # Parsing through each data variable and removing unneeded indices.     
    for var in list(data.data_vars):
        varArray:xarray.DataArray = data[var]
        nonIndices = list(set(list(varArray.coords)).difference(set(AssignCoordinates.coordinates[var])))
        data[var] = varArray.isel(**{coord: 0 for coord in nonIndices})
    data = data.drop_vars('batch')

    return data

def makeXarray(data:pd.DataFrame) -> xarray.Dataset:
    
    # Converting to xarray.
    data = data.to_xarray()
    data = modifyCoordinates(data)

    return data

if __name__ == '__main__':

    # The code for creating inputs, targets and forcings will be here. 
    values = {value:makeXarray(values[value]) for value in values}

In [9]:
2

2

## Forecasting

In [10]:
gencast_mini = 'GenCast 1p0deg Mini _2019.npz'
gencast_1deg = 'GenCast 1p0deg _2019.npz'
gencast_025deg = 'GenCast 0p25deg _2019.npz'
gencast_operational = 'GenCast 0p25deg Operational _2022.npz'


graphcast_small = 'GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz'
graphcast_operational = 'GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz'
graphcast_34_lvls = 'GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz'

In [13]:
%%time
# Includes the packages imported and constants assigned. 
# The functions created for the inputs, targets and forcings also go here. 

with open(
    f'model/params/{graphcast_small}', 'rb'
) as model:
    ckpt = checkpoint.load(model, graphcast.CheckPoint)
    params = ckpt.params
    state = {}
    model_config = ckpt.model_config
    task_config = ckpt.task_config





#!!!!!!!!DELETE!!!!!!!!!!!!!!!!!!!!!!

existing_array = params['grid2mesh_gnn/~_networks_builder/encoder_nodes_grid_nodes_mlp/~/linear_0']['w']

zeros_array = np.zeros((288, 512), dtype=np.float32)  

new_array = np.vstack((existing_array, zeros_array))

params['grid2mesh_gnn/~_networks_builder/encoder_nodes_grid_nodes_mlp/~/linear_0']['w'] = new_array

del existing_array, zeros_array, new_array
gc.collect()




existing_array = params['grid2mesh_gnn/~_networks_builder/encoder_nodes_mesh_nodes_mlp/~/linear_0']['w']

zeros_array = np.zeros((288, 512), dtype=np.float32)  

new_array = np.vstack((existing_array, zeros_array))

params['grid2mesh_gnn/~_networks_builder/encoder_nodes_mesh_nodes_mlp/~/linear_0']['w'] = new_array

del existing_array, zeros_array, new_array
gc.collect()

#!!!!!!!!DELETE!!!!!!!!!!!!!!!!!!!!!!





with open(r'model/stats/diffs_stddev_by_level.nc', 'rb') as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()


with open(r'model/stats/mean_by_level.nc', 'rb') as f:
    mean_by_level = xarray.load_dataset(f).compute()


with open(r'model/stats/stddev_by_level.nc', 'rb') as f:
    stddev_by_level = xarray.load_dataset(f).compute()


def construct_wrapped_graphcast(
    model_config:graphcast.ModelConfig,
    task_config:graphcast.TaskConfig
):
    predictor = graphcast.GraphCast(model_config, task_config)
    
    predictor = casting.Bfloat16Cast(predictor)
    
    predictor = normalization.InputsAndResiduals(
        predictor, diffs_stddev_by_level = 
        diffs_stddev_by_level, mean_by_level = 
        mean_by_level, stddev_by_level = stddev_by_level
    )
    
    predictor = autoregressive.Predictor(predictor, gradient_checkpointing = True)
    
    return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
    predictor = construct_wrapped_graphcast(model_config, task_config)
    
    return predictor(inputs, targets_template = targets_template, forcings = forcings)


def with_configs(fn):
    return functools.partial(
                             fn, 
                             model_config = model_config,
                             task_config = task_config
                         )


def with_params(fn):
    return functools.partial(
                             fn, 
                             params = params, 
                             state = state
                         )


def drop_state(fn):
    return lambda **kw: fn(**kw)[0]


run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))


class Predictor:
    @classmethod
    def predict(
        cls,
        inputs,
        targets,
        forcings
    ) -> xarray.Dataset:
        predictions = rollout.chunked_prediction(
                                                 run_forward_jitted,
                                                 rng = jax.random.PRNGKey(0),
                                                 inputs = inputs,
                                                 targets_template = targets,
                                                 forcings = forcings
                                             )
        
        return predictions


# if __name__ == '__main__':
#     # The code for creating inputs, targets, forcings & processing will be here. 
#     predictions = Predictor.predict(
#                                     values['inputs'],
#                                     values['targets'],
#                                     values['forcings']
#                                 )
#     predictions.to_dataframe().to_csv('predictions.csv', sep = ',')




# The code for creating inputs, targets, forcings & processing will be here. 
predictions = Predictor.predict(
                                values['inputs'],
                                values['targets'],
                                values['forcings']
                            )
# predictions.to_dataframe().to_csv('predictions.csv', sep = ',')

  num_target_steps = targets_template.dims["time"]
  scan_length = targets_template.dims['time']
  num_inputs = inputs.dims['time']


TypeError: scan body function carry input and carry output must have the same pytree structure, but they differ:

The input carry component inputs[0] is a <class 'xarray.core.dataset.Dataset'> with pytree metadata _HashableCoords({'time': <xarray.IndexVariable 'time' (time: 2)> Size: 16B
array(['2024-01-01T06:00:00.000000000', '2024-01-01T12:00:00.000000000'],
      dtype='datetime64[ns]'), 'lat': <xarray.IndexVariable 'lat' (lat: 181)> Size: 1kB
array([-90., -89., -88., -87., -86., -85., -84., -83., -82., -81., -80., -79.,
       -78., -77., -76., -75., -74., -73., -72., -71., -70., -69., -68., -67.,
       -66., -65., -64., -63., -62., -61., -60., -59., -58., -57., -56., -55.,
       -54., -53., -52., -51., -50., -49., -48., -47., -46., -45., -44., -43.,
       -42., -41., -40., -39., -38., -37., -36., -35., -34., -33., -32., -31.,
       -30., -29., -28., -27., -26., -25., -24., -23., -22., -21., -20., -19.,
       -18., -17., -16., -15., -14., -13., -12., -11., -10.,  -9.,  -8.,  -7.,
        -6.,  -5.,  -4.,  -3.,  -2.,  -1.,   0.,   1.,   2.,   3.,   4.,   5.,
         6.,   7.,   8.,   9.,  10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,
        18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,
        30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,
        42.,  43.,  44.,  45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,
        54.,  55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,
        66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,  77.,
        78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,  88.,  89.,
        90.]), 'lon': <xarray.IndexVariable 'lon' (lon: 360)> Size: 3kB
array([  0.,   1.,   2., ..., 357., 358., 359.], shape=(360,))}) but the corresponding component of the carry output is a <class 'xarray.core.dataset.Dataset'> with pytree metadata _HashableCoords({'lat': <xarray.IndexVariable 'lat' (lat: 181)> Size: 1kB
array([-90., -89., -88., -87., -86., -85., -84., -83., -82., -81., -80., -79.,
       -78., -77., -76., -75., -74., -73., -72., -71., -70., -69., -68., -67.,
       -66., -65., -64., -63., -62., -61., -60., -59., -58., -57., -56., -55.,
       -54., -53., -52., -51., -50., -49., -48., -47., -46., -45., -44., -43.,
       -42., -41., -40., -39., -38., -37., -36., -35., -34., -33., -32., -31.,
       -30., -29., -28., -27., -26., -25., -24., -23., -22., -21., -20., -19.,
       -18., -17., -16., -15., -14., -13., -12., -11., -10.,  -9.,  -8.,  -7.,
        -6.,  -5.,  -4.,  -3.,  -2.,  -1.,   0.,   1.,   2.,   3.,   4.,   5.,
         6.,   7.,   8.,   9.,  10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,
        18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,
        30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,
        42.,  43.,  44.,  45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,
        54.,  55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,
        66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,  77.,
        78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,  88.,  89.,
        90.]), 'lon': <xarray.IndexVariable 'lon' (lon: 360)> Size: 3kB
array([  0.,   1.,   2., ..., 357., 358., 359.], shape=(360,)), 'level': <xarray.IndexVariable 'level' (level: 13)> Size: 104B
array([  50.,  100.,  150.,  200.,  250.,  300.,  400.,  500.,  600.,  700.,
        850.,  925., 1000.]), 'time': <xarray.IndexVariable 'time' (time: 2)> Size: 16B
array(['2024-01-01T06:00:00.000000000', '2024-01-01T12:00:00.000000000'],
      dtype='datetime64[ns]')}), so the pytree node metadata does not match.

Revise the function so that the carry output has the same pytree structure as the carry input.

In [14]:
2

2

ValueError: 'grid2mesh_gnn/~_networks_builder/encoder_nodes_grid_nodes_mlp/~/linear_0/w' with retrieved shape (186, 512) does not match shape=[474, 512] dtype=dtype(bfloat16)