In [2]:
import xarray as xr
import gcsfs

# ─── CONFIG ────────────────────────────────────────────────────────────────────
GCS_URI_6H_13L       = "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
OUT_6H_13L_ZARR      = "ERA5_2023_6h_13lvl_1440x721.zarr"
OUT_6H_13L_NC        = "ERA5_2023_6h_13lvl_1440x721.nc"
OUT_6H_13L_US_ZARR   = "ERA5_2023_6h_13l_US.zarr"
OUT_6H_13L_US_NC     = "ERA5_2023_6h_13l_US.nc"

def open_gcs_zarr(uri: str, project: str = None):
    fs = gcsfs.GCSFileSystem(project=project)
    return xr.open_zarr(fs.get_mapper(uri), consolidated=True)

# 1) load & time‐subset 6h/13l for 2023
ds = open_gcs_zarr(GCS_URI_6H_13L)
ds2023 = ds.sel(time=slice("2023-01-01","2023-12-31"))

# ─── spatially subset to continental U.S. ────────────────────────────────────
# your lon runs 0→360
us_lon_min, us_lon_max = -125 + 360, -66 + 360   # → 235, 294
us_lat_min, us_lat_max = 24.0, 50.0              # 24N → 50N

# use the actual coord names in your ds:
lon_coord = "longitude"    if "longitude"    in ds2023.coords else "lon"
lat_coord = "latitude"     if "latitude"     in ds2023.coords else "lat"

ds2023_us = ds2023.sel(
    { lon_coord: slice(us_lon_min, us_lon_max),
      lat_coord: slice(us_lat_min, us_lat_max) }
)

# 2) write out both the global 2023 slice and the U.S. subset
ds2023 .to_zarr(OUT_6H_13L_ZARR,    mode="w")
ds2023 .to_netcdf(OUT_6H_13L_NC)
ds2023_us.to_zarr(OUT_6H_13L_US_ZARR, mode="w")
ds2023_us.to_netcdf(OUT_6H_13L_US_NC)

print("Global 2023 →", OUT_6H_13L_ZARR)
print("Global 2023 →", OUT_6H_13L_NC)
print("  U.S. 2023 →", OUT_6H_13L_US_ZARR)
print("  U.S. 2023 →", OUT_6H_13L_US_NC)


Global 2023 → ERA5_2023_6h_13lvl_1440x721.zarr
Global 2023 → ERA5_2023_6h_13lvl_1440x721.nc
  U.S. 2023 → ERA5_2023_6h_13l_US.zarr
  U.S. 2023 → ERA5_2023_6h_13l_US.nc


In [5]:
import os
import pandas as pd
import xarray as xr

# ─── 1) Inspect your ERA5 dataset ──────────────────────────────────────────────
ERA5_STORE = "ERA5_2023_6h_13lvl_1440x721.zarr"
ds = xr.open_zarr(ERA5_STORE)

print("=== ERA5 DATASET SUMMARY ===")
print(ds)                        # full summary: dims, coords, data_vars
print("\nVariables in ds.data_vars:")
for name, var in ds.data_vars.items():
    print(f" • {name:30s} dims={var.dims}   shape={tuple(var.shape)}")

print("\nCoordinates in ds.coords:")
for coord in ds.coords:
    print(" •", coord, "=", ds.coords[coord].values[:5], "…")

# ─── 2) Inspect one of your cleaned arrivals CSVs ──────────────────────────────
ARRIVAL_DIR = "Datasets/Arrival_With_Weather/2023"
# pick the first CSV in the folder
fn = sorted([f for f in os.listdir(ARRIVAL_DIR) if f.lower().endswith(".csv")])[0]
df = pd.read_csv(os.path.join(ARRIVAL_DIR, fn), parse_dates=["FlightDate"])

print("\n=== SAMPLE ARRIVAL CSV:", fn, "===\n")
print("Columns:", list(df.columns))
print("\nFirst 5 rows:")
print(df.head())


