In [None]:
# I thought I had Graphcast working with MPAS HWT Ens at one point, but maybe I never did. Results look weird.
# Features do not move eastward; they only change intensity with time.

In [None]:
import datetime
import os
from pathlib import Path

import pandas as pd

# This must be done BEFORE importing any library that uses tqdm
import tqdm

# MPASEnsDataSource handles whole ensemble (all members) as opposed to earth2studio.data.mpas.MPAS
# It does not follow the earth2studio DataSource Protocol, because it expects one init_time
# and one integer lead_time. It can't be used directly (as 2nd-to-last arg) in deterministic. 
from hwt_mpas import (
    MemoryDataSource,
    MPASEnsDataSource,
)
from tqdm import notebook

tqdm.tqdm = notebook.tqdm

from earth2studio.data import LandSeaMask, SurfaceGeoPotential
from earth2studio.io import NetCDF4Backend, XarrayBackend
from earth2studio.models.px import GraphCastOperational, GraphCastSmall
from earth2studio.run import deterministic

SCRATCH = Path(os.getenv("SCRATCH"))

In [None]:
!uname -a

In [None]:
# --- Main Execution Block ---
model_class = GraphCastOperational
print(f"Initializing {model_class} model...")
model = model_class.load_model(model_class.load_default_package())
print("Model initialized successfully.")

model.input_coords()

In [None]:
ds

In [None]:
init_time = pd.Timestamp("20230502")
model = GraphCastOperational.load_model(GraphCastOperational.load_default_package())

mpas_data_source = MPASEnsDataSource(
    grid_path=Path("MPAS/15km_mesh/grid_mesh/x1.2621442.grid.nc"),
    data_dir=f"HWT{init_time.year}/mpas_15km",
)
# Load static data (z, lsm)
lat = model.input_coords()["lat"]
lon = model.input_coords()["lon"]

# Handle Geopotential (z)
print("Fetching z")
z = (
    # Tried cache=True but AttributeError: type object 'WholeFileCacheFileSystem' has no attribute '_cat_file'
    SurfaceGeoPotential(cache=False)([None])
    .sel(lat=lat, lon=lon)
    .squeeze()
)

# Handle Land-Sea Mask (lsm)
print("Fetching lsm")
lsm = LandSeaMask(cache=False)([None]).sel(lat=lat, lon=lon).squeeze()

dss = []
for lead_time in [18, 24]:
    ds = mpas_data_source(
        init_time - pd.Timedelta(days=1), lead_time, members=[1], variables=model.input_coords()["variable"]
    ).squeeze(dim="member")
    ds = ds.transpose("time", ...)
    ds.loc[:,"lsm"] = lsm.reindex(lat=ds["lat"])  # one is N-S; the other is S-N
    ds.loc[:,"z"] = z.reindex(lat=ds["lat"])
    dss.append(ds)

In [None]:
import xarray as xr
dss = xr.concat(dss, dim="time").sel(lat=model.input_coords()["lat"])
dss

In [None]:
import xarray as xr
from earth2studio.data import DataSource
from typing import List
import datetime
import numpy as np

class MemoryDataSource(DataSource):
    """
    A simple data source that holds a single xarray.DataArray state in memory.
    """

    def __init__(self, data: xr.DataArray):
        """
        Initializes the data source.
        
        Parameters
        ----------
        data : xr.DataArray
            The data to hold in memory. Must have 'time' and 'variable' 
            coordinates. This DataArray should contain ALL time slices
            needed for the initialization.
        """
        super().__init__()
        self.data = data
        if "time" not in self.data.coords:
            raise ValueError("DataArray must have a 'time' coordinate.")

    def __call__(
        self,
        time: List[datetime.datetime], # This 'time' IS the valid_time
        variable: List[str],
        **kwargs
    ):
        """
        Selects and returns data for the requested valid time.
        'time' is a list of valid times (usually of length 1).
        """
        
        # 'time' is a list, e.g., ['2023-04-30T18:00...']
        # Select the first (and only) item.
        valid_time = time[0]
        
        try:
            # Select data using the scalar valid_time.
            # This automatically drops the 'time' dimension,
            # which prevents the original CoordinateValidationError.
            data_slice = self.data.sel(time=[valid_time], variable=variable)

        except KeyError:
            # Add a helpful error if the required time isn't in the data
            raise KeyError(
                f"Could not find valid_time {valid_time} in the MemoryDataSource. "
                f"Available times: {self.data.time.values}"
            )
            
        return data_slice

In [None]:
ic = MemoryDataSource(dss)
nsteps = 5
ofile = SCRATCH / f"tmp/mpas_graphcast_{init_time:%Y%m%d%H}.nc"
if os.path.exists(ofile):
    os.remove(ofile)
io = NetCDF4Backend(ofile, backend_kwargs={"mode": "w"})
deterministic([init_time], nsteps, model, ic, io)
io.close()

In [None]:
dss.sel(lat=model.input_coords()["lat"])