In [None]:
import datetime
import os
import pdb
from pathlib import Path
from typing import Any, Iterable

import numpy as np
import pandas as pd
import torch

# --- MONKEY-PATCH TQDM FOR JUPYTER ---
# This must be done BEFORE importing any library that uses tqdm
import tqdm
import xarray as xr
from hwt_mpas import MPASDataSource
from netCDF4 import Dataset
from tqdm import notebook

tqdm.tqdm = notebook.tqdm

from earth2studio.data import GEFS_FX, GFS, DataSource, LandSeaMask, SurfaceGeoPotential
from earth2studio.io import IOBackend, NetCDF4Backend
from earth2studio.io.zarr import ZarrBackend
from earth2studio.models.px import GraphCastOperational, GraphCastSmall
from earth2studio.run import deterministic
from earth2studio.utils.type import CoordSystem

SCRATCH = Path(os.getenv("SCRATCH"))

In [None]:
class MemoryDataSource(DataSource):
    """A simple data source that holds a single xarray.DataArray state in memory."""

    def __init__(self, data: xr.DataArray):
        super().__init__()
        self.data = data

    def __call__(self, init_time, variable, **kwargs):
        return self.data


# Define a custom IO class that subsets the data before writing to NetCDF.
class SubsetNetCDF4Backend(IOBackend):
    def __init__(
        self, file_name: str, lat_slice: slice, lon_slice: slice, backend_kwargs: dict = {}
    ):
        self.file_name = file_name
        self.lat_slice = lat_slice
        self.lon_slice = lon_slice
        self.backend_kwargs = backend_kwargs
        self.writer = None

    def add_array(
        self, coords: CoordSystem, array_name: str | list[str], **kwargs: dict[str, Any]
    ) -> None:
        # Create a temporary xarray object to correctly select coordinate values.
        dummy_data = np.zeros([len(v) for v in coords.values()])
        temp_da = xr.DataArray(dummy_data, coords=coords, dims=list(coords.keys()))

        # Select the subset using coordinate values (degrees)
        subset_da = temp_da.sel(lat=self.lat_slice, lon=self.lon_slice)

        # Extract the subsetted coordinates as a dictionary
        subset_coords = {k: v.values for k, v in subset_da.coords.items()}

        # Initialize the internal NetCDF4Backend with the subsetted coordinates
        self.writer = NetCDF4Backend(self.file_name, self.backend_kwargs)
        self.writer.add_array(subset_coords, array_name, **kwargs)

    def write(
        self,
        x: torch.Tensor | list[torch.Tensor],
        coords: CoordSystem,
        array_name: str | list[str],
    ) -> None:
        if self.writer is None:
            raise RuntimeError("add_array must be called before write.")

        if not isinstance(x, list):
            x = [x]
            array_name = [array_name]

        for i, tensor in enumerate(x):
            var_name = array_name[i]

            # 1. Create a DataArray from the incoming global data tensor
            temp_da = xr.DataArray(tensor.cpu().numpy(), coords=coords, dims=list(coords.keys()))

            # 2. Select the desired subset using coordinate values (degrees)
            subset_da = temp_da.sel(lat=self.lat_slice, lon=self.lon_slice)

            # 3. Extract the subsetted data and coordinates for the writer
            subset_x = torch.from_numpy(subset_da.data)
            subset_coords = {k: v.values for k, v in subset_da.coords.items()}

            # 4. Write the subsetted data using the internal writer
            self.writer.write(subset_x, subset_coords, var_name)

    def close(self):
        if self.writer:
            self.writer.close()