=== ERA5 DATASET SUMMARY ===
<xarray.Dataset> Size: 38GB
Dimensions:                                           (time: 40, latitude: 721,
                                                       longitude: 1440,
                                                       level: 13)
Coordinates:
  * latitude                                          (latitude) float32 3kB ...
  * level                                             (level) int64 104B 50 ....
  * longitude                                         (longitude) float32 6kB ...
  * time                                              (time) datetime64[ns] 320B ...
Data variables: (12/62)
    10m_u_component_of_wind                           (time, latitude, longitude) float32 166MB dask.array<chunksize=(1, 721, 1440), meta=np.ndarray>
    10m_v_component_of_wind                           (time, latitude, longitude) float32 166MB dask.array<chunksize=(1, 721, 1440), meta=np.ndarray>
    10m_wind_speed                                    (time

IndexError: list index out of range

In [23]:
import os
import numpy as np
import pandas as pd
import xarray as xr

# ─── CONFIG ───────────────────────────────────────────────────────────────────
#ERA5_STORE   = "ERA5_2023_6h_13l_US.zarr"
ERA5_STORE   = "ERA5_2023_6h_13lvl_1440x721.zarr"
MASTER_COORD = "Datasets/T_MASTER_CORD.csv"
ARRIVAL_DIR  = "Datasets/Arrival_Statistics/2023"
OUTPUT_DIR   = "Datasets/Arrival_With_Weather/2023"

VARS = [
    "2m_temperature",
    "mean_sea_level_pressure",
    "10m_u_component_of_wind",
    "10m_v_component_of_wind",
    "total_precipitation_6hr",
    "mean_top_downward_short_wave_radiation_flux",
]

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ─── 1) load ERA5 & rename dims ───────────────────────────────────────────────
ds = xr.open_zarr(ERA5_STORE, consolidated=True)

if "latitude" in ds.dims and "longitude" in ds.dims:
    ds = ds.rename({"latitude": "lat", "longitude": "lon"})

print("dims before subsetting:", ds.dims)

# ─── 1b) wrap & sort lon/lat, then slice just over the U.S. ────────────────
# (drop this block if ERA5_STORE already points to your US‑only zarr)
us_lon_min, us_lon_max = 235.0, 294.0   # [0–360] U.S. longitudes
us_lat_min, us_lat_max =  24.0,  50.0   # U.S. latitudes

ds = (
    ds
    .assign_coords(lon=((ds.lon + 360) % 360))
    .sortby("lon")
    .sortby("lat")
    .sel(lon=slice(us_lon_min, us_lon_max),
         lat=slice(us_lat_min, us_lat_max))
)
print("dims after subsetting:", ds.dims)

# ─── 2) build your fast lookup arrays ────────────────────────────────────────
time_index = ds["time"].to_index()
lat_vals    = ds["lat"].values
lon_vals    = ds["lon"].values

# ─── 3) load master coords ───────────────────────────────────────────────────
mc = (
    pd.read_csv(MASTER_COORD, dtype=str)
      .set_index("AIRPORT_SEQ_ID")[["LATITUDE","LONGITUDE"]]
      .astype(float)
)

# ─── 4) hhmm → Timedelta helper ─────────────────────────────────────────────
def parse_hhmm(x):
    if pd.isna(x) or not str(x).strip():
        return pd.NaT
    s = str(int(float(x))).zfill(4)
    return pd.Timedelta(hours=int(s[:2]), minutes=int(s[2:]))





# ─── right after your subsetting block, before you build time_index, etc ───
# Pre‑load all the VARS into memory as NumPy arrays:
var_arrays = {}
for var in VARS:
    # this will compute the Dask array into a np.ndarray of shape (time, lat, lon)
    var_arrays[var] = ds[var].data.compute()
