### The function to apply a circular neighborhood filter to MPAS unstructured mesh is in a branch of uxarray under forked github repository [https://github.com/ahijevyc/uxarray](ahijevyc/uxarray). The branch is ahijevyc/apply_remap_func.

In [6]:
import logging
import os
import warnings
from pathlib import Path

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmaps
import geoviews.feature as gf
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
import pandas as pd
import uxarray
import xarray
from metpy.units import units
from sklearn.neighbors import BallTree as SKBallTree
from tqdm import tqdm

import G211  # 80-km CONUS grid
from util import dBZfunc, dec_ax, mkcoord, trim_ll, xtime

degE = units.parse_expression("degree_E")
degN = units.parse_expression("degree_N")

try:  # Avoid ValueError when assigning cmap again
    cmap = cmaps.WhiteBlueGreenYellowRed
except ValueError:
    pass
%matplotlib inline

## Other MPAS

In [38]:
idate = pd.to_datetime("20180429T00")
grid_path = "/glade/campaign/mmm/wmr/weiwang/cps/irma3/2020/tk707_conus/init.nc"
uxds = uxarray.open_mfdataset(
    grid_path,
    Path("/glade/campaign/mmm/wmr/weiwang/cps/irma3/mp6/tk707").glob("diag*.nc"),
    preprocess=xtime,
)

lon_bounds = (-74, -64)
lat_bounds = (17, 25)
var = uxds["refl10cm_max"].isel(valid_time=12).subset.bounding_box(lon_bounds, lat_bounds)

var.plot.rasterize(cmap=cmap) * gf.coastline(projection=ccrs.PlateCarree())

### Apply circular neighborhood filter function `func` with radius `r` (in degrees)

In [41]:
func = np.max
r = 0.25
var.remap.apply_func(var.uxgrid, func=func, r=r).plot.rasterize(
    cmap=cmap, dynamic=True
) * gf.coastline(projection=ccrs.PlateCarree())

### interpolate to 0.5 deg grid

In [42]:
import subprocess

subprocess.run("ncremap -g grd.nc -G latlon=360,720".split())

onedeglatlonGrid = uxarray.open_grid("grd.nc")
var.remap.apply_func(
    onedeglatlonGrid.subset.bounding_box(lon_bounds, lat_bounds), func=func, r=r
).plot.rasterize(cmap=cmap, dynamic=True) * gf.coastline(projection=ccrs.PlateCarree())



In [37]:
idir = Path("/glade/campaign/mmm/parc/sobash/MPAS_regional_cheyenne/2023032500")
grid_path = idir / "conus.init.nc"
uxarray.open_dataset(
    grid_path,
    idir / "diag.2023-03-25_17.00.00.nc",
)["refl10cm_max"].isel(Time=0).plot.rasterize(
    cmap=cmap
) * gf.states(projection=ccrs.PlateCarree())



## HWT MPAS

In [None]:
idate = pd.to_datetime("20230429T00")
base_path = (
    Path("/glade/campaign/mmm/parc/schwartz/HWT2023/mpas")
    / idate.strftime("%Y%m%d%H")
    / "post/mem_5"
)
grid_path = "/glade/campaign/mmm/parc/schwartz/MPAS/15-3km_mesh/mpas_init/static.nc"
# Paths to Data Variable files
var_names = [
    d.strftime("diag.%Y-%m-%d_%H.%M.%S.nc")
    for d in pd.date_range(
        start=idate + pd.Timedelta(hours=3),
        end=idate + pd.Timedelta(hours=9),
        freq="3H",
    )
]
data_paths = [base_path / name for name in var_names]

uxds = uxarray.open_mfdataset(
    grid_path, data_paths, concat_dim="Time", combine="nested", use_dual=False
)
uxds.uxgrid

## Trim global mesh to lat/lon box

In [None]:
# limit to lat lon box

lat0, lat1 = 41, 42.5
lon0, lon1 = -75, -70

# lat0, lat1 = 35, 43
# lon0, lon1 = -83, -66

lon_bounds = (lon0, lon1)
lat_bounds = (lat0, lat1)
grid_ds, ds = trim_ll(grid_path, data_paths, lon_bounds, lat_bounds)


figw = 12