def run(init_time, model, z, lsm, members=["gec00"] + [f"gep{p:02d}" for p in range(1, 31)]):
    forecast_length = 240
    forecast_step_hours = 6
    nsteps = forecast_length // forecast_step_hours

    output_dir = f"/glade/derecho/scratch/ahijevyc/ai-models/output/graphcast/{init_time:%Y%m%d%H}"
    os.makedirs(output_dir, exist_ok=True)
    print(f"Ensemble forecast outputs will be saved in: {output_dir}")

    model_variables = model.input_coords()["variable"]
    vars_to_zero_fill = [v for v in model_variables if v.startswith("w") or v == "tp06"]

    vars_to_fetch = [v for v in model_variables if v not in vars_to_zero_fill]
    vars_to_fetch.remove("z")
    vars_to_fetch.remove("lsm")

    for member in members:
        output_filepath = os.path.join(output_dir, f"{member}.nc")
        lat_slice = slice(20, 60)
        lon_slice = slice(220, 300)

        if os.path.exists(output_filepath):
            try:
                with xr.open_dataset(output_filepath) as ds:
                    if len(ds.data_vars) != 85:
                        raise ValueError(
                            f"Incorrect number of data variables. Expected 85, found {len(ds.data_vars)}."
                        )
                    for dim_name, dim_size in ds.dims.items():
                        if dim_size == 0:
                            raise ValueError(f"Dimension '{dim_name}' has size 0.")
                    if any(ds.z500.squeeze().max(dim=["lat", "lon"]) > 1e30):
                        raise ValueError(f"bad data in {output_filepath}")
                    print(
                        f"Valid and complete forecast file already exists for member '{member}', skipping."
                    )
                    continue
            except Exception as e:
                print(
                    f"Found invalid or incomplete file for member '{member}', removing. Error: {e}"
                )
                os.remove(output_filepath)

        print(f"Fetching initial conditions for {member} at {init_time.isoformat()}...")
        gefs_source = GEFS_FX(member=member)
        initial_state_partial = gefs_source(init_time, [datetime.timedelta(hours=0)], vars_to_fetch)

        print(f"Regridding initial state for {member}...")
        lat = model.input_coords()["lat"]
        lon = model.input_coords()["lon"]
        wrapped = initial_state_partial.sel(lon=0).assign_coords(lon=360)
        initial_state_periodic = xr.concat([initial_state_partial, wrapped], dim="lon")
        initial_state_partial = initial_state_periodic.interp(lat=lat, lon=lon, method="linear")

        data_arrays_to_concat = [initial_state_partial]
        for var_name in vars_to_zero_fill:
            zero_array = xr.zeros_like(initial_state_partial.isel(variable=0))
            zero_array["variable"] = var_name
            data_arrays_to_concat.append(zero_array)
        data_arrays_to_concat.extend([z, lsm])
        initial_state = xr.concat(data_arrays_to_concat, dim="variable", coords="minimal")

        assert initial_state.notnull().all()
        initial_state = initial_state.sel(variable=model_variables).squeeze(
            dim="lead_time", drop=True
        )

        in_memory_source = MemoryDataSource(initial_state)

        print(f"Running forecast and subsetting for member '{member}'...")

        subset_writer = SubsetNetCDF4Backend(
            file_name=output_filepath,
            lat_slice=lat_slice,
            lon_slice=lon_slice,
            backend_kwargs={"mode": "w"},
        )

        deterministic([init_time], nsteps, model, in_memory_source, subset_writer)

        subset_writer.close()

        print(f"Successfully created forecast file: {output_filepath}")
        print(f"--- Finished forecast for member '{member}' ---")

    print(
        f"\nâœ… All ensemble member forecasts for {init_time.date()} have been successfully generated."
    )

In [None]:
# --- Main Execution Block ---
print("Initializing GraphCast model...")
model_class = GraphCastOperational
model = model_class.load_model(model_class.load_default_package())
print("Model initialized successfully.")

# Load static data (z, lsm), fetching and saving as local files if not present.
print("Loading static data (z, lsm)...")
lat = model.input_coords()["lat"]
lon = model.input_coords()["lon"]
static_data_dir = "/glade/derecho/scratch/ahijevyc/ai-models/static_data"
os.makedirs(static_data_dir, exist_ok=True)
z_filepath = os.path.join(static_data_dir, "graphcast_z_0.25deg.nc")
lsm_filepath = os.path.join(static_data_dir, "graphcast_lsm_0.25deg.nc")
static_data_time = datetime.datetime(2023, 1, 1)

# Handle Geopotential (z)
if os.path.exists(z_filepath):
    print(f"Loading z from local file: {z_filepath}")
    z = xr.open_dataarray(z_filepath)
else:
    print(f"Fetching z and saving to: {z_filepath}")
    z_data = (
        SurfaceGeoPotential(cache=False)([static_data_time])
        .sel(lat=lat, lon=lon)
        .squeeze(dim="time")
    )
    z_data["variable"] = ["z"]
    z_data.to_netcdf(z_filepath)
    z = z_data

# Handle Land-Sea Mask (lsm)
if os.path.exists(lsm_filepath):
    print(f"Loading lsm from local file: {lsm_filepath}")
    lsm = xr.open_dataarray(lsm_filepath)
else:
    print(f"Fetching lsm and saving to: {lsm_filepath}")
    lsm_data = (
        LandSeaMask(cache=False)([static_data_time]).sel(lat=lat, lon=lon).squeeze(dim="time")
    )
    lsm_data.to_netcdf(lsm_filepath)
    lsm = lsm_data

print("Static data loaded successfully.")

for init_time in pd.to_datetime(pd.date_range("20230424", "20230531")):
    print(f"\n{'='*20} Starting Run for {init_time.date()} {'='*20}")
    run(init_time, model, z, lsm)