# Now var_arrays[var][t_idx, l_idx, o_idx] is pure NumPy fancy‑indexing

# ─── then your make_weather_lookup becomes ─────────────────────────────────
def make_weather_lookup(keys: pd.DataFrame) -> pd.DataFrame:
    sub = keys.copy()
    sub["lat"] = sub["AirportSeqID"].map(mc["LATITUDE"])
    sub["lon"] = sub["AirportSeqID"].map(mc["LONGITUDE"]) % 360.0
    sub = sub.dropna(subset=["Datetime","lat","lon"])
    if sub.empty:
        return pd.DataFrame(columns=VARS,
                            index=pd.MultiIndex.from_arrays([[],[]],
                                                           names=["AirportSeqID","Datetime"]))

    # 1) nearest‐time
    t_idx = time_index.get_indexer(sub["Datetime"], method="nearest")

    # 2) nearest‐lat/lon
    sub_lat = sub["lat"].to_numpy()
    sub_lon = sub["lon"].to_numpy()
    l_idx = np.abs(lat_vals[None,:] - sub_lat[:,None]).argmin(axis=1)
    o_idx = np.abs(lon_vals[None,:] - sub_lon[:,None]).argmin(axis=1)

    # 3) pull out each VAR from our preloaded var_arrays
    out = {}
    for var in VARS:
        arr = var_arrays[var]   # pure numpy now
        out[var] = arr[t_idx, l_idx, o_idx]

    # 4) assemble a DataFrame and re‑index
    df_lkp = pd.DataFrame(out, index=sub.index)
    df_lkp.index = pd.MultiIndex.from_frame(
        sub[["AirportSeqID","Datetime"]],
        names=["AirportSeqID","Datetime"]
    )
    return df_lkp






In [27]:
for fn in sorted(os.listdir(ARRIVAL_DIR)):
    if not fn.lower().endswith(".csv"):
        continue
    print("→ augmenting", fn)
    df = pd.read_csv(os.path.join(ARRIVAL_DIR, fn), dtype=str, parse_dates=False)

    # 1) build true timestamps
    df["FlightDate"] = (
        pd.to_datetime(df["FlightDate"].astype(str),
                       format="%Y-%m-%d",
                       errors="coerce")
          .dt.normalize()
    )
    df["DepDelta"] = df["DepTime"].apply(parse_hhmm)
    df["ArrDelta"] = df["ArrTime"].apply(parse_hhmm)
    df["DepDatetime"] = df["FlightDate"] + df["DepDelta"]
    df["ArrDatetime"] = df["FlightDate"] + df["ArrDelta"]

    # 2) vectorized lookups
    orig = (
      df[["OriginAirportSeqID","DepDatetime"]]
        .dropna().drop_duplicates()
        .rename(columns={"OriginAirportSeqID":"AirportSeqID",
                         "DepDatetime":"Datetime"})
    )
    dest = (
      df[["DestAirportSeqID","ArrDatetime"]]
        .dropna().drop_duplicates()
        .rename(columns={"DestAirportSeqID":"AirportSeqID",
                         "ArrDatetime":"Datetime"})
    )
    orig_lkp = make_weather_lookup(orig)
    dest_lkp = make_weather_lookup(dest)

    # 3) stitch back on
    df = (
      df
      .merge(orig_lkp.rename(columns=lambda c: f"Origin_{c}"),
             left_on=["OriginAirportSeqID","DepDatetime"],
             right_index=True, how="left")
      .merge(dest_lkp.rename(columns=lambda c: f"Dest_{c}"),
             left_on=["DestAirportSeqID","ArrDatetime"],
             right_index=True, how="left")
    )

    # 4) drop your temps
    df.drop(columns=["DepDelta","ArrDelta","DepDatetime","ArrDatetime"],
            errors="ignore", inplace=True)

    # 5) now drop *any* column with no actual data:
    #    first turn blank‐strings into NaN
    df.replace(r'^\s*$', np.nan, regex=True, inplace=True)
    #    then drop columns that are all-NaN
    df.dropna(axis=1, how="all", inplace=True)

    # 6) save
    out = os.path.join(OUTPUT_DIR, fn)
    df.to_csv(out, index=False)
    print("   saved →", out)


