In [1]:
# if necessary, install NeuralGCM and dependencies
!pip install -q -U neuralgcm dinosaur-dycore gcsfs

# Forecasting quick start

This notebook uses ERA5 data and pretrained NeuralGCM model to make a weather forecast.

The forecast is made in 3 steps:
1. Slice of ERA5 data is regridded to model resolution
2. NeuralGCM model state is initialized and rolled out
3. Predictions and reference trajectory are combined for visualization

By default the notebook uses intermediate deterministic NeuralGCM 1.4° model. Other available checkpoints include deterministic 0.7°, 2.8° and stochastic 1.4° NeuralGCM variations.

```{tip}
You can run this notebook yourself in [Google Colab](https://colab.research.google.com/github/google-research/neuralgcm/blob/main/docs/inference_demo.ipynb). We recommend using a GPU or TPU runtime due to high memory and compute requirements.
```

In [2]:
import gcsfs
import jax
import numpy as np
import pickle
import xarray

from dinosaur import horizontal_interpolation
from dinosaur import spherical_harmonic
from dinosaur import xarray_utils
import neuralgcm

gcs = gcsfs.GCSFileSystem(token='anon')

## Add the SST pattern that was send to us https://esgf-ui.ceda.ac.uk/cog/search/cmip6-ceda/ 
How to: \
tos in die Suchleiste
<span style="color:red">~~Download erledigt~~</span>

In [8]:
! pip install xarray netCDF4 numpy

import xarray as xr

ersst_path = "./data/tos_Omon_GISS-E2-1-G_historical_r1i1p5f1_gn_200101-201412.nc"
ersst_data = xr.open_dataset(ersst_path)

  pid, fd = os.forkpty()




Subset the data to be similar to NeuralGCM

## Load a pre-trained NeuralGCM model

```{caution}
Trained model weights are licensed for non-commercial use, under the Creative Commons [Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/) license (CC BY-NC-SA 4.0).
```

