In [1]:
# if necessary, install NeuralGCM and dependencies
!python --version
!pip install -q -U neuralgcm dinosaur-dycore gcsfs
!pip install matplotlib
!pip install cartopy
!pip install h5netcdf
!pip install --upgrade xarray netCDF4 h5netcdf dask


Python 3.11.10


In [2]:
import gin
gin.enter_interactive_mode()
import gcsfs
import jax
import numpy as np
import pickle
import xarray

from dinosaur import horizontal_interpolation
from dinosaur import spherical_harmonic
from dinosaur import xarray_utils
import neuralgcm

import matplotlib.pyplot as plt
import cartopy

from PIL import Image
from IPython.display import display

gcs = gcsfs.GCSFileSystem(token='anon')

In [3]:
import xarray as xr
import numpy as np


# Create a simple dataset

# Define dimensions for a ~500 MB dataset

# Example dimensions
time_dim = 64
lat_dim = 64
lon_dim = 128

# Constructing a dataset similar to `test_data` with your variables
converted_data = xr.Dataset(
    {
        "temperature": (("time", "lat", "lon"), np.random.rand(time_dim, lat_dim, lon_dim)),
        "geopotential": (("time", "lat", "lon"), np.random.rand(time_dim, lat_dim, lon_dim)),
        "specific_cloud_liquid_water_content": (("time", "lat", "lon"), np.random.rand(time_dim, lat_dim, lon_dim)),
        "u_component_of_wind": (("time", "lat", "lon"), np.random.rand(time_dim, lat_dim, lon_dim)),
        "specific_cloud_ice_water_content": (("time", "lat", "lon"), np.random.rand(time_dim, lat_dim, lon_dim)),
        "P_minus_E_cumulative": (("time", "lat", "lon"), np.random.rand(time_dim, lat_dim, lon_dim)),
        "specific_humidity": (("time", "lat", "lon"), np.random.rand(time_dim, lat_dim, lon_dim)),
        "v_component_of_wind": (("time", "lat", "lon"), np.random.rand(time_dim, lat_dim, lon_dim)),
    },
    coords={
        "time": np.arange(time_dim),
        "lat": np.linspace(-90, 90, lat_dim),
        "lon": np.linspace(-180, 180, lon_dim),
    },
)

# Print the dataset
print(converted_data)

converted_data.to_netcdf("test_minimal.nc")
print("Minimal dataset saved successfully.")


<xarray.Dataset> Size: 34MB
Dimensions:                              (time: 64, lat: 64, lon: 128)
Coordinates:
  * time                                 (time) int64 512B 0 1 2 3 ... 61 62 63
  * lat                                  (lat) float64 512B -90.0 ... 90.0
  * lon                                  (lon) float64 1kB -180.0 ... 180.0
Data variables:
    temperature                          (time, lat, lon) float64 4MB 0.4836 ...
    geopotential                         (time, lat, lon) float64 4MB 0.3903 ...
    specific_cloud_liquid_water_content  (time, lat, lon) float64 4MB 0.6764 ...
    u_component_of_wind                  (time, lat, lon) float64 4MB 0.5183 ...
    specific_cloud_ice_water_content     (time, lat, lon) float64 4MB 0.03711...
    P_minus_E_cumulative                 (time, lat, lon) float64 4MB 0.8406 ...
    specific_humidity                    (time, lat, lon) float64 4MB 0.6161 ...
    v_component_of_wind                  (time, lat, lon) float64 4MB 0.08

## Background and Motivation

<span style="color:blue">Todo later</span> Rewrite Introduction, add any relevant thins into here
Droughts are natural disasters that are getting worse every year and therefore affecting millions of people by increasing the risk of malnutrition, diseases, wildfires, or forced migration due to droughts. Developing early warning systems and timely interventions is crucial to mitigate the economic, social, and environmental impacts of droughts.



<span style="color:blue"> Add other necessary information

## Data and Objectives

Todo: Add if something is missing
Data: \
ERA 5 data \
pretrained NeuralGCM model (intermediate deterministic NeuralGCM 1.4 model) \
SST data 

Objectives: \
The aim of this project is to make a 30-year roll-out prediction for drought frequency and amplitudes in the region of Spain. 

## Drought specific variables

Evapotranspiration (not in NeuralGCM) \
Precipitation (not in NeuralGCM) \
Temperature \
Specific_humidity \
Sea Surface Temperature (not in NeuralGCM)

In [4]:
! pip install xarray netCDF4 numpy