→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_1.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_1.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_10.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_10.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_11.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_11.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_12.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_12.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_2.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_2.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_3.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_3.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_4.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_4.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_5.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_5.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_6.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_6.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_7.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_7.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_8.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_8.csv
→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_9.csv


  df.replace(r'^\s*$', np.nan, regex=True, inplace=True)


   saved → Datasets/Arrival_With_Weather/2023\On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_9.csv


Now run preprocessing to merge and format the arrival data for training

then take data and create grids

In [6]:
import os
import numpy as np
import pandas as pd
import xarray as xr

# ─── CONFIG ────────────────────────────────────────────────────────────────────
ERA5_STORE   = "ERA5_2023_6h_13lvl_1440x721.zarr"
MASTER_COORD = "Datasets/T_MASTER_CORD.csv"
ARRIVAL_DIR  = "Datasets/Arrival_Statistics/2023"
OUTPUT_DIR   = "Datasets/Arrival_With_Weather/2023"

# the ERA5 vars you want to pull out
VARS = [
    "2m_temperature",
    "mean_sea_level_pressure",
    "10m_u_component_of_wind",
    "10m_v_component_of_wind",
    "total_precipitation_6hr",
    "toa_incident_solar_radiation",
]

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ─── 1) load ERA5 & fix coords ─────────────────────────────────────────────────
ds = xr.open_zarr(ERA5_STORE)
# if your zarr came with 'latitude'/'longitude' dims, rename them:
if "latitude" in ds.dims and "longitude" in ds.dims:
    ds = ds.rename({"latitude": "lat", "longitude": "lon"})

# ─── 2) load master coords ─────────────────────────────────────────────────────
mc = (
    pd.read_csv(MASTER_COORD, dtype=str)
      .set_index("AIRPORT_SEQ_ID")[["LATITUDE", "LONGITUDE"]]
)

# ─── 3) helper to parse HHMM → Timedelta ────────────────────────────────────────
def parse_hhmm(x):
    if pd.isna(x) or not str(x).strip():
        return pd.NaT
    s = str(int(float(x))).zfill(4)
    return pd.Timedelta(hours=int(s[:2]), minutes=int(s[2:]))

# ─── 4) loop over each month’s arrival file ────────────────────────────────────
for fn in sorted(os.listdir(ARRIVAL_DIR)):
    if not fn.lower().endswith(".csv"):
        continue
    print("→ augmenting", fn)
    # --- read your CSV; FlightDate will come in as object/string because of dtype=str
    df = pd.read_csv(
        os.path.join(ARRIVAL_DIR, fn),
        parse_dates=False,     # we’ll do it ourselves
        dtype=str
        )