In [None]:
class GFSFill(GFS):
    """
    Intercepts requests for specified variables and provides predefined data arrays.
    For all other variables, it falls back to the standard GFS implementation.
    """

    def __init__(self, custom_arrays: dict[str, xr.DataArray], *args, **kwargs):
        """
        Initializes with a dictionary of custom xarray.DataArrays,
        where the keys are the variable names.
        """
        super().__init__(*args, **kwargs)
        self.custom_arrays = custom_arrays
        self.custom_vars = list(custom_arrays.keys())

    def __call__(
        self,
        time: np.ndarray,
        variable: np.ndarray,
    ) -> xr.DataArray:
        # Identify which custom variables are in the current request
        requested_custom_vars = [v for v in self.custom_vars if v in variable]

        # Identify which variables need to be fetched from the standard GFS source
        other_vars = [v for v in variable if v not in self.custom_vars]

        # Fetch data for non-custom variables from the parent GFS class
        if other_vars:
            gfs_data = super().__call__(time, np.array(other_vars))
        else:
            gfs_data = None

        # Prepare a list of all data arrays to be concatenated
        data_to_concat = []
        if gfs_data is not None:
            data_to_concat.append(gfs_data)

        # Add the requested custom arrays to the list
        for var_name in requested_custom_vars:
            # Expand dims to match the structure of the fetched data
            custom_da = self.custom_arrays[var_name].expand_dims({"time": time})
            data_to_concat.append(custom_da)

        # Concatenate all data arrays along the 'variable' dimension
        if len(data_to_concat) > 1:
            return xr.concat(data_to_concat, dim="variable")
        elif len(data_to_concat) == 1:
            return data_to_concat[0]
        else:
            # Should not happen if 'variable' is never empty
            return xr.DataArray()

# Compare GFS init with and without surface geopotential and land mask
init_time = pd.Timestamp("20240501")

model = GraphCastOperational.load_model(GraphCastOperational.load_default_package())
ds = GFS()
# cache=False to avoid AttributeError: type object 'WholeFileCacheFileSystem'
# has no attribute '_cat_file'. Did you mean: 'cat_file'?
# dummy time list for required positional argument 'time'
# squeeze 'time' to avoid ValueError: Dimension time already exists.
zsl = SurfaceGeoPotential(cache=False)([0]).squeeze(dim="time")
lsm = LandSeaMask(cache=False)([0]).squeeze(dim="time")

# --- Instantiate Custom Data Source ---
custom_data = {"zsl": zsl, "lsm": lsm}
ds_filled = GFSFill(custom_arrays=custom_data)

# --- Run Forecast ---
nsteps = 8
ofile = SCRATCH / "GFS.nc"
if os.path.exists(ofile):
    os.remove(ofile)
io = NetCDF4Backend(ofile, backend_kwargs={"mode": "w"})
deterministic([init_time], nsteps, model, ds, io)
io.close()
ofile = SCRATCH / "GFSFill.nc"
if os.path.exists(ofile):
    os.remove(ofile)
io = NetCDF4Backend(ofile, backend_kwargs={"mode": "w"})
deterministic([init_time], nsteps, model, ds_filled, io)
io.close()

print("Forecast run complete.")

In [None]:
ds(init_time, model.input_coords()["variable"])

In [None]:
init_time = pd.Timestamp("20230502")
model = GraphCastOperational.load_model(GraphCastOperational.load_default_package())
mpas_data_source = MPASDataSource(
    grid_path=Path("MPAS/15km_mesh/grid_mesh/x1.2621442.grid.nc"),
    data_dir=f"HWT{init_time.year}/mpas_15km",
)
ds = mpas_data_source(init_time, 0, members=[1], variables=model.input_coords()["variable"]).squeeze(dim="member")
ds = ds.transpose("time",...)
ds.loc[:, "lsm"] = lsm.values
ds.loc[:, "z"] = zsl.values

ic = MemoryDataSource(ds)
nsteps = 8
ofile = SCRATCH / "graphcast_mpas_{init_time:%Y%m%d%H}.nc"
if os.path.exists(ofile):
    os.remove(ofile)
io = NetCDF4Backend(ofile, backend_kwargs={"mode": "w"})
deterministic([init_time], nsteps, model, ic, io)
io.close()

In [None]:
ds.sel(variable="lsm")

In [None]:
ifile = SCRATCH/f"ai-models/output/graphcast/{init_time:%Y%m%d%H}/gep01.nc"
print(ifile)
ds = xr.open_dataset(ifile)
ds.z500.squeeze().max(dim=["lat", "lon"]) < 1e30

In [None]:
ds = xr.open_dataset(SCRATCH/f"ai-models/output/graphcast/{init_time:%Y%m%d%H}/gep01.nc")
ds.z100.squeeze().max(dim=["lat","lon"])

In [None]:
ds = xr.open_zarr(SCRATCH/f"ai-models/output/graphcast/{init_time:%Y%m%d%H}/gep23")
ds

In [None]:
ds.squeeze().max(dim=["lat","lon"]).load()