triang = tri.Triangulation(grid_ds.lonCell, grid_ds.latCell)

projection = ccrs.LambertConformal(central_longitude=-82)
if lon1 - lon0 > 15:
    projection = ccrs.PlateCarree()


da = ds["refl10cm_max"].isel(Time=0).compute()

norm = mpl.colors.Normalize(vmin=0, vmax=50)

## Upscale to NCEP grid 211 (80 km) as in uxarray (SKBallTree)
<li>rectangular region of influence</li>
<li>circular region of influence</li>

In [None]:
warnings.filterwarnings("ignore", message="Approximating coordinate system")
latlon = np.deg2rad(np.c_[grid_ds["latCell"], grid_ds["lonCell"]])
X = np.deg2rad(np.c_[G211.lat.ravel(), G211.lon.ravel()])

# List MPAS indices closest to each 80-km grid cell
idxs = SKBallTree(X, metric="haversine").query(latlon, return_distance=False)

fig, axes = plt.subplots(nrows=2, figsize=(figw, 12), subplot_kw=dict(projection=projection))

# optionally coarsen for speed
coarsen = dict(nCells=int(np.sqrt(ds.nCells.size) / 50), boundary="trim")
logging.warning(coarsen)

ax = axes[0]
ax.set_title("pts in G211 grid cells")
for ig, (glon, glat) in enumerate(zip(G211.lon.ravel(), G211.lat.ravel())):
    if glon < lon0 or glon >= lon1:
        continue
    if glat < lat0 or glat >= lat1:
        continue
    j = [i for i, (idx,) in enumerate(idxs) if idx == ig]
    if len(j) == 0:
        continue

    cc = ax.scatter(
        grid_ds["lonCell"].isel(nCells=j).coarsen(**coarsen).mean(),
        grid_ds["latCell"].isel(nCells=j).coarsen(**coarsen).mean(),
        transform=ccrs.PlateCarree(),
        marker=".",
        alpha=0.8,
    )
    c = dBZfunc(da[j], np.max)
    gs = ax.plot(
        glon,
        glat,
        color=cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba(c),
        transform=ccrs.PlateCarree(),
        marker="o",
        alpha=1,
    )


# List MPAS indices within rptdist of each 80-km grid cell center
ax = axes[1]
ax.set_title("pts within circular neighborhood of G211 centers")
tree = SKBallTree(latlon, metric="haversine")
rptdist = 40  # km
r = np.deg2rad(rptdist / 111.1)
idxs = tree.query_radius(X, r)

mm = ax.tripcolor(
    triang,
    da,
    transform=ccrs.PlateCarree(),
    cmap=cmap,
    norm=norm,
)
fig.colorbar(mm, ax=axes, orientation="horizontal", shrink=0.6)

for i, idx in enumerate(tqdm(idxs)):
    if i % 1 > 0:
        continue
    if len(idx) == 0:
        continue

    c = dBZfunc(da[idx], np.max)
    cc = ax.scatter(
        grid_ds["lonCell"][idx].coarsen(**coarsen).mean(),
        grid_ds["latCell"][idx].coarsen(**coarsen).mean(),
        c=[cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba(c)],
        transform=ccrs.PlateCarree(),
        marker=".",
        alpha=0.1,
    )
    ax.scatter(
        G211.lon.ravel()[i],
        G211.lat.ravel()[i],
        c=c,
        transform=ccrs.PlateCarree(),
        marker="o",
        alpha=1,
        cmap=cmap,
        norm=norm,
    )
    ax.tissot(
        rad_km=rptdist,
        lons=G211.lon.ravel()[i],
        lats=G211.lat.ravel()[i],
        alpha=0.1,
        facecolor="white",
        edgecolor="black",
    )

[dec_ax(ax, (*lon_bounds, *lat_bounds)) for ax in axes]

plt.suptitle(f"{len(grid_ds.lonCell)} cells")

## Upscale to 30-km mesh

In [None]:
%%time
fig, axes = plt.subplots(
    nrows=2,
    figsize=(figw, 12),
    subplot_kw=dict(projection=projection),
)