# 1) cast the FlightDate column to datetime64 (dropping any rogue time component):
    df["FlightDate"] = (
        pd.to_datetime(df["FlightDate"].astype(str),    # ensure it’s all string
                       format="%Y-%m-%d",
                       errors="coerce")
          .dt.normalize()   # midnight, so FlightDate + Timedelta works nicely
    )
    

    # now you can safely do .dt if you need it, but we don’t:
    # 2) parse your HHMM into Timedeltas
    df["DepDelta"] = df["DepTime"].apply(parse_hhmm)
    df["ArrDelta"] = df["ArrTime"].apply(parse_hhmm)

    # if FlightDate sometimes comes in as e.g. "2023‑01‑08 00:00:00", you can
    # strip to just the date first, *after* it’s datetime:
    df["FlightDate"] = (
        pd.to_datetime(df["FlightDate"], errors="coerce")
          .dt.floor("D")
    )

    # 3) form real timestamps
    df["DepDatetime"] = df["FlightDate"] + df["DepDelta"]
    df["ArrDatetime"] = df["FlightDate"] + df["ArrDelta"]

    # … the rest of your augmentation loop follows …


    # c) pre‑create your weather columns
    for var in VARS:
        df[f"Origin_{var}"] = np.nan
        df[f"Dest_{var}"]   = np.nan

    # d) fill them in
    for i, row in df.iterrows():
        for side, DTcol, AIDcol in [
            ("Origin","DepDatetime","OriginAirportSeqID"),
            ("Dest",  "ArrDatetime","DestAirportSeqID")
        ]:
            t   = row[DTcol]
            aid = str(row[AIDcol]).strip()
            if pd.isna(t) or aid not in mc.index:
                continue

            lat = float(mc.at[aid, "LATITUDE"])
            lon = float(mc.at[aid, "LONGITUDE"]) % 360

            for var in VARS:
                try:
                    val = ds[var].sel(
                        time=t,   method="nearest",
                        lat=lat, lon=lon
                    ).item()
                except Exception:
                    val = np.nan
                df.at[i, f"{side}_{var}"] = val

    # e) drop just the temporaries
    df.drop(columns=["DepDelta","ArrDelta","DepDatetime","ArrDatetime"],
            inplace=True)

    # f) drop _any_ column that's 100% NaN (that includes weather columns
    #    for out‑of‑US or missing airports)
    df.dropna(axis=1, how="all", inplace=True)

    # g) write it back out
    out = os.path.join(OUTPUT_DIR, fn)
    df.to_csv(out, index=False)
    print("   saved →", out)


→ augmenting On_Time_Reporting_Carrier_On_Time_Performance_(1987_present)_2023_1.csv


KeyboardInterrupt: 

In [1]:
import os
import numpy as np
import pandas as pd
import xarray as xr
import jax
import haiku as hk

from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast


# ─── CONFIG ───────────────────────────────────────────────────────────────────
ERA5_STORE    = "ERA5_2023_6h_13lvl_1440x721.zarr"
PRED_STORE    = "Preds_2023_6h_13l.zarr"
CKPT_NPZ      = "Models/graphcast_operational.npz"
ARR_DIR       = "Datasets/Arrival_Statistics/2023"
OUT_DIR       = "Datasets/Arrival_With_GNN/2023"
MASTER_COORD  = "Datasets/T_MASTER_CORD.csv"

os.makedirs(OUT_DIR, exist_ok=True)

# ─── A) Load & unflatten .npz params ───────────────────────────────────────────
def unflatten_params(flat):
    tree = {}
    for path, arr in flat.items():
        d = tree
        *keys, leaf = path
        for k in keys:
            d = d.setdefault(k, {})
        d[leaf] = arr
    return tree

npz = np.load(CKPT_NPZ, allow_pickle=True)
flat = {
    tuple(k.split("/")): npz[k]
    for k in npz.files if k.startswith("params:")
}
nested = unflatten_params(flat)
params = hk.data_structures.to_immutable_dict(nested)
state  = {}  # no 'state:' in this archive

# ─── B) Rebuild model & task configs ──────────────────────────────────────────
model_config = graphcast.ModelConfig(
    resolution=0.25,
    mesh_size=20,
    latent_size=64,
    gnn_msg_steps=1,
    hidden_layers=2,
    radius_query_fraction_edge_length=0.6
)
task_config = graphcast.TaskConfig(
    input_variables   = graphcast.TASK.input_variables,
    forcing_variables = graphcast.TASK.forcing_variables,
    target_variables  = graphcast.TASK.target_variables,
    pressure_levels   = graphcast.PRESSURE_LEVELS[13],
    input_duration    = graphcast.TASK.input_duration
)

model = graphcast.GraphCast(model_config, task_config)
rng   = jax.random.PRNGKey(0)