import xarray as xr




Subset the data to be similar to NeuralGCM

## Load a pre-trained NeuralGCM model

```{caution}
Trained model weights are licensed for non-commercial use, under the Creative Commons [Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/) license (CC BY-NC-SA 4.0).
```

Pre-trained model checkpoints from the NeuralGCM paper are [available for download](https://console.cloud.google.com/storage/browser/gresearch/neuralgcm/04_30_2024) on Google Cloud Storage:

- Deterministic models:
    - `gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_0_7_deg.pkl`
    - `gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl`
    - `gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl`
- Stochastic models:
    - `gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_stochastic_1_4_deg.pkl`

## Need to train it on our own using the inputs from era5

## Load and modify the model configuration string

Ensure that the new variables (precipitation, soil moisture, evapotranspiration) are available in your dataset and properly preprocessed. The data should be regridded to match NeuralGCM's native grid and provided in the correct units. Refer to NeuralGCM's data preparation guidelines for detailed instructions.

In [5]:
gcs = gcsfs.GCSFileSystem(token='anon')
model_name = 'neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl'
with gcs.open(f'gs://gresearch/neuralgcm/04_30_2024/{model_name}', 'rb') as f:
    ckpt = pickle.load(f)

new_inputs_to_units_mapping = {
    'u': 'meter / second',
    'v': 'meter / second',
    't': 'kelvin',
    'z': 'm**2 s**-2',
    'sim_time': 'dimensionless',
    'tracers': {'specific_humidity': 'dimensionless',
                'specific_cloud_liquid_water_content': 'dimensionless',
                'specific_cloud_ice_water_content': 'dimensionless',
    },

    'diagnostics': {
        'surface_pressure': 'kg / (meter s**2)',
        'P_minus_E_cumulative': 'kg / (meter**2)'
        # Add new diagnostic variables if any
    }
}

new_model_config_str = '\n'.join([
        ckpt['model_config_str'],
        f'DimensionalLearnedPrimitiveToWeatherbenchDecoder.inputs_to_units_mapping = {new_inputs_to_units_mapping}',
        'DimensionalLearnedPrimitiveToWeatherbenchDecoder.diagnostics_module = @NodalModelDiagnosticsDecoder',
        'StochasticPhysicsParameterizationStep.diagnostics_module = @PrecipitationMinusEvaporationDiagnostics',
        'PrecipitationMinusEvaporationDiagnostics.method = "cumulative"',
        'PrecipitationMinusEvaporationDiagnostics.moisture_species =  ("specific_humidity", "specific_cloud_liquid_water_content", "specific_cloud_ice_water_content")',])

ckpt['model_config_str'] = new_model_config_str

model = neuralgcm.PressureLevelModel.from_checkpoint(ckpt)

## Subset the data and compute some variables
<span style="color:red">TODO </span> SST still NAN for Spain

<span style="color:blue">TODO later </span> Add short info about AMO and SST Anomalies?

## <span style ="color:red"> Need to change start time and end time
Chose data: Start_time = 1940-01-01, End_time = 1989-12-31

In [6]:
lat_bounds = [slice(34, 45), slice(34, 51)] 
lon_bounds = [slice(-25, 19), slice(-20, 10)]  

## Load ERA5 data from GCP/Zarr

See {doc}`datasets` for details. Leave this part


In [7]:
era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(gcs.get_mapper(era5_path), chunks=None)

start_time = '2022-06-21'
end_time = '2022-8-25'
data_inner_steps = 24  # process every 24th hour


<span style="color:green"></span>Incorporate the processed SST data or derived indices into your drought prediction model as predictors or covariates.

In [8]:
print(full_era5["mean_total_precipitation_rate"])
for i in full_era5:
    if "wind" in i:
        print(i)

<xarray.DataArray 'mean_total_precipitation_rate' (time: 1323648,
                                                   latitude: 721,
                                                   longitude: 1440)> Size: 5TB
[1374264299520 values with dtype=float32]
Coordinates:
  * latitude   (latitude) float32 3kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
  * longitude  (longitude) float32 6kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
  * time       (time) datetime64[ns] 11MB 1900-01-01 ... 2050-12-31T23:00:00
Attributes:
    long_name:   Mean total precipitation rate
    short_name:  mtpr
    units:       kg m**-2 s**-1
100m_u_component_of_wind
100m_v_component_of_wind
10m_u_component_of_neutral_wind
10m_u_component_of_wind
10m_v_component_of_neutral_wind
10m_v_component_of_wind
10m_wind_gust_since_previous_post_processing
instantaneous_10m_wind_gust
mean_direction_of_wind_waves
mean_period_of_wind_waves
mean_wave_period_based_on_first_moment_for_wind_waves
mean_wave_period_based_on_second_moment_for_

<span style="color:green">NEW</span> Add SST anomalies to the input variables of the NeuralGCM model
<span style="color:green">TODO</span> Error with too much data -> sliced or maybe on cluster

In [9]:
import pandas as pd

# Verify the change
print("Updated Dimensions:", full_era5.dims)
print("Updated Coordinates:", full_era5.coords)

# Step 1: Subset the region and time range

#time_bounds = slice('2000-01-01' ,'2020-12-31')
lat_bounds = slice(51, 34)  # Latitude bounds (51°N to 34°N)
lon_bounds = slice(-20, 10)  # Longitude bounds (-20°W to 10°E)




Updated Coordinates: Coordinates:
  * latitude   (latitude) float32 3kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
  * level      (level) int64 296B 1 2 3 5 7 10 20 ... 875 900 925 950 975 1000
  * longitude  (longitude) float32 6kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
  * time       (time) datetime64[ns] 11MB 1900-01-01 ... 2050-12-31T23:00:00


Regrid to NeuralGCM's native resolution:

In [10]:
sliced_era5 = (
    full_era5
    [model.input_variables + model.forcing_variables]
    .pipe(
        xarray_utils.selective_temporal_shift,
        variables=model.forcing_variables,
        time_shift='24 hours',
    )
    .sel(time=slice(start_time, end_time, data_inner_steps))
    .compute()
)


In [11]:

era5_grid = spherical_harmonic.Grid(
    latitude_nodes=full_era5.sizes['latitude'],
    longitude_nodes=full_era5.sizes['longitude'],
    latitude_spacing=xarray_utils.infer_latitude_spacing(full_era5.latitude),
    longitude_offset=xarray_utils.infer_longitude_offset(full_era5.longitude),
)

regridder = horizontal_interpolation.ConservativeRegridder(
    era5_grid, model.data_coords.horizontal, skipna=True
)
    
eval_era5 = xarray_utils.regrid(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)


<span style="color:green">NEW</span>:Ensure that the combined dataset adheres to NeuralGCM’s expected format.

In [12]:
# Step 4: Initialize the model state and define forecast parameters

import json

# Convert attributes to UTF-8 safe strings
def enforce_utf8_attrs(ds):
    ds.attrs = json.loads(json.dumps(ds.attrs))
    for var in ds.variables:
        ds[var].attrs = json.loads(json.dumps(ds[var].attrs))
    return ds



inner_steps = 24  # Save model outputs once every 24 hours
outer_steps = 64  # Total of 55 days
timedelta = np.timedelta64(1, 'h') * inner_steps
times = (np.arange(outer_steps) * inner_steps)  # Time axis in hours

# Initialize model state
inputs = model.inputs_from_xarray(eval_era5.isel(time=0))
input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0))
rng_key = jax.random.PRNGKey(42)  # Optional for deterministic models
initial_state = model.encode(inputs, input_forcings, rng_key)