Pre-trained model checkpoints from the NeuralGCM paper are [available for download](https://console.cloud.google.com/storage/browser/gresearch/neuralgcm/04_30_2024) on Google Cloud Storage:

- Deterministic models:
    - `gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_0_7_deg.pkl`
    - `gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl`
    - `gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl`
- Stochastic models:
    - `gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_stochastic_1_4_deg.pkl`

<span style="color:green">NEW</span> 
## Need to train it on our own using the inputs from era5

Checklist: \
~~Find out all variables~~ \
Training

In [9]:
drought_variables = [
    "precipitation",      # To track precipitation deficits
    "evaporation",        # Water loss through evaporation
    "soil_moisture",      # Soil water content for agricultural impacts
    "temperature",        # High temperatures linked to drought
    "specific_humidity",  # Tracks atmospheric moisture
    "surface_pressure",   # Indicator of regional pressure systems
    "sea_surface_temp",   # SST anomalies linked to teleconnections
]

In [10]:
model_name = 'neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl'  #@param ['neural_gcm_dynamic_forcing_deterministic_0_7_deg.pkl', 'neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl', 'neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl', 'neural_gcm_dynamic_forcing_stochastic_1_4_deg.pkl'] {type: "string"}

with gcs.open(f'gs://gresearch/neuralgcm/04_30_2024/{model_name}', 'rb') as f:
  ckpt = pickle.load(f)

model = neuralgcm.PressureLevelModel.from_checkpoint(ckpt)

##
# TODO: Train on our own, Decide on training and testing data and do the 30 year roll out
# 

  ckpt = pickle.load(f)


## Load ERA5 data from GCP/Zarr

See {doc}`datasets` for details. Leave this part


Select out a few days of data:

## Need to change start time and end time

Checklist: \
~~Select spain as a region~~ \
<span style="color:red">Decide on start and end time</span> \


In [14]:
era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(gcs.get_mapper(era5_path), chunks=None)

start_date = '2020-02-14'
end_date = '2020-02-18'
data_inner_steps = 24  # process every 24th hour

# Define the latitude and longitude bounds for Spain
# Die Werte müssen eigentlich größer sein, da im Mittelmeer auch Abläufe sind, die das beeinflussen. Für die Fläche 

lat_bounds = slice(36, 44)  # Latitude range for Spain mainland
lon_bounds = slice(-10, 4)  # Longitude range for Spain mainland

<span style="color:green">NEW</span>Subset the data before adding it to era5 input for the model

In [15]:

# Subset the data
sst_subset = ersst_data['tos'].sel(time=slice(start_date, end_date),
                                   lat=lat_bounds,
                                   lon=lon_bounds)

<span style="color:green">NEW</span> Compute anomalies, trends, or indices such as the Atlantic Multidecadal Oscillation (AMO) to understand SST variations over time.

In [18]:
# Calculate the climatology (mean over the period)
climatology = sst_subset.groupby('time').mean('time')

# Compute SST anomalies
sst_anomalies = sst_subset.groupby('time') - climatology

# Calculate the AMO index (example)
amo_index = sst_anomalies.mean(dim=['lat', 'lon'])

print(climatology)
print(sst_anomalies)
print(amo_index)

ValueError: time must not be empty

In [19]:
# TODO Plot them to have something to see? Low Priority

<span style="color:green">NEW</span>Incorporate the processed SST data or derived indices into your drought prediction model as predictors or covariates.

In [32]:
precipitation_data = full_era5['mean_total_precipitation_rate'] #idk if it is precipitation TODO: print variable names to know
temperature_data = full_era5['temperature']

model_inputs = {
    'precipitation': precipitation_data,
    'temperature': temperature_data,
    'sst_anomalies': sst_anomalies,
    # TODO: Add other variables as needed
}

NameError: name 'sst_anomalies' is not defined

In [31]:
print(full_era5.var)

print(full_era5["mean_total_precipitation_rate"])
for i in full_era5:
    if "pre" in i:
        print(i)

<bound method DatasetAggregations.var of <xarray.Dataset> Size: 4PB
Dimensions:                                                          (
                                                                      time: 1323648,
                                                                      latitude: 721,
                                                                      longitude: 1440,
                                                                      level: 37)
Coordinates:
  * latitude                                                         (latitude) float32 3kB ...
  * level                                                            (level) int64 296B ...
  * longitude                                                        (longitude) float32 6kB ...
  * time                                                             (time) datetime64[ns] 11MB ...
Data variables: (12/273)
    100m_u_component_of_wind                                         (time, latitude, longitude) flo

<span style="color:green">NEW</span> Add SST anomalies to the input variables of the NeuralGCM model

In [24]:
# Add SST anomalies to the input variables of the NeuralGCM model
era5_with_sst = xr.Dataset(
    {
        'precipitation': full_era5['mean_total_precipitation_rate'],
        'temperature': full_era5['temperature'],
        'humidity': full_era5['specific_humidity'],
        'sst_anomalies': sst_anomalies_grid,
        # TODO: ADD remining inputs
    },
    coords={
        'time': full_era5['time'],
        'latitude': full_era5['latitude'],
        'longitude': full_era5['longitude'],
    }
)

<xarray.Dataset> Size: 4PB
Dimensions:                                                          (
                                                                      time: 1323648,
                                                                      latitude: 721,
                                                                      longitude: 1440,
                                                                      level: 37)
Coordinates:
  * latitude                                                         (latitude) float32 3kB ...
  * level                                                            (level) int64 296B ...
  * longitude                                                        (longitude) float32 6kB ...
  * time                                                             (time) datetime64[ns] 11MB ...
Data variables: (12/273)
    100m_u_component_of_wind                                         (time, latitude, longitude) float32 5TB ...
    100m_v_component_of_wind

KeyError: "No variable named 'precipitation'. Variables on the dataset include ['100m_u_component_of_wind', '100m_v_component_of_wind', '10m_u_component_of_neutral_wind', '10m_u_component_of_wind', '10m_v_component_of_neutral_wind', ..., 'wave_spectral_directional_width_for_wind_waves', 'wave_spectral_kurtosis', 'wave_spectral_peakedness', 'wave_spectral_skewness', 'zero_degree_level']"

Regrid to NeuralGCM's native resolution: <span style="color:red">Rewrite it to use the sst data</span>.


In [None]:
era5_grid = spherical_harmonic.Grid(
    latitude_nodes=full_era5.sizes['latitude'],
    longitude_nodes=full_era5.sizes['longitude'],
    latitude_spacing=xarray_utils.infer_latitude_spacing(full_era5.latitude),
    longitude_offset=xarray_utils.infer_longitude_offset(full_era5.longitude),
)
regridder = horizontal_interpolation.ConservativeRegridder(
    era5_grid, model.data_coords.horizontal, skipna=True
)
eval_era5 = xarray_utils.regrid(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)

<span style="color:green">NEW</span>:Ensure that the combined dataset adheres to NeuralGCM’s expected format.

In [None]:
model_inputs = {
    variable: era5_with_sst[variable].values for variable in era5_with_sst.data_vars
}

## Make the forecast

See {doc}`trained_models` for details.

In [None]:
inner_steps = 24  # save model outputs once every 24 hours
outer_steps = 4 * 24 // inner_steps  # total of 4 days
timedelta = np.timedelta64(1, 'h') * inner_steps
times = (np.arange(outer_steps) * inner_steps)  # time axis in hours

# initialize model state
# TODO: change the input with the new one


inputs = model.inputs_from_xarray(eval_era5.isel(time=0))
input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0))
rng_key = jax.random.key(42)  # optional for deterministic models
initial_state = model.encode(inputs, input_forcings, rng_key)

# use persistence for forcing variables (SST and sea ice cover)
all_forcings = model.forcings_from_xarray(eval_era5.head(time=1))

# make forecast
final_state, predictions = model.unroll(
    initial_state,
    all_forcings,
    steps=outer_steps,
    timedelta=timedelta,
    start_with_input=True,
)
predictions_ds = model.data_to_xarray(predictions, times=times)

## Compare forecast to ERA5

See [WeatherBench2](https://sites.research.google/weatherbench/) for more comprehensive evaluations and archived NeuralGCM forecasts.

Can stay like this

In [None]:
# Selecting ERA5 targets from exactly the same time slice
target_trajectory = model.inputs_from_xarray(
    eval_era5
    .thin(time=(inner_steps // data_inner_steps))
    .isel(time=slice(outer_steps))
)
target_data_ds = model.data_to_xarray(target_trajectory, times=times)

combined_ds = xarray.concat([target_data_ds, predictions_ds], 'model')
combined_ds.coords['model'] = ['ERA5', 'NeuralGCM']

In [None]:
# Visualize ERA5 vs NeuralGCM trajectories
combined_ds.specific_humidity.sel(level=850).plot(
    x='longitude', y='latitude', row='time', col='model', robust=True, aspect=2, size=2
);

## Data Analysis