# ─── C) Prepare an empty output Zarr for predictions ──────────────────────────
ds_in = xr.open_zarr(ERA5_STORE, chunks={"time": 1})
template = xr.Dataset(
    {
     var: (("time","lat","lon"),
           np.full((len(ds_in.time), len(ds_in.lat), len(ds_in.lon)), np.nan))
     for var in task_config.target_variables
    },
    coords={"time": ds_in.time, "lat": ds_in.lat, "lon": ds_in.lon}
)
template.to_zarr(PRED_STORE, mode="w")

# ─── D) Stream through each time-step, run GNN, and write slice ──────────────
ds_pred = xr.open_zarr(PRED_STORE, consolidated=False, mode="a")
for ti in range(len(ds_in.time)):
    t = ds_in.time.isel(time=ti)
    di = ds_in[task_config.input_variables].isel(time=ti).expand_dims(batch=1, time=1)
    df = ds_in[task_config.forcing_variables].isel(time=ti).expand_dims(batch=1, time=1)
    dt = xr.zeros_like(ds_in[task_config.target_variables].isel(time=ti))\
           .expand_dims(batch=1, time=1)

    preds, _ = model.apply(params, state, rng, di, dt, df, is_training=False)
    p0 = preds.isel(batch=0, time=0)

    for var in task_config.target_variables:
        ds_pred[var][ti, :, :] = p0[var].data

ds_pred.close()

# ─── E) Load master coords ────────────────────────────────────────────────────
mc = pd.read_csv(MASTER_COORD, dtype=str).set_index("AIRPORT_SEQ_ID")
mc["LAT"] = mc["LATITUDE"].astype(float)
mc["LON"] = (mc["LONGITUDE"].astype(float) % 360)

# ─── F) Augment arrival CSVs with GraphCast preds ─────────────────────────────
def parse_hhmm(x):
    if pd.isna(x) or str(x).strip()=="":
        return pd.NaT
    s = str(int(float(x))).zfill(4)
    return pd.Timedelta(hours=int(s[:2]), minutes=int(s[2:]))

preds_ds = xr.open_zarr(PRED_STORE, chunks={"time":1})

for fn in os.listdir(ARR_DIR):
    if not fn.lower().endswith(".csv"):
        continue
    df = pd.read_csv(os.path.join(ARR_DIR, fn), dtype=str,
                     parse_dates=["FlightDate"])
    df["DepDT"] = df["DepTime"].apply(parse_hhmm) + df["FlightDate"]
    df["ArrDT"] = df["ArrTime"].apply(parse_hhmm) + df["FlightDate"]

    for var in task_config.target_variables:
        df[f"Origin_pred_{var}"] = np.nan
        df[f"Dest_pred_{var}"]   = np.nan

    for i, row in df.iterrows():
        for side, dtcol, key in [
            ("Origin","DepDT","OriginAirportSeqID"),
            ("Dest",  "ArrDT","DestAirportSeqID")
        ]:
            dt = row[dtcol]
            aid = str(row.get(key,""))
            if pd.isna(dt) or aid not in mc.index:
                continue
            lat, lon = mc.at[aid,"LAT"], mc.at[aid,"LON"]
            for var in task_config.target_variables:
                try:
                    val = preds_ds[var].sel(
                        time=dt, method="nearest",
                        lat=lat, method="nearest",
                        lon=lon, method="nearest"
                    ).item()
                except Exception:
                    val = np.nan
                df.at[i, f"{side}_pred_{var}"] = val

    df.drop(columns=["DepDT","ArrDT"], inplace=True)
    df.dropna(axis=1, how="all", inplace=True)
    out = os.path.join(OUT_DIR, fn)
    df.to_csv(out, index=False)
    print("Saved GNN‑augmented →", out)


KeyboardInterrupt: 

use gridded xarrays for GNN predictions and arrival data augmentation