# Use persistence for forcing variables (SST and sea ice cover)
all_forcings = model.forcings_from_xarray(eval_era5.head(time=1))

# Step 5: Make forecast
final_state, predictions = model.unroll(
    initial_state,
    all_forcings,
    steps=outer_steps,
    timedelta=timedelta,
    start_with_input=True,
)

# Convert predictions to xarray dataset
predictions_ds = model.data_to_xarray(predictions, times=times)

# Step 6: Post-process and save forecast results
forecast_dataset = predictions_ds
print(forecast_dataset)

# Iterate through each variable in the dataset and fill NaN values
# Iterate through all variables in the dataset
for var in forecast_dataset.data_vars:
    print(f"Processing variable: {var}")
    try:
        # Apply fillna only if the variable supports it
        forecast_dataset[var] = forecast_dataset[var].fillna(-999)
        print(f"Finished processing variable: {var}")
    except Exception as e:
        print(f"Skipping variable: {var} due to error: {e}")

print("All variables processed. Missing values filled with -999 where applicable.")


print("All variables processed. Missing values filled with -999.")

print("hello")
#forecast_dataset = forecast_dataset.astype(float)

print("marco")
# Save results to a NetCDF file
#forecast_dataset.to_netcdf('neuralgcm_forecast_2022.nc')

#forecast_dataset = enforce_utf8_attrs(forecast_dataset)
#forecast_dataset.to_netcdf('neuralgcm_forecast_2022.nc')

#forecast_dataset.to_zarr('neuralgcm_forecast.zarr')
#with h5py.File('neuralgcm_forecast_2022.h5', 'w') as f:
#    for var in forecast_dataset.data_vars:
#        f.create_dataset(var, data=forecast_dataset[var].values


# Define the path to your workspace
#workspace_path = "/pfs/work7/workspace/scratch/kd5572-my_workspace"
#file_name = "neuralgcm_forecast_2022.nc"

# Full path to save the NetCDF file
#save_path = f"{workspace_path}/{file_name}"

# Save to NetCDF
#prediction_ds.to_netcdf("test_minimal.nc")
#print("Minimal dataset filled with predictions and saved successfully.")

<xarray.Dataset> Size: 545MB
Dimensions:                              (time: 64, level: 37, longitude: 128,
                                          latitude: 64)
Coordinates:
  * longitude                            (longitude) float64 1kB 0.0 ... 357.2
  * latitude                             (latitude) float64 512B -87.86 ... 8...
  * level                                (level) int64 296B 1 2 3 ... 975 1000
  * time                                 (time) int64 512B 0 24 48 ... 1488 1512
Data variables:
    specific_cloud_liquid_water_content  (time, level, longitude, latitude) float32 78MB ...
    specific_cloud_ice_water_content     (time, level, longitude, latitude) float32 78MB ...
    v_component_of_wind                  (time, level, longitude, latitude) float32 78MB ...
    geopotential                         (time, level, longitude, latitude) float32 78MB ...
    P_minus_E_cumulative                 (time, longitude, latitude) float32 2MB ...
    specific_humidity         

In [13]:
# Save the NetCDF file
forecast_dataset = forecast_dataset.compute()
print("hi")

print(f"Inspecting time step 0")
subset = forecast_dataset.isel(time=0)
print(subset)

problematic_vars = []

for var in subset.data_vars:
    print(f"Testing variable: {var}")
    try:
        test_subset = subset.drop_vars(var)
        test_subset.load()
        print(f"Dataset loaded successfully without {var}.")
    except Exception as e:
        print(f"Error with variable {var}: {e}")
        problematic_vars.append(var)

print(f"Problematic variables: {problematic_vars}")

for var in problematic_vars:
    subset[var] = subset[var].fillna(-999)  # Replace NaNs
    subset[var].attrs = {}  # Remove attributes

forecast_dataset.attrs = {}
for var in forecast_dataset.data_vars:
    forecast_dataset[var].attrs = {}

#forecast_dataset.isel(time=0).to_netcdf("time0_cleaned.nc")

hi
Inspecting time step 0
<xarray.Dataset> Size: 9MB
Dimensions:                              (level: 37, longitude: 128,
                                          latitude: 64)
Coordinates:
  * longitude                            (longitude) float64 1kB 0.0 ... 357.2
  * latitude                             (latitude) float64 512B -87.86 ... 8...
  * level                                (level) int64 296B 1 2 3 ... 975 1000
    time                                 int64 8B 0
Data variables:
    specific_cloud_liquid_water_content  (level, longitude, latitude) float32 1MB ...
    specific_cloud_ice_water_content     (level, longitude, latitude) float32 1MB ...
    v_component_of_wind                  (level, longitude, latitude) float32 1MB ...
    geopotential                         (level, longitude, latitude) float32 1MB ...
    P_minus_E_cumulative                 (longitude, latitude) float32 33kB ...
    specific_humidity                    (level, longitude, latitude) float32 

In [14]:
from joblib import dump

forecast_dataset.attrs = {}
for var in forecast_dataset.data_vars:
    print("u")
    forecast_dataset[var].attrs = {}
# Save

loaded_data = forecast_dataset.load()
print(type(loaded_data))

print(loaded_data)

u
u
u
u
u
u
u
u
u
<class 'xarray.core.dataset.Dataset'>
<xarray.Dataset> Size: 545MB
Dimensions:                              (time: 64, level: 37, longitude: 128,
                                          latitude: 64)
Coordinates:
  * longitude                            (longitude) float64 1kB 0.0 ... 357.2
  * latitude                             (latitude) float64 512B -87.86 ... 8...
  * level                                (level) int64 296B 1 2 3 ... 975 1000
  * time                                 (time) int64 512B 0 24 48 ... 1488 1512
Data variables:
    specific_cloud_liquid_water_content  (time, level, longitude, latitude) float32 78MB ...
    specific_cloud_ice_water_content     (time, level, longitude, latitude) float32 78MB ...
    v_component_of_wind                  (time, level, longitude, latitude) float32 78MB ...
    geopotential                         (time, level, longitude, latitude) float32 78MB ...
    P_minus_E_cumulative                 (time, longitude, 

In [15]:
from joblib import dump

try:
    dump(loaded_data, 'forecast_dataset_test.joblib')
    print("Joblib dump successful.")
except Exception as e:
    print(f"Joblib error: {e}")


Joblib dump successful.


In [None]:
def save_file (dataset, output_path):
    print("Starting the process")
    chunk_size_kb = 1
    chunk_size_bytes = chunk_size_kb * 1024

    #needed to ensure to have at least 1 dimension
    min_lat_dim = 3
    min_lon_dim = 3
    
    with xr.Dataset() as ds_out:

         for i in range(len(dataset.time)):
            #use only one day and some lat and lon (min. 3)
            time_step = dataset.isel(time=slice(i, i + 1))   
            lat_list = list(range(0, time_step.sizes['latitude'], min_lat_dim))
            lon_list = list(range(0, time_step.sizes['longitude'], min_lon_dim))

             
            for lat_first in lat_list:
                for lon_first in lon_list:
                    lat_last = min(lat_first+ min_lat_dim, time_step.sizes['latitude'])
                    lon_last = min(lon_first + min_lon_dim, time_step.sizes['longitude'])
    
                    chunk = time_step.isel(latitude=slice(lat_first, lat_last),
                                               longitude=slice(lon_first, lon_last))

                    if chunk.sizes['latitude'] == 0 or chunk.sizes['longitude'] == 0:
                        print("Skipping chunk with empty latitude or longitude.")
                        continue

                    #reduce data (todo don't discard)
                    while chunk.nbytes > chunk_size_bytes:
                        chunk = chunk.isel(latitude=slice(0, chunk.sizes['latitude'] // 2),
                                               longitude=slice(0, chunk.sizes['longitude'] // 2))
                    print(f"Created chunk with size: {chunk.nbytes}")
    
                    mode = 'a' if i > 0 or (lat_first > 0 or lon_first > 0) else 'w'
                    chunk.to_netcdf(output_path, mode=mode, engine='netcdf4')
    
                    print(f"Chunk successfully written to file")


output_path = "forecast_dataset_chunked.nc"

save_file(forecast_dataset, output_path)
print("Saving process finished!")
loaded_data = xr.open_dataset(output_path)
print(loaded_data)

try:
    loaded_data = xr.open_dataset("forecast_dataset_chunked.nc")
    #loaded_data = xr.open_dataset("/hptc_cluster/oi6277/forecast_dataset_chunked.nc"
    print("File is loaded:")
    print(loaded_data)
except Exception as e:
    print(f"Error{e}")

In [None]:
# Check if there are any NaN values in the entire dataset
nan_exists = forecast_dataset.isnull().any()
print("NaN exists in dataset:", nan_exists)


In [None]:
output_file = 'dataset.pkl'

# Save the dataset using pickle
with open(output_file, 'wb') as f:
    pickle.dump(forecast_dataset, f)

print(f"Dataset successfully saved as {output_file}")


In [None]:
#forecast_dataframe = forecast_dataset.to_dataframe()
#forecast_dataframe.to_csv("neural.csv")
#forecast_dataset.to_netcdf("neuralgcm_2022_28.nc", "w", compute=True, engine="netcdf4")
forecast_dataset.load().to_netcdf('neuralgcm_forecast_2022.nc')

In [None]:
try:
    forecast_dataset.to_netcdf('neuralgcm_forecast_2022.nc')
    print("File written successfully!")
except Exception as e:
    print("Error while writing NetCDF file:", e)


## Make the forecast

See {doc}`trained_models` for details.

## Compare forecast to ERA5

See [WeatherBench2](https://sites.research.google/weatherbench/) for more comprehensive evaluations and archived NeuralGCM forecasts.

Can stay like this

In [None]:
# Selecting ERA5 targets from exactly the same time slice
target_trajectory = model.inputs_from_xarray(
    eval_era5
    .thin(time=(inner_steps // data_inner_steps))
    .isel(time=slice(outer_steps))
)
print("marco")
target_data_ds = model.data_to_xarray(target_trajectory, times=times)
print("polo")
combined_ds = xarray.concat([target_data_ds, predictions_ds], 'model')
print("marco")
combined_ds.coords['model'] = ['ERA5', 'NeuralGCM']

In [None]:
# Visualize ERA5 vs NeuralGCM trajectories
combined_ds.temperature.sel(level=850).plot(
    x='longitude', y='latitude', row='time', col='model', robust=True, aspect=2, size=2
)

print(predictions_ds.P_minus_E_cumulative.dims)
print(predictions_ds.P_minus_E_cumulative.coords)



predictions_ds.P_minus_E_cumulative.sel(time=24).plot(x='longitude', y='latitude', robust=True, aspect=2, size=2)


In [None]:
# Define the plot
import cartopy.crs as ccrs
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(10, 5))

# Plot the data
predictions_ds.P_minus_E_cumulative.sel(time=24).plot(
    x='longitude',
    y='latitude',
    robust=True,
    ax=ax,
    transform=ccrs.PlateCarree()
)

# Add cartographic features
ax.coastlines()
ax.add_feature(cartopy.feature.BORDERS, linestyle=':')
ax.add_feature(cartopy.feature.LAND, edgecolor='black', facecolor='lightgray', alpha=0.5)

# Customize the title and labels
ax.set_title("Cumulative P minus E at time=24", fontsize=12)
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")

plt.show()

## Data Analysis

In [None]:
#!pip install numpy==2.0
!pip install numba

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from scipy.stats import norm
#from numba import jit

In [None]:

ersst_path = "./data/tos_Omon_GISS-E2-1-G_historical_r1i1p5f1_gn_200101-201412.nc"
ersst_data = xr.open_dataset(ersst_path)

## Temperature

In [None]:
predictions_ds['temperature'].sel(
    time=slice(outer_steps),
    latitude=lat_bounds,
    longitude=lon_bounds
).plot(x='longitude', y='latitude', row='time', col='model', robust=True, aspect=2, size=2)

## SST from NeuralGCM

In [None]:
predictions_ds[sea_surface_temperature].plot(x='longitude', y='latitude', aspect=2, size=2.5);

In [None]:
# Install and import necessary libraries
!pip install dask
import xarray as xr
import dask

lat_bounds = slice(34, 51)  # Latitude bounds (51°N to 34°N)
lon_bounds = slice(-20, 10)  # Longitude bounds (-20°W to 10°E)

# Extract SST with lazy loading
sst = predictions_ds['temperature'].sel(
    time=slice(outer_steps),
    latitude=lat_bounds,
    longitude=lon_bounds)
print("1 done - SST extracted")

# Align time indexing for consistent processing
sst['time'] = sst.indexes['time']
print(sst)
print("done")

# Calculate climatological mean SST
climatological_mean_sst = sst.mean(dim='time').persist()  # Persist in memory for repeated use
print("2 done - Climatological mean computed")
print(climatological_mean_sst)
# Compute SST anomalies
sst_anomalies = (sst - climatological_mean_sst).persist()  # Persist anomalies for further analysis
print("3 done - SST anomalies computed")
print(sst_anomalies)

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
print(climatological_mean_sst.isel(level=0).values)
print(climatological_mean_sst.isel(level=-1))

print(climatological_mean_sst['level'])
# Subset the data
subset_sst = climatological_mean_sst.sel(latitude=lat_bounds, longitude=lon_bounds)

# Example for one time step
plt.figure()
ax = plt.axes(projection=ccrs.PlateCarree())
#ax.set_global()
ax.coastlines()
sst_plot = ax.imshow(
    subset_sst.isel(level=-1),
    extent=[
        subset_sst['longitude'].min(),
        subset_sst['longitude'].max(),
        subset_sst['latitude'].min(),
        subset_sst['latitude'].max(),
    ],
    transform=ccrs.PlateCarree(), cmap='coolwarm', origin='upper'
)



plt.colorbar(sst_plot, ax=ax, orientation='horizontal', label='SST Anomalies (°C)')
plt.title('Sea Surface Temperature Anomalies')
plt.show()


## SST Pattern  
<span style="color:green"></span> Compute anomalies, trends, or indices such as the Atlantic Multidecadal Oscillation (AMO) to understand SST variations over time. \

In [None]:
# Subset the data
lat_bounds = [slice(34, 45), slice(34, 51)] 
lon_bounds = [slice(-25, 19), slice(-20, 10)]  
start_time = '2005-01-01'
end_time = '2022-06-16'
sst_subset = ersst_data['tos'].sel(time=slice(start_time, end_time),
                                   lat=lat_bounds[0],
                                   lon=lon_bounds[0])
print(sst_subset.values)

# Calculate the climatology (mean over the period)
sst_subset['time'] = sst_subset.indexes['time']

climatology = sst_subset.mean(dim='time') 

print(climatology)
print(climatology.dims)
print(climatology.shape)
# Compute SST anomalies
sst_anomalies = sst_subset - climatology

sst_anomalies.to_netcdf('sst_anomalies.nc')

# Calculate the AMO index (example)
amo_index = sst_anomalies.mean(dim=['lat', 'lon'])

print(amo_index)

In [None]:
region = [-24, 26, 35, 45] 

plt.figure()
ax = plt.axes(projection=cartopy.crs.PlateCarree(central_longitude=0))
ax.set_global()
ax.set_extent(region, crs=cartopy.crs.PlateCarree())
ax.gridlines(linestyle='--',color='gray')
ax.coastlines()

temp_cartopy = ax.pcolormesh(
    sst_subset['lon'], 
    sst_subset['lat'], 
    climatology, 
    transform=cartopy.crs.PlateCarree(), 
    shading='auto', 
    cmap='bwr'
)
colorbar = plt.colorbar(temp_cartopy, ax=ax, orientation='horizontal', label='Mean Temperature')
colorbar.set_label("°C",size=12,rotation=0)
plt.title("Mean Sea Surface Temperature")
plt.show()

In [None]:
print(sst_anomalies.dims)
print(sst_anomalies)

In [None]:
plt.figure()
ax = plt.axes(projection=cartopy.crs.PlateCarree(central_longitude=0))
ax.set_global()
ax.set_extent(region, crs=cartopy.crs.PlateCarree())
ax.gridlines(linestyle='--',color='gray')
ax.coastlines()

temp_cartopy = ax.pcolormesh(sst_subset['lon'], sst_subset['lat'], sst_anomalies[0,:,:], transform=cartopy.crs.PlateCarree(), shading='auto',cmap='bwr')
colorbar = plt.colorbar(temp_cartopy, ax=ax, orientation='horizontal', label='SST Anomalies')
colorbar.set_label("°C",size=12,rotation=0)
plt.title("SST Anomalies")
plt.show()

In [None]:
plt.figure(figsize=(12,6))
plt.plot(amo_index, label="AMO INDEX", color="b")
plt.axhline(0, color="k", linestyle="--", linewidth=0.8, label="Zero Anomaly")
plt.title("Atlantic Multidecadal Oscillation (AMO) Index")
plt.xlabel("Time")
plt.ylabel("AMO Index")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

## RMSE SST
true values from file and predicted_values from neuralgcm

In [None]:
rmse = np.sqrt(np.mean((true_values - predicted_values) ** 2))

## SPEI

In [None]:
# Function to calculate SPEI directly 
def calculate_spei(predictions_ds, scale=3):
    import numpy as np
    import xarray as xr

    def thornthwaite(temp, lat):
        """
        Thornthwaite method to estimate potential evapotranspiration (PET).
        """
        temp_celsius = temp - 273.15  # Convert from Kelvin to Celsius
        if temp_celsius > 0:
            I = (temp_celsius / 5) ** 1.514
            a = (6.75e-7 * I ** 3) - (7.71e-5 * I ** 2) + (1.792e-2 * I) + 0.49239
            PET = 16 * ((10 * temp_celsius / I) ** a)  # Simplified Thornthwaite formula
        else:
            PET = 0  # PET is zero if temperature is below or equal to zero
    
        return PET



    def compute_spei(D, scale=3):
        """
        Compute Standardized Precipitation Evapotranspiration Index (SPEI) at a given scale.
        """
        rolling_mean = D.rolling(time=scale, center=False).mean()
        rolling_std = D.rolling(time=scale, center=False).std()
        spei = (rolling_mean - rolling_mean.mean(dim="time")) / rolling_std
        print("2")
        return spei

    # Step 1: Calculate PET
    latitude = predictions_ds['latitude']
    temperature = predictions_ds['temperature'].sel(level=1000)  # Near-surface temperature
    latitude_value = latitude.mean().item()  # Simplified to one value for now
    PET = xr.apply_ufunc(thornthwaite, temperature, latitude_value, vectorize=True)

# Calculate D
    P_minus_E = predictions_ds['P_minus_E_cumulative']
    D = P_minus_E - PET
# Compute SPEI
    SPEI = compute_spei(D, scale=scale)
    print("3")
    return SPEI

SPEI_result = calculate_spei(predictions_ds)
print(SPEI_result)
spei_avg = SPEI_result.mean(dim=['latitude', 'longitude'])
time = spei_avg['time']
spei_values = spei_avg.values



In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Define the starting date and daily interval
start_date = '2022-06-21'  # Replace with your actual start date
time = pd.date_range(start=start_date, periods=len(spei_values), freq='D')  # Daily frequency

# Plot
plt.figure(figsize=(10, 6))
plt.plot(time, spei_values, label='SPEI (Spatial Average)', color='blue')
plt.axhline(0, color='black', linestyle='--', linewidth=0.8, label='Neutral')  # Reference line
plt.axhline(-1, color='orange', linestyle='--', label='Moderate Drought')
plt.axhline(-2, color='red', linestyle='--', label='Severe Drought')
plt.axhline(1, color='green', linestyle='--', label='Moderate Wet')
plt.axhline(2, color='darkgreen', linestyle='--', label='Severe Wet')

# Formatting
plt.title("SPEI Time Series (Spatial Average)")
plt.xlabel("Date")
plt.ylabel("SPEI Value")
plt.legend()
plt.grid()

# Automatically format date labels on the x-axis
plt.gcf().autofmt_xdate()

plt.show()


## Adding later

In [None]:
from PIL import Image
from IPython.display import display

img = Image.open('images/Spain_range_1.png')
display(img)

In [None]:
from PIL import Image
from IPython.display import display

img = Image.open('images/Spain_range_1.png')
display(img)

In [None]:
img = Image.open('images/Spain_range_1.png')
display(img)

In [None]:
#import signal
#import sys
#import numpy as np
#import jax
#
# Parameters for 30-year rollout
#years = 30
#days_per_year = 365  # Exclude leap years for simplicity
#inner_steps = 24  # Save model outputs every 24 hours
#hours_per_day = 24
outer_steps = (days_per_year * years * hours_per_day) // inner_steps  # Total steps for 30 years
timedelta = np.timedelta64(inner_steps, 'h')  # Time interval between model outputs
times = np.arange(outer_steps) * inner_steps  # Time axis in hours

# Placeholder model and data (replace with actual implementations)
# model = ...
# eval_era5 = ...

class GracefulExit:
    """Handles graceful exit and file closing."""
    def __init__(self, file):
        self.file = file
        self.is_running = True

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.cleanup()

    def cleanup(self):
        print("Shutting down gracefully...")
        if not self.file.closed:
            self.file.close()
        self.is_running = False

# Signal handler to trigger cleanup
def handle_signal(signum, frame):
    global graceful_exit_context
    graceful_exit_context.cleanup()

# Register signal handlers
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)

output_file = "30_year_rollout_predictions.nc"

with open(output_file, "w") as nc_file:
    with GracefulExit(nc_file) as graceful_exit_context:
        try:
            # Initialize model state
            print("Initializing model state...")
            inputs = model.inputs_from_xarray(eval_era5.isel(time=0))
            input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0))
            rng_key = jax.random.key(42)  # Optional for deterministic models
            initial_state = model.encode(inputs, input_forcings, rng_key)

            # Use persistence for forcing variables (e.g., SST and sea ice cover)
            print("Using persistent forcing variables...")
            all_forcings = model.forcings_from_xarray(eval_era5.head(time=1))

            # Make forecast
            print(f"Starting 30-year rollout with {outer_steps} steps...")
            final_state, predictions = model.unroll(
                initial_state,
                all_forcings,
                steps=outer_steps,
                timedelta=timedelta,
                start_with_input=True,
            )

            # Convert predictions to xarray dataset
            print("Converting predictions to xarray.Dataset...")
            #print(predictions)
            predictions_ds = model.data_to_xarray(predictions, times=times)
            print(predictions_ds)
            # Save results to a NetCDF file
            print("Applying chunking to the dataset...")
            #chunked_ds = predictions_ds.chunk({'time': 100})  # Adjust chunk size as needed
            #chunked_ds.to_netcdf(output_file, mode='w', engine='netcdf4')


            # Final message
            print("30-year rollout completed successfully.")

        except Exception as e:
            print(f"An error occurred: {e}")
            graceful_exit_context.cleanup()
        finally:
            print("Execution stopped.")
