works with kernel Python \[conda env:pangu] as constructed in README

run fengwu with 1 GPU and 20 GB memory (must be v100, not gp100 because gp100 only has 16 GB VRAM, but v100 has 32GB)

Make these the default casper modules `module save`

```
Currently Loaded Modules:
  1) ncarenv/24.12  (S)   3) ncarcompilers/1.0.0   5) ucx/1.17.0      7) hdf5/1.12.3    9) cudnn/9.2.0.82-12
  2) intel/2024.2.1       4) cuda/12.3.2           6) openmpi/5.0.6   8) netcdf/4.9.2  10) conda/latest
```

Inferences in conda env:ainwp is different by 0.0001 K from pangu env (after 240 hours)
conda env:ainwp is supposed to replicate realtime runs, but still 0.1 K different from realtime runs

In [None]:
import argparse
import os
from pathlib import Path

import numpy as np
import pandas as pd
import xarray
from run_pangu import plot_ensemble
from run_pangu.s3_run_fengwu_ecmwf import (
    ai_input_grb,
    channel_subset,
    fengwu_channels,
    lat,
    lon,
    pressure_levels,
    setup_model_sessions,
    variables,
)

ai_models_dir = Path("/glade/derecho/scratch/ahijevyc/ai-models")
date = pd.to_datetime("2024042500", format="%Y%m%d%H")
date_6 = date - pd.to_timedelta("6h")
ic = "gefs"
fhr_end = 24

In [None]:
def ai_input_nc(nc):
    input = xarray.open_dataset(nc)
    sfc_param = ["u10m", "v10m", "t2m", "msl"]
    pl_param = [f"{f}{p}" for f in variables for p in pressure_levels]
    fields_all = []
    for p in sfc_param + pl_param:
        field = input["__xarray_dataarray_variable__"].sel(channel=p).squeeze().values
        fields_all.append(field)
    return np.stack(fields_all)


def run_fengwu(input, data_mean, data_std, date, ic, fhr_end, odir, clobber=False):
    for fhr in range(6, fhr_end+1, 6):
        output_filename = f"{odir}/fengwu_{ic}_pred_{fhr:03d}.nc"
        if os.path.exists(output_filename):
            if clobber:
                os.remove(output_filename)
            else:
                continue
        print(f"Processing {date:%Y-%m-%d} - {fhr} hour")
        # 
        output = ort_session_6.run(None, {"input": input})[0]
        input = np.concatenate((input[:, 69:], output[:, :69]), axis=1)
        output = (output[0, :69] * data_std) + data_mean

        # Create prediction timedelta
        pred_timedelta = pd.Timedelta(hours=fhr)

        # Create xarray DataArrays with proper dimensions
        da_output = xarray.DataArray(
            data=np.expand_dims(np.expand_dims(output, axis=0), axis=0),
            coords={
                "init_time": [date],
                "prediction_timedelta": [pred_timedelta],
                "channel": fengwu_channels,
                "lat": lat,
                "lon": lon,
            },
            dims=["init_time", "prediction_timedelta", "channel", "lat", "lon"],
        ).sel(lat=slice(60, 20), lon=slice(220, 300), channel=channel_subset)

        # Save as netCDF
        da_output.to_netcdf(output_filename)

In [None]:
model_dir = Path("/glade/derecho/scratch/zxhua/AI_global_forecast_model_for_education/FengWu/model")
ort_session_6 = setup_model_sessions(model_dir)
# Load normalization data
data_mean = np.load(model_dir / "data_mean.npy")[:, np.newaxis, np.newaxis]
data_std = np.load(model_dir / "data_std.npy")[:, np.newaxis, np.newaxis]