In [None]:
import os
import numpy as np
import pandas as pd
import xarray as xr

import jax
import haiku as hk
import graphcast
from graphcast.checkpoint import load as load_checkpoint
from graphcast.model_utils import (
    dataset_to_stacked,
    stacked_to_dataset,
    restore_leading_axes,
)

# ─── CONFIG ────────────────────────────────────────────────────────────────────
CKPT_PATH     = "Models/graphcast_operational.npz"
INPUT_STORE   = "Datasets/Gridded_Station_Inputs_2024.nc"   # your station‐gridded .nc
ARRIVAL_DIR   = "Datasets/Arrival_With_Weather/2024"      # CSVs you want to augment
OUTPUT_DIR    = "Datasets/Arrival_With_GNN/2024"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ─── 1) load the checkpoint from your .npz ─────────────────────────────────────
#    this will give you ckpt.params, ckpt.model_config, ckpt.task_config
ckpt = load_checkpoint(CKPT_PATH, graphcast.CheckPoint)
params      = ckpt.params
model_cfg   = ckpt.model_config
task_cfg    = ckpt.task_config

# ─── 2) build your GraphCast model & JAX/Haiku wrappers ────────────────────────
model = graphcast.GraphCast(model_cfg, task_cfg)

# we need a pure‑function that takes a stacked input array and returns the stacked pred
def _forward_fn(stacked_inputs):
    # Haiku modules expect: apply(params, state, rng, *args)
    return model(stacked_inputs)

# wrap it into a pair of (init, apply) functions with state
network = hk.without_apply_rng(hk.transform_with_state(_forward_fn))

# ─── 3) load your pre‐gridded station dataset ──────────────────────────────────
ds_grid = xr.open_dataset(INPUT_STORE)

# slice out only the variables GraphCast needs
vars_in  = list(task_cfg.input_variables) + list(task_cfg.forcing_variables)
vars_out = list(task_cfg.target_variables)

ds_in  = ds_grid[vars_in]
ds_tmpl = ds_grid[vars_out]  # template for un‑stacking later

# ─── 4) run the GNN once over the entire grid ─────────────────────────────────
# stack everything into a single JAX array of shape [batch=1,time,lat,lon,channels]
stacked = dataset_to_stacked(ds_in, sizes=ds_grid.sizes)
stacked = stacked.expand_dims("batch", 0)

# apply the network
(pred_flat, net_state), _ = network.apply(params, {}, None, stacked)

# restore the axes so we have [batch,time,lat,lon,channels]
pred_restored = restore_leading_axes(pred_flat)

# unstack back into an xarray.Dataset with your target variable names
ds_pred = stacked_to_dataset(pred_restored, template_dataset=ds_tmpl)

# drop the dummy batch dim, now ds_pred dims = (time,lat,lon)
ds_pred = ds_pred.isel(batch=0)


# ─── 5) helper to turn HHMM → timedelta ───────────────────────────────────────
def _parse_hhmm(x):
    if pd.isna(x) or not str(x).strip():
        return pd.NaT
    s = str(int(float(x))).zfill(4)
    hh, mm = int(s[:2]), int(s[2:])
    return pd.Timedelta(hours=hh, minutes=mm)


