works with kernel Python \[conda env:pangu] as constructed in README

run fengwu with 1 GPU and 20 GB memory (must be v100, not gp100 because gp100 only has 16 GB VRAM, but v100 has 32GB)

Make these the default casper modules `module save`

```
Currently Loaded Modules:
  1) ncarenv/24.12  (S)   3) ncarcompilers/1.0.0   5) ucx/1.17.0      7) hdf5/1.12.3    9) cudnn/9.2.0.82-12
  2) intel/2024.2.1       4) cuda/12.3.2           6) openmpi/5.0.6   8) netcdf/4.9.2  10) conda/latest
```

Inferences in conda env:ainwp is different by 0.0001 K from pangu env (after 240 hours)
conda env:ainwp is supposed to replicate realtime runs, but still 0.1 K different from realtime runs

In [None]:
import os
from pathlib import Path

import pandas as pd
import plot_ensemble
import xarray as xr
from earth2studio.data.mpas import MPAS as MPASDataSource
from earth2studio.data.mpas import MPASHybrid
from earth2studio.data.mpas_ens import MPAS as MPASEns
from earth2studio.io import IOBackend, NetCDF4Backend, WPSBackend
from earth2studio.models.px import GraphCastOperational, GraphCastSmall, Pangu6
from earth2studio.run import deterministic
from s3_run_pangu_ecmwf import run_inference, setup_model_sessions

SCRATCH = Path(os.getenv("SCRATCH"))
ai_models_dir = SCRATCH / "ai-models"
date = pd.to_datetime("2018042400", format="%Y%m%d%H")
# date = pd.to_datetime("2024042400", format="%Y%m%d%H")
ic = "mpas"
fhr_end = 120

In [None]:
if True:
    idir = Path(
        "/glade/derecho/scratch/stoedtli/pandac/stoedtli_3dhybrid-60-60-iter_O30kmI60km_benchmark_1/CyclingFC"
    )
    mpas_datasrc = MPASHybrid(
        grid_path=idir / f"{date:%Y%m%d%H}/invariant.655362.nc",
        data_path=f"{idir}/%Y%m%d%H/mpasin.%Y-%m-%d_%H.%M.%S.nc",
        pressure_levels=[50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000],
    )

if False:
    idir = Path("/glade/campaign/mmm/parc/schwartz")
    mpas_datasrc = MPASEns(
        grid_path=idir / "MPAS/15-3km_mesh/grid_mesh/x5.6488066.grid_CONUS.nc",
        data_dir=idir / f"HWT{date.year}/mpas",
    )
if False:
    mpas_datasrc = MPASEns(
        grid_path=idir / "MPAS/15km_mesh/grid_mesh/x1.2621442.grid.nc",
        data_dir=idir / f"HWT{date.year}/mpas_15km",
    )

if False:
    idir = (
        Path(
            "/glade/derecho/scratch/stoedtli/pandac/stoedtli_3dhybrid-60-60-iter_O30kmI60km_benchmark_1/CyclingFC"
        )
        / f"{date:%Y%m%d%H}"
    )
    mpas_datasrc = MPASDataSource(
        grid_path=idir / "invariant.655362.nc",
        data_path=[
            idir / f"diag.{date:%Y-%m-%d_%H.%M.%S}.nc",
            idir / f"p_sfc.{date:%Y-%m-%d_%H.%M.%S}.nc",
        ],
    )

In [None]:
model_class = GraphCastOperational
model = model_class.load_model(model_class.load_default_package())

In [None]:
mpas_datasrc(
    ["20180423T18"],
    ["t1000"],
).plot(col="time", robust=True, figsize=(15, 8))

In [None]:
nsteps = 1
ofile = SCRATCH / f"tmp/test_GRAP:{date:%Y-%m-%d_%H}.nc"
if os.path.exists(ofile):
    os.remove(ofile)
io = WPSBackend(SCRATCH / "tmp", map_source=model.__class__.__name__, static_fields=["lsm", "z"])
#io = NetCDF4Backend(ofile, backend_kwargs={"mode": "w"})
deterministic([date], nsteps, model, mpas_datasrc, io)
io.close()

In [None]:
from metpy.constants import Re
Re

In [None]:
ort_session_24, ort_session_6 = setup_model_sessions(ai_models_dir)

In [None]:
da

In [None]:
output_dir = Path(f"/glade/derecho/scratch/ahijevyc/ai-models/output/panguweather/{date:%Y%m%d%H}")

da = ds.rename(variable="channel")
inferences = run_inference(da, ort_session_24, ort_session_6, fhr_end)
for fcst in inferences:
    fhr = fcst.prediction_timedelta.squeeze() / pd.to_timedelta("1h")
    output_filename = os.path.join(output_dir, f"pangu_{ic}_hybrid_pred_{fhr:03.0f}.nc")
    print(output_filename)
    # zlib compression can reduce by 1/3 but takes long time (versus almost instantaneous)
    fcst.to_netcdf(output_filename)

In [None]:
output_dir = Path(f"/glade/derecho/scratch/ahijevyc/ai-models/output/panguweather/{date:%Y%m%d%H}")

for member in ds.member.data:
    da = ds.sel(member=member).rename(variable="channel").rename(Time="time")
    all_exist = True
    for fhr in range(6, fhr_end+1, 6):
        output_filename = os.path.join(output_dir, f"pangu_{ic}{member}_pred_{fhr:03.0f}.nc")
        if not os.path.exists(output_filename):
            print(output_filename, 'no exists')
            all_exist = False
            break
    if all_exist:
        print(f"all {date} {ic}{member} exist")
        continue
    inferences = run_inference(da, ort_session_24, ort_session_6, fhr_end)
    for fcst in inferences:
        fhr = fcst.prediction_timedelta.squeeze() / pd.to_timedelta("1h")
        output_filename = os.path.join(output_dir, f"pangu_{ic}{member}_pred_{fhr:03.0f}.nc")
        print(output_filename)
        # zlib compression can reduce by 1/3 but takes long time (versus almost instantaneous)
        fcst.to_netcdf(output_filename)

In [None]:
ifiles = []
for member in ds.member.data:
    ifiles.append(sorted(list(output_dir.glob(f"pangu_{ic}{member}_pred_???.nc"))))
da = (
    xr.open_mfdataset(
        ifiles,
        combine="nested",
        concat_dim=["member", "prediction_timedelta"],
    )
    .rename(lat="latitude", lon="longitude")
    .assign_coords(member=ds.member)
    .sel(channel="z500")
    .rename(__xarray_dataarray_variable__="z", prediction_timedelta="step")
    .squeeze(dim="init_time")
)
da

In [None]:
import argparse
args = argparse.Namespace
args.ic = ic.upper()
args.model = "panguweather"
fig = plot_ensemble.plot_forecast_grid(args, da.sortby("member"), plotdays=range(1, fhr_end//24+1))

In [None]:
member = 4
inferences = xr.open_mfdataset(output_dir.glob(f"pangu_{ic}{member}_pred_???.nc"))
inferences = inferences.sel(lat=slice(60,20), lon=slice(220, 300))
inferences.sel(channel="z500").squeeze().__xarray_dataarray_variable__.plot(col="prediction_timedelta", col_wrap=6)

In [None]:
inferences