In [22]:
import xarray as xr
import re
from collections import defaultdict

In [23]:
## Imports
# BK 1850 tags
bk = xr.open_dataset('/net/paleonas.wustl.edu/volume1/blkshare/ajthompson/postproc/' \
'f.e12.F_1850_CAM5.wiso.f19.0ka.002.watertags.2.cam.h0.0006-0025.climo.nc')

# aaf 2090 tags
aaf = xr.open_dataset('/RAID/datasets/f.ie12.BRCP85C5CN.f19_g16.LME.004_2100watertags.004/archive/atm/hist/climatology/' \
'f.ie12.BRCP85C5CN.f19_g16.LME.004_2100watertags.004.cam.h0.2105-2124_monthly_climatology_cat.nc')

In [28]:
## Define a function to average new regions to match old tags

def average_regions_to_new_tag(
    ds: xr.Dataset,
    regions=("EURO", "NASA", "INDA", "SASA"),
    new_region="ERAS",
    *,
    weights: dict | None = None,   # e.g., {"EURO": 1, "NASA": 2, "INDA": 1, "SASA": 1}
    require_all: bool = False,     # if True, only average when *all* regions for a (prefix, sep, tail) group are present
    keep_nonregion_vars: bool = True,
    dtype="float32",
    skipna=True
) -> xr.Dataset:
    """
    Collapse per-region variables into a new averaged region (e.g., ERAS).

    Matches variables shaped like:
      [<prefix>][<sep>]<REGION><tail>
    where <prefix> is optional (can include underscores), <sep> is an optional single underscore,
    <REGION> is one of `regions`, and <tail> is the remaining suffix (e.g., '18OI', 'V', 'r', etc.).

    Examples:
      'NASA18OI'           -> prefix='',  sep='',  region='NASA', tail='18OI'
      'PRECRC_NASA18Or'    -> prefix='PRECRC', sep='_', region='NASA', tail='18Or'

    Output variable is named:
      f"{prefix}{sep}{new_region}{tail}"
    """
    if not regions:
        raise ValueError("Provide at least one region prefix in `regions`.")

    region_alt = "|".join(map(re.escape, regions))
    # Capture optional prefix (lazy), optional underscore sep, region, and tail
    # - prefix: any text (possibly empty)
    # - sep: optional single underscore between prefix and region (captured to preserve in output)
    # - region: one of provided regions
    # - tail: required non-empty suffix
    pat = re.compile(rf"^(?P<prefix>.*?)(?P<sep>_)?(?P<region>{region_alt})(?P<tail>.+)$")

    # Group DataArrays by (prefix, sep, tail)
    groups = defaultdict(dict)   # (prefix, sep, tail) -> {region: DataArray}
    region_vars = []             # names of all region-specific variables (to drop later)

    for vname, da in ds.data_vars.items():
        m = pat.match(vname)
        if m:
            prefix = m.group("prefix") or ""
            sep = m.group("sep") or ""
            region = m.group("region")
            tail = m.group("tail")
            groups[(prefix, sep, tail)][region] = da
            region_vars.append(vname)

    new_vars = {}
    missing_summary = {}

    for key, reg_map in groups.items():
        prefix, sep, tail = key
        present_regions = sorted(reg_map.keys())

        if require_all and len(present_regions) != len(regions):
            missing_summary[tail] = sorted(set(regions) - set(present_regions))
            continue

        # Build lists in canonical region order
        das = []
        wts = []
        for r in regions:
            if r in reg_map:
                das.append(reg_map[r])
                wts.append(1.0 if weights is None else float(weights.get(r, 0.0)))
            elif not require_all:
                continue

        if not das:
            continue

        # Align before averaging
        das_aligned = xr.align(*das, join="exact")
        arr = xr.concat(das_aligned, dim="__region__")

        if weights is None:
            eras = arr.mean(dim="__region__", skipna=skipna)
        else:
            import numpy as np
            w = xr.DataArray(np.array(wts, dtype="float64"), dims="__region__")
            eras = (arr * w).sum(dim="__region__", skipna=skipna) / w.sum()

        eras = eras.astype(dtype)
        new_name = f"{prefix}{sep}{new_region}{tail}"
        eras.name = new_name
        new_vars[new_name] = eras

    # Assemble output dataset
    ds_out_parts = []
    if keep_nonregion_vars:
        keep_vars = [v for v in ds.data_vars if v not in set(region_vars)]
        ds_out_parts.append(ds[keep_vars])
    else:
        ds_out_parts.append(xr.Dataset(coords=ds.coords))

    ds_out_parts.append(xr.Dataset(new_vars))
    ds_out = xr.merge(ds_out_parts)
    ds_out.attrs = dict(ds.attrs)

    if missing_summary:
        ds_out.attrs[f"{new_region}_missing_regions_info"] = str(missing_summary)

    return ds_out



In [None]:
# Average for 
eras_ds = average_regions_to_new_tag(
    aaf,
    regions=("EURO", "NASA", "INDA", "SASA"),
    new_region="ERAS",
    # weights={"EURO":1, "NASA":1, "INDA":1, "SASA":1},  # optional
    require_all=True,       # set True if you only want averages when *all* are present
    keep_nonregion_vars=False # keep other non-region variables
)

In [32]:
eras_ds