ax = axes[0]
ax.set_title(f"original {len(grid_ds.lonCell)} cell mesh")
cc = ax.scatter(
    grid_ds["lonCell"],
    grid_ds["latCell"],
    c=da,
    transform=ccrs.PlateCarree(),
    cmap=cmap,
    norm=norm,
    marker=".",
    alpha=0.5,
)

fig.colorbar(mm, ax=axes, orientation="horizontal", shrink=0.6)

# Upscale to coarse mesh
coarse_mesh = xarray.open_dataset(
    "/glade/campaign/mmm/parc/schwartz/MPAS_regional/15km_mesh_regional/mpas_init/regional_15km_mesh_2000km.static.nc"
)
coarse_mesh["lonCell"] = np.degrees(coarse_mesh.lonCell)
coarse_mesh["latCell"] = np.degrees(coarse_mesh.latCell)
# before computing the triangulation
coarse_mesh["lonCell"] = ((coarse_mesh["lonCell"] + 180) % 360) - 180

ibox = (
    (coarse_mesh.lonCell >= lon0)
    & (coarse_mesh.lonCell < lon1)
    & (coarse_mesh.latCell >= lat0)
    & (coarse_mesh.latCell < lat1)
)
logging.info("Trim latCell and lonCell")
coarse_mesh = coarse_mesh[["latCell", "lonCell"]].where(ibox, drop=True)

X = np.c_[coarse_mesh.latCell.values.ravel(), coarse_mesh.lonCell.values.ravel()]
# List MPAS indices closest to each coarse mesh cell
idxs = SKBallTree(np.deg2rad(X), metric="haversine").query(latlon, return_distance=False)

coarse_var = np.empty(coarse_mesh["lonCell"].shape)
coarse_var.fill(np.nan)

axes[1].set_title(f"colored by mean value in {len(coarse_mesh.lonCell)} cell mesh")

for idx in np.unique(idxs):
    i = idxs.squeeze() == [idx]
    if any(i):
        coarse_var[idx] = dBZfunc(da.values[i], np.mean)
        c = dBZfunc(da.values[i], np.mean)
        cc = axes[1].scatter(
            grid_ds["lonCell"][i],
            grid_ds["latCell"][i],
            c=[cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba(c)],
            transform=ccrs.PlateCarree(),
            marker=".",
            alpha=0.5,
        )


for ax in axes:
    dec_ax(ax, (*lon_bounds, *lat_bounds))

coarse_mesh[da.name] = coarse_var

## Plot TC centers with radius of influence and location of maximum
### optional filter by hemisphere or quandrant about center point

### HWT 2023 MPAS

In [None]:
idate = pd.to_datetime("20230531T00")
base_path = (
    Path("/glade/campaign/mmm/parc/schwartz/HWT2023/mpas")
    / idate.strftime("%Y%m%d%H")
    / "post/mem_4"
)
grid_path = "/glade/campaign/mmm/parc/schwartz/MPAS/15-3km_mesh/mpas_init/static.nc"
# Paths to Data Variable files
var_names = [
    d.strftime("diag.%Y-%m-%d_%H.%M.%S.nc")
    for d in pd.date_range(
        start=idate + pd.Timedelta(hours=36),
        end=idate + pd.Timedelta(hours=48),
        freq="6H",
    )
]
data_paths = [base_path / name for name in var_names]
# Tropical storm ARLENE
track = [
    (pd.to_datetime("20230601T12"), -87.5, 28.6),
    (pd.to_datetime("20230601T18"), -87.5, 28.6),
    (pd.to_datetime("20230602T00"), -87.6, 28.7),
]
lat0, lat1 = 24, 32
lon0, lon1 = -93, -80
grid_ds, ds = trim_ll(grid_path, data_paths, (lon0, lon1), (lat0, lat1))

In [None]:
uxds = uxarray.open_mfdataset(
    grid_path, data_paths, concat_dim="Time", combine="nested", use_dual=False
)
print(uxds.source_datasets)

In [None]:
uxds["refl10cm_max"].subset.bounding_box(lon_bounds, lat_bounds).isel(Time=2).plot.rasterize(
    method="polygon",
    width=800,
    height=400,
    dynamic=True,
    cmap=cmap,
    exclude_antimeridian=True,
) * gf.states(projection=ccrs.PlateCarree(), scale="50m")

