In [None]:
import datetime
import os
from pathlib import Path

import pandas as pd

# This must be done BEFORE importing any library that uses tqdm
import tqdm
from hwt_mpas import (  # handles whole ensemble (all members) as opposed to earth2studio.data.mpas.MPAS
    MemoryDataSource,
    MPASEnsDataSource,
)
from tqdm import notebook

tqdm.tqdm = notebook.tqdm

from earth2studio.data import LandSeaMask, SurfaceGeoPotential
from earth2studio.io import NetCDF4Backend
from earth2studio.models.px import GraphCastOperational, GraphCastSmall
from earth2studio.run import deterministic

SCRATCH = Path(os.getenv("SCRATCH"))

In [None]:
!uname -a

In [None]:
# --- Main Execution Block ---
model_class = GraphCastSmall
print(f"Initializing {model_class} model...")
model = model_class.load_model(model_class.load_default_package())
print("Model initialized successfully.")

In [None]:
model.input_coords()

In [None]:
init_time = pd.Timestamp("20230502")
model = GraphCastOperational.load_model(GraphCastOperational.load_default_package())
mpas_data_source = MPASEnsDataSource(
    grid_path=Path("MPAS/15km_mesh/grid_mesh/x1.2621442.grid.nc"),
    data_dir=f"HWT{init_time.year}/mpas_15km",
)
ds = mpas_data_source(
    init_time, 0, members=[1], variables=model.input_coords()["variable"]
).squeeze(dim="member")
ds = ds.transpose("time", ...)
# Load static data (z, lsm), fetching and saving as local files if not present.
print("Loading static data (z, lsm)...")
lat = model.input_coords()["lat"]
lon = model.input_coords()["lon"]

# Handle Geopotential (z)
print("Fetching z")
z = (
    # Tried cache=True but AttributeError: type object 'WholeFileCacheFileSystem' has no attribute '_cat_file'
    SurfaceGeoPotential(cache=False)([None])
    .sel(lat=lat, lon=lon)
    .squeeze()
)

# Handle Land-Sea Mask (lsm)
print("Fetching lsm")
lsm = LandSeaMask(cache=False)([None]).sel(lat=lat, lon=lon).squeeze()

print("Static data loaded successfully.")
ds.loc[:, "lsm"] = lsm.values
ds.loc[:, "z"] = z.values

ic = MemoryDataSource(ds)
nsteps = 4
ofile = SCRATCH / f"tmp/graphcast_mpas_{init_time:%Y%m%d%H}.nc"
if os.path.exists(ofile):
    os.remove(ofile)
io = NetCDF4Backend(ofile, backend_kwargs={"mode": "w"})
deterministic([init_time], nsteps, model, ic, io)
io.close()