# ─── 6) loop over your arrival files and lookup the predictions ───────────────
for fn in os.listdir(ARRIVAL_DIR):
    if not fn.lower().endswith(".csv"):
        continue

    print("Augmenting", fn)
    df = pd.read_csv(os.path.join(ARRIVAL_DIR, fn),
                     parse_dates=["FlightDate"],
                     dtype=str)

    # build actual datetime columns for dep/arr
    df["DepDelta"] = df["DepTime"].apply(_parse_hhmm)
    df["ArrDelta"] = df["ArrTime"].apply(_parse_hhmm)
    df["DepDatetime"] = df["FlightDate"] + df["DepDelta"]
    df["ArrDatetime"] = df["FlightDate"] + df["ArrDelta"]

    # prepare new prediction columns
    for var in vars_out:
        df[f"Origin_pred_{var}"] = np.nan
        df[f"Dest_pred_{var}"]   = np.nan

    # for each flight, do a nearest‑neighbor lookup in ds_pred
    for side, dtcol, apicol, latcol, loncol in [
        ("Origin", "DepDatetime", "OriginAirportSeqID", "Origin_Lat",  "Origin_Lon"),
        ("Dest",   "ArrDatetime", "DestAirportSeqID",   "Dest_Lat",    "Dest_Lon"),
    ]:
        # you need to have pre‑computed Origin_Lat/Lon and Dest_Lat/Lon
        # by merging in your master coord table earlier.
        mask = ~df[dtcol].isna()
        idx = df.loc[mask, [dtcol, apicol, latcol, loncol]].index
        for i in idx:
            t   = df.at[i, dtcol]
            lat = float(df.at[i, latcol])
            lon = float(df.at[i, loncol]) % 360.0
            for var in vars_out:
                try:
                    v = ds_pred[var].sel(time=t, lat=lat, lon=lon,
                                         method="nearest").item()
                except Exception:
                    v = np.nan
                df.at[i, f"{side}_pred_{var}"] = v

    # drop helper cols
    df.drop(columns=["DepDelta","ArrDelta"], inplace=True)

    # write out
    out_path = os.path.join(OUTPUT_DIR, fn)
    df.to_csv(out_path, index=False)
    print(" → wrote", out_path)


or maybe augment historic and predicted at the same time?

In [None]:
import pickle
# assume weather_dfs is your dict of historic station‑interpolated DataFrames

preds_ds = xr.open_zarr(PRED_STORE, chunks={"time":1})

for fn in os.listdir(ARR_DIR):
    if not fn.lower().endswith(".csv"): 
        continue

    # 1) load the arrival + historic columns
    df = pd.read_csv(os.path.join(ARR_DIR, fn), parse_dates=["FlightDate"], dtype=str)
    df["DepDT"] = df["DepTime"].apply(parse_hhmm) + df["FlightDate"]
    df["ArrDT"] = df["ArrTime"].apply(parse_hhmm) + df["FlightDate"]

    # 2) initialize predicted‐weather columns
    for var in task_config.target_variables:
        df[f"Origin_pred_{var}"] = np.nan
        df[f"Dest_pred_{var}"]   = np.nan

    # 3) fill both historic *and* predicted
    for i, row in df.iterrows():
        for side, dtcol, apcol in [
            ("Origin","DepDT","OriginAirportID"),
            ("Dest",  "ArrDT","DestAirportID")
        ]:
            dt  = row[dtcol]
            aid = str(row[apcol]).strip()

            # (a) historic lookup (as before)
            if not pd.isna(dt) and aid in weather_dfs:
                rec = weather_dfs[aid].iloc[
                    weather_dfs[aid].index.get_indexer([dt], method="nearest")[0]
                ]
                for feat in FEATURES:
                    df.at[i, f"{side}_{feat}"] = rec[feat]

            # (b) predicted lookup
            if not pd.isna(dt) and aid in mc.index:
                lat, lon = mc.at[aid, "LAT"], mc.at[aid, "LON"]
                for var in task_config.target_variables:
                    try:
                        val = preds_ds[var].sel(
                            time=dt, method="nearest",
                            lat=lat, method="nearest",
                            lon=lon, method="nearest"
                        ).item()
                    except KeyError:
                        val = np.nan
                    df.at[i, f"{side}_pred_{var}"] = val

    # 4) clean up & save
    df.drop(columns=["DepDT","ArrDT"], inplace=True)
    df.dropna(axis=1, how="all", inplace=True)
    out = os.path.join(OUT_DIR, fn)
    df.to_csv(out, index=False)
    print("Saved both historic + predicted →", out)