In [None]:
%%time
# List of TC centers
X = np.array([[lat, lon] for t, lon, lat in track])  # lat, lon
latlon = np.vstack((np.deg2rad(grid_ds.latCell.values), np.deg2rad(grid_ds.lonCell.values))).T
tree = SKBallTree(latlon, metric="haversine")

rptdist = 300  # km
r = np.deg2rad(rptdist / 111.0)
idxs = tree.query_radius(np.deg2rad(X), r)
idxs

In [None]:
grid_ds

In [None]:
%%time
t = 2
fig, ax = plt.subplots(
    figsize=(figw, 6),
    subplot_kw=dict(projection=projection),
)
var = ds["refl10cm_max"].isel(Time=t, nCells=idxs[0]).compute()

ax.set_title(f"{track[t][0]} {len(grid_ds.lonCell)} cells")
cc = ax.scatter(
    grid_ds["lonCell"].isel(nCells=idxs[0]),
    grid_ds["latCell"].isel(nCells=idxs[0]),
    c=var,
    transform=ccrs.PlateCarree(),
    cmap=cmap,
    norm=norm,
    marker=".",
    alpha=0.5,
)
ax.tissot(
    rad_km=rptdist,
    lons=[track[t][1]],
    lats=[track[t][2]],
    alpha=0.1,
    facecolor="none",
    edgecolor="black",
)

dec_ax(ax, (lon0, lon1, lat0, lat1))
ax.add_feature(cfeature.STATES)

fig.colorbar(mm, ax=axes, orientation="horizontal", shrink=0.6)

In [None]:
grid_ds.latCell[6080], uxds.uxgrid.face_lon[5081942], uxds.uxgrid.face_lat[5081942]

In [None]:
uxtree = uxds.uxgrid.get_ball_tree()
xy = np.flip(X, axis=1)

In [None]:
fig, ax = plt.subplots(figsize=(figw, 6), subplot_kw=dict(projection=ccrs.PlateCarree()))
var = ds["refl10cm_max"].isel(Time=0)
cc = ax.scatter(
    grid_ds["lonCell"],
    grid_ds["latCell"],
    c=var,
    transform=ccrs.PlateCarree(),
    cmap=cmap,
    norm=norm,
    marker=".",
    alpha=0.5,
)
ax.coastlines(resolution="50m")
for i, ((lat, lon), idx) in enumerate(zip(X, idxs)):
    if len(idx) == 0:
        continue
    lons = grid_ds["lonCell"][idx]
    lats = grid_ds["latCell"][idx]
    # indices in hemispheres and quadrants about center point
    ieast = ((lons - lon) + 180) % 360 - 180 >= 0
    inorth = lats >= lat
    ne = ieast & inorth
    se = ieast & ~inorth
    sw = ~ieast & ~inorth
    nw = ~ieast & inorth
    qfilt = sw | ne
    lons = lons[qfilt]
    lats = lats[qfilt]
    ax.scatter(
        lons,
        lats,
        transform=ccrs.PlateCarree(),
        marker=".",
        color="white",
        edgecolor=None,
        alpha=0.5,
        label=None,
    )
    x = var[idx][qfilt].argmax().compute()
    ax.scatter(
        lons[x],
        lats[x],
        transform=ccrs.PlateCarree(),
        marker="x",
        label=var.values.max(),
    )
ax.tissot(
    rad_km=rptdist,
    lons=[X[:, 1]],
    lats=[X[:, 0]],
    alpha=0.3,
    facecolor="none",
    edgecolor="black",
)

ax.legend(title="max")
# was plt, not fig. TODO: maybe delete this comment?
fig.colorbar(mm, orientation="horizontal", pad=0.04, shrink=0.8)
plt.title(
    f"{grid_ds.title if 'title' in grid_ds.attrs else ''} ({len(grid_ds.lonCell)} cells)",
    fontweight="bold",
    fontsize=14,
)

In [None]:
f = "/glade/campaign/mmm/dpm/nystrom/MPAS/Climo_15km/2010/2010082500/CNTL"
uxcoarse_mesh = uxarray.open_dataset(f, f)
uxcoarse_mesh[var.name] = coarse_var