In [None]:
for mem in ["c00"] + [f"p{m:02d}" for m in range(1, 31)]:  # gefs has 30 perturbed members; ecmwf has 50
    ens = int(mem[1:])  # int() removes leading zeros
    assert mem.startswith("p") or mem == "c00"
    odir = ai_models_dir / f"output/fengwu/{date:%Y%m%d%H}/{mem}_att2"
    if all([Path(f"{odir}/fengwu_{ic}_pred_{fhr:03d}.nc").exists() for fhr in range(6,fhr_end+1,6)]):
        print('.', end='')
        continue

    if ic == "ecmwf":
        assert date > pd.to_datetime("20250209"), "started saving ecmwf after 20250209"
        prior_nc = f"/glade/derecho/scratch/sobash/fengwu_realtime/{date:%Y%m%d%H}/ens{ens}/pangu_ens{ens}_init_{date_6:%Y%m%d%H}.nc"
        current_nc = f"/glade/derecho/scratch/sobash/fengwu_realtime/{date:%Y%m%d%H}/ens{ens}/pangu_ens{ens}_init_{date:%Y%m%d%H}.nc"
        input_prior = fengwu_input_nc(prior_nc)
        input_current = fengwu_input_nc(current_nc)
    elif ic == "gefs":
        from run_pangu import s1_get_gefs, s2_make_ic_gefs

        assert ens <= 30
        s1_get_gefs.download_time(date_6)
        s1_get_gefs.download_time(date)

        s2_make_ic_gefs.process_member(date_6.strftime("%Y%m%d%H"), mem)
        s2_make_ic_gefs.process_member(date.strftime("%Y%m%d%H"), mem)
        prior_grb = (
            ai_models_dir / f"input/{date_6:%Y%m%d%H}/{mem}/ge{mem}.t{date_6:%H}z.pgrb.0p25.f000"
        )
        current_grb = (
            ai_models_dir / f"input/{date:%Y%m%d%H}/{mem}/ge{mem}.t{date:%H}z.pgrb.0p25.f000"
        )
        input_prior = ai_input_grb(prior_grb)
        input_current = ai_input_grb(current_grb)
        # TODO: maybe run with ai-models-fengwu, an ai-models plugin
        if False:
            odir = current_grb.parent / "input_data"
            os.makedirs(odir, exist_ok=True)
            np.save(odir / "input1.npy", input_prior)
            np.save(odir / "input2.npy", input_current)

    # Normalize input data
    input_current_after_norm = (input_current - data_mean) / data_std
    input_prior_after_norm = (input_prior - data_mean) / data_std
    input_fengwu = np.concatenate((input_prior_after_norm, input_current_after_norm), axis=0)[
        np.newaxis, :, :, :
    ]
    input_fengwu = input_fengwu.astype(np.float32)

    os.makedirs(odir, exist_ok=True)
    run_fengwu(input_fengwu, data_mean, data_std, date, ic, fhr_end, odir)

In [None]:
args = argparse.Namespace
args.ic = ic.upper()
args.model = "fengwu"
ifiles = sorted(list((ai_models_dir / f"output/fengwu/{date:%Y%m%d%H}").glob(f"[cp][0-9][0-9]/fengwu_{ic}_*")))

ai_models_dir / f"output/fengwu/{date:%Y%m%d%H}/{mem}"
def daymultiple(f):
    # multiple of 1 day
    fhr = f.name[-6:-3]  # fhr part
    return int(fhr) % 24 == 0


ifiles = [f for f in ifiles if daymultiple(f)]
print(len(ifiles))
da = (
    xarray.open_mfdataset(ifiles, decode_timedelta=True, preprocess=plot_ensemble.parsemem)
    .sel(channel="z500", init_time=date)
    .squeeze()
    .rename(__xarray_dataarray_variable__="z")
    .rename(lat="latitude", lon="longitude", prediction_timedelta="step")
)
da

In [None]:
fig = plot_ensemble.plot_forecast_grid(args, da, plotdays=[1, 2, 3, 4, 5, 6, 7, 8])

In [None]:
ic

In [None]:
da = (
    xarray.open_mfdataset((ai_models_dir / f"output/fengwu/{date:%Y%m%d%H}/c00").glob("*_ecmwf_*"), decode_timedelta=True)
    .sel(channel="t2m", init_time=date)
    .squeeze()["__xarray_dataarray_variable__"]
    .isel(prediction_timedelta=slice(None, None, 4))
)
da.sel(lat=slice(60, 20), lon=slice(220, 300)).plot(col="prediction_timedelta", col_wrap=5)

In [None]:
da_old = (
    xarray.open_mfdataset(
        Path(
            f"/glade/derecho/scratch/sobash/fengwu_realtime/{date:%Y%m%d%H}/ens{ens}/fengwu_forecast_data"
        ).glob("fengwu*.nc"),
        decode_timedelta=True,
    )
    .sel(channel="t2m")
    .squeeze()["__xarray_dataarray_variable__"]
    .isel(prediction_timedelta=slice(None, None, 4))
)
da_old.plot(col="prediction_timedelta", col_wrap=5)

In [None]:
(da - da_old).plot(col="prediction_timedelta", col_wrap=5)

In [None]:
(da - da_old).max().load()