In [None]:
# ! pip install -q -U neuralgcm dinosaur gcsfs
# ! pip install --upgrade xarray zarr gcsfs
# ! pip install -q -U aiohttp

from typing import Any
import xarray
import gcsfs
import pickle as pkl
import neuralgcm as gcm
import numpy as np
import fsspec
import zarr
import google.auth

from google.cloud import storage
from dinosaur import horizontal_interpolation
from dinosaur import spherical_harmonic
from dinosaur import xarray_utils

gcs = gcsfs.GCSFileSystem(token='anon')

model_list = ['v1/deterministic_0_7_deg.pkl', 'v1/deterministic_1_4_deg.pkl', 'v1/deterministic_2_8_deg.pkl', 'v1/stochastic_1_4_deg.pkl', 'v1_precip/stochastic_precip_2_8_deg.pkl', 'v1_precip/stochastic_evap_2_8_deg']

model_name = 'v1/deterministic_2_8_deg.pkl'

with gcs.open(f'gs://neuralgcm/models/{model_name}', 'rb') as f:
  ckpt = pkl.load(f)

model = gcm.PressureLevelModel.from_checkpoint(ckpt)

# era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'

full_era5 = xarray.open_zarr(
    'gs://gcp-public-data-arco-era5/ar/1959-2022-full_37-1h-0p25deg-chunk-1.zarr-v2',
    storage_options={'token': 'anon'},
    chunks={'time': 48},
    consolidated=True,
)
# fs = gcsfs.GCSFileSystem(token="anon",project='public')

# store = gcs.get_mapper(era5_path)

# store = zarr.storage.FSStore(path=era5_path, fs=fs)

# full_era5 = xarray.open_zarr(store=store, consolidated=True, chunks=None)

# full_era5 = xarray.open_zarr(fs.get_mapper(era5_path))

#retry
# Initialize the storage client

# Get the bucket and blob

# era5 = xarray.open_zarr(
#     store,
#     chunks={'time': 48},
#     consolidated=True,
# )
demo_start_time = '2020-02-14'
demo_end_time = '2020-02-18'
data_inner_steps = 24  # process every 24th hour

available_vars = list(set(model.input_variables + model.forcing_variables).intersection(full_era5.data_vars))

sliced_era5 = (
    full_era5[model.input_variables+ model.forcing_variables]
    .pipe(
        xarray_utils.selective_temporal_shift,
        variables=model.forcing_variables,
        time_shift='24 hours',
    )
    .sel(time=slice(demo_start_time, demo_end_time, data_inner_steps))
    .compute()
)


era5_grid = spherical_harmonic.Grid(
    latitude_nodes=full_era5.sizes['latitude'],
    longitude_nodes=full_era5.sizes['longitude'],
    latitude_spacing=xarray_utils.infer_latitude_spacing(full_era5.latitude),
    longitude_offset=xarray_utils.infer_longitude_offset(full_era5.longitude),
)


In [None]:

regridder = horizontal_interpolation.ConservativeRegridder(
    era5_grid, model.data_coords.horizontal, skipna=True
)

eval_era5 = xarray_utils.regrid(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)

inner_steps = 24 # saves every half day --> * effects the index selection of time bellow
outer_steps = 4*inner_steps # 4 days

timedelta = np.timedelta64(1, 'h') * inner_steps # increment of time change over time
times = (np.arange(outer_steps) * inner_steps) # x-axis

# model configuration

# vari ={
#     'specific_cloud_ice_water_content': 'ciwic'
# }
# renamed_era5 = eval_era5.rename_vars(vari)

# inputs = renamed_era5.isel(time = 0)


# model.inputs_from_xarray = [var for var in model.input_variables if var in eval_era5.data_vars]

eval_era5["specific_cloud_ice_water_content"] = [{'level': 2, 'longitude': 3, 'latitude': 4}, {'level': 2, 'longitude': 3, 'latitude': 4}]

inputs = model.inputs_from_xarray(eval_era5.isel(time=0))
input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0))
rng_key = jax.random.key(42)  # optional for deterministic models
initial_state = model.encode(inputs, input_forcings, rng_key)


all_forcings = model.forcings_from_xarray(eval_era5.head(time=1))