In [1]:
import logging
import os
from pathlib import Path

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmaps
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
import pandas as pd
import uxarray
import xarray
from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER
from metpy.units import units
from tqdm import tqdm

degE = units.parse_expression("degree_E")
degN = units.parse_expression("degree_N")

try:  # Avoid ValueError when assigning cmap again
    cmap = cmaps.WhiteBlueGreenYellowRed
except ValueError:
    pass

from sklearn.neighbors import BallTree as SKBallTree

In [2]:
def mkcoord(ds):
    if "t_iso_levels" in ds:
        ds = ds.swap_dims(dict(nIsoLevelsT="t_iso_levels", nIsoLevelsZ="z_iso_levels"))
    return ds


def dec_ax(ax, extent):
    # ax.add_feature(cfeature.STATES)
    ax.set_extent(extent)

    gl = ax.gridlines(draw_labels=True, x_inline=False)
    gl.top_labels = gl.right_labels = False
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER


def dBZfunc(dBZ, func):
    """mean of linearized Z, not logarithmic dbZ"""
    Z = 10 ** (dBZ / 10)
    fZ = func(Z)
    return np.log10(fZ) * 10


def trim_latlon(grid_path, data_paths, extent):
    """
    trim grid file and data file to extent box
    TODO: use uxarray.Dataset.subset
    """
    
    # Open the Grid file
    grid_ds = xarray.open_dataset(grid_path)

    grid_ds["lonCell"] = np.degrees(grid_ds.lonCell)
    grid_ds["latCell"] = np.degrees(grid_ds.latCell)

    # before computing the triangulation
    grid_ds["lonCell"] = ((grid_ds["lonCell"] + 180) % 360) - 180

    # Open data files
    ds = xarray.open_mfdataset(
        data_paths, preprocess=mkcoord, concat_dim="Time", combine="nested"
    )

    ibox = (
        (grid_ds.lonCell >= lon0)
        & (grid_ds.lonCell < lon1)
        & (grid_ds.latCell >= lat0)
        & (grid_ds.latCell < lat1)
    )
    # Trim grid
    grid_ds = grid_ds[["latCell", "lonCell"]].where(ibox, drop=True)

    # Trim data
    ds = ds.where(ibox, drop=True)

    return grid_ds, ds

## HWT 2023 MPAS

In [3]:
idate = pd.to_datetime("20230429T00")
base_path = (
    Path("/glade/campaign/mmm/parc/schwartz/HWT2023/mpas")
    / idate.strftime("%Y%m%d%H")
    / "post/mem_5"
)
grid_path = "/glade/campaign/mmm/parc/schwartz/MPAS/15-3km_mesh/mpas_init/static.nc"
# Paths to Data Variable files
var_names = [
    d.strftime("diag.%Y-%m-%d_%H.%M.%S.nc")
    for d in pd.date_range(
        start=idate + pd.Timedelta(hours=3),
        end=idate + pd.Timedelta(hours=18),
        freq="3H",
    )
]
data_paths = [base_path / name for name in var_names]

uxds = uxarray.open_mfdataset(
    grid_path, data_paths, concat_dim="Time", combine="nested", use_dual=False
)
print(uxds.source_datasets)

[PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2023/mpas/2023042900/post/mem_5/diag.2023-04-29_03.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2023/mpas/2023042900/post/mem_5/diag.2023-04-29_06.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2023/mpas/2023042900/post/mem_5/diag.2023-04-29_09.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2023/mpas/2023042900/post/mem_5/diag.2023-04-29_12.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2023/mpas/2023042900/post/mem_5/diag.2023-04-29_15.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2023/mpas/2023042900/post/mem_5/diag.2023-04-29_18.00.00.nc')]




In [None]:
%%time
uxds["refl10cm_max"].isel(Time=2).plot.rasterize(
    method="polygon",
    width=800,
    height=400,
    dynamic=True,
    cmap=cmap,
    exclude_antimeridian=True,
)



## Trim global mesh to lat/lon box

In [None]:
# limit to lat lon box

lat0, lat1 = 41, 42.5
lon0, lon1 = -75, -70

# lat0, lat1 = 35, 43
# lon0, lon1 = -83, -66

extent = [lon0, lon1, lat0, lat1]
grid_ds, ds = trim_latlon(grid_path, data_paths, extent)


figw = 12

triang = tri.Triangulation(grid_ds.lonCell, grid_ds.latCell)

projection = ccrs.LambertConformal(central_longitude=-82)
if lon1 - lon0 > 15:
    projection = ccrs.PlateCarree()


var = ds["refl10cm_max"].isel(Time=0).compute()

norm = mpl.colors.Normalize(vmin=0, vmax=50)

## Upscale to 30-km mesh

In [None]:
%%time
fig, axes = plt.subplots(
    nrows=2,
    figsize=(figw, 12),
    subplot_kw=dict(projection=projection),
)

ax = axes[0]
ax.set_title(f"original {len(grid_ds.lonCell)} cell mesh")
cc = ax.scatter(
    grid_ds["lonCell"],
    grid_ds["latCell"],
    c=var,
    transform=ccrs.PlateCarree(),
    cmap=cmap,
    norm=norm,
    marker=".",
    alpha=0.5,
)

# fig.colorbar(mm, ax=axes, orientation="horizontal", shrink=0.6)

tmpdir = Path(os.getenv("TMPDIR"))
# Upscale to coarse mesh
coarse_mesh = xarray.open_dataset(
    tmpdir.parents[1]
    / "nystrom/pandac/nystrom_3dvar_O30kmIE60km_ColdStart_TEST/CyclingFC/2022020100/static.655362.nc"
)
coarse_mesh["lonCell"] = np.degrees(coarse_mesh.lonCell)
coarse_mesh["latCell"] = np.degrees(coarse_mesh.latCell)
# before computing the triangulation
coarse_mesh["lonCell"] = ((coarse_mesh["lonCell"] + 180) % 360) - 180

ibox = (
    (coarse_mesh.lonCell >= lon0)
    & (coarse_mesh.lonCell < lon1)
    & (coarse_mesh.latCell >= lat0)
    & (coarse_mesh.latCell < lat1)
)
logging.info("Trim latCell and lonCell")
coarse_mesh = coarse_mesh[["latCell", "lonCell"]].where(ibox, drop=True)

X = np.c_[coarse_mesh.latCell.values.ravel(), coarse_mesh.lonCell.values.ravel()]
# List MPAS indices closest to each coarse mesh cell
idxs = SKBallTree(np.deg2rad(X), metric="haversine").query(
    np.deg2rad(np.c_[grid_ds["latCell"], grid_ds["lonCell"]]), return_distance=False
)

coarse_var = np.empty(coarse_mesh["lonCell"].shape)
coarse_var.fill(np.nan)

axes[1].set_title(f"colored by mean value in {len(coarse_mesh.lonCell)} cell mesh")

for idx in np.unique(idxs):
    i = idxs.squeeze() == [idx]
    if any(i):
        coarse_var[idx] = dBZfunc(var.values[i], np.mean)
        c = dBZfunc(var.values[i], np.mean)
        cc = axes[1].scatter(
            grid_ds["lonCell"][i],
            grid_ds["latCell"][i],
            c=[cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba(c)],
            transform=ccrs.PlateCarree(),
            marker=".",
            alpha=0.5,
        )


for ax in axes:
    dec_ax(ax, extent)

coarse_mesh[var.name] = coarse_var

## Plot TC centers with radius of influence and location of maximum
### optional filter by hemisphere or quandrant about center point

### HWT 2023 MPAS

In [None]:
idate = pd.to_datetime("20230531T00")
base_path = (
    Path("/glade/campaign/mmm/parc/schwartz/HWT2023/mpas")
    / idate.strftime("%Y%m%d%H")
    / "post/mem_4"
)
grid_path = "/glade/campaign/mmm/parc/schwartz/MPAS/15-3km_mesh/mpas_init/static.nc"
# Paths to Data Variable files
var_names = [
    d.strftime("diag.%Y-%m-%d_%H.%M.%S.nc")
    for d in pd.date_range(
        start=idate + pd.Timedelta(hours=36),
        end=idate + pd.Timedelta(hours=48),
        freq="6H",
    )
]
data_paths = [base_path / name for name in var_names]
# Tropical storm ARLENE
track = pd.DataFrame(
    data=[
        (pd.to_datetime("20230601T12"), -87.5, 28.6),
        (pd.to_datetime("20230601T18"), -87.5, 28.6),
        (pd.to_datetime("20230602T00"), -87.6, 28.7),
    ],
    columns=["valid_time", "lon", "lat"],
)
track

In [None]:
lat0, lat1 = 24, 32
lon0, lon1 = -93, -80
extent = [lon0, lon1, lat0, lat1]
grid_ds, ds = trim_latlon(grid_path, data_paths, extent)

In [None]:
uxds = uxarray.open_mfdataset(
    grid_path, data_paths, concat_dim="Time", combine="nested", use_dual=False
)
print(uxds.source_datasets)

In [None]:
uxda = uxds["refl10cm_max"]
ibox = (
    (uxds.uxgrid.face_lon >= lon0)
    & (uxds.uxgrid.face_lon < lon1)
    & (uxds.uxgrid.face_lat >= lat0)
    & (uxds.uxgrid.face_lat < lat1)
)

In [None]:
if False:
    uxds.where(ibox)["refl10cm_max"].isel(Time=2).plot.rasterize(
        method="polygon",
        width=800,
        height=400,
        dynamic=True,
        cmap=cmap,
        exclude_antimeridian=True,
    )

In [None]:
%%time
# List of TC centers
latlon = np.vstack(
    (np.deg2rad(grid_ds.latCell.values), np.deg2rad(grid_ds.lonCell.values))
).T
tree = SKBallTree(latlon, metric="haversine")

rptdist = 300  # km
r = np.deg2rad(rptdist / 111.0)
idxs = tree.query_radius(np.deg2rad(track[["lat","lon"]]), r)
idxs

In [None]:
%%time
t = 2
fig, ax = plt.subplots(
    figsize=(figw, 6),
    subplot_kw=dict(projection=projection),
)
var = ds["refl10cm_max"].isel(Time=t, nCells=idxs[0]).compute()

ax.set_title(f"{track.iloc[t]['valid_time']} {len(grid_ds.lonCell)} cells")
cc = ax.scatter(
    grid_ds["lonCell"].isel(nCells=idxs[0]),
    grid_ds["latCell"].isel(nCells=idxs[0]),
    c=var,
    transform=ccrs.PlateCarree(),
    cmap=cmap,
    norm=norm,
    marker=".",
    alpha=0.5,
)
ax.tissot(
    rad_km=rptdist,
    lons=[track.iloc[t]["lon"]],
    lats=[track.iloc[t]["lat"]],
    alpha=0.1,
    facecolor="white",
    edgecolor="black",
)

dec_ax(ax, extent)
ax.add_feature(cfeature.STATES)

fig.colorbar(cc, ax=axes, orientation="horizontal", shrink=0.6)

In [None]:
grid_ds.latCell[6080], uxds.uxgrid.face_lon[5081942], uxds.uxgrid.face_lat[5081942]

In [None]:
uxtree = uxds.uxgrid.get_ball_tree()

In [None]:
fig, ax = plt.subplots(
    figsize=(figw, 6), subplot_kw=dict(projection=ccrs.PlateCarree())
)
var = ds["refl10cm_max"].isel(Time=0)
cc = ax.scatter(
    grid_ds["lonCell"],
    grid_ds["latCell"],
    c=var,
    transform=ccrs.PlateCarree(),
    cmap=cmap,
    norm=norm,
    marker=".",
    alpha=0.5,
)

ax.coastlines(resolution="50m")
for i, (lat, lon, idx) in enumerate(zip(track["lat"],track["lon"], idxs)):
    if len(idx) == 0:
        continue
    lons = grid_ds["lonCell"][idx]
    lats = grid_ds["latCell"][idx]
    # indices in hemispheres and quadrants about center point
    ieast = ((lons - lon) + 180) % 360 - 180 >= 0
    inorth = lats >= lat
    ne = ieast & inorth
    se = ieast & ~inorth
    sw = ~ieast & ~inorth
    nw = ~ieast & inorth
    qfilt = sw | ne
    lons = lons[qfilt]
    lats = lats[qfilt]
    ax.scatter(
        lons,
        lats,
        transform=ccrs.PlateCarree(),
        marker=".",
        color="white",
        edgecolor=None,
        alpha=0.5,
        label=None,
    )
    x = var[idx][qfilt].argmax().compute()
    ax.scatter(
        lons[x],
        lats[x],
        transform=ccrs.PlateCarree(),
        marker="x",
        label=var.values.max(),
    )
ax.tissot(
    rad_km=rptdist,
    lons=[track["lon"]],
    lats=[track["lat"]],
    alpha=0.3,
    facecolor="white",
    edgecolor="black",
)

ax.legend(title="max")
gl = ax.gridlines(draw_labels=True)
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.top_labels = gl.right_labels = False
gl.xlines = gl.ylines = False
# was plt, not fig. TODO: maybe delete this comment?
fig.colorbar(cc, orientation="horizontal", pad=0.04, shrink=0.8)
plt.title(
    f"{grid_ds.title if 'title' in grid_ds.attrs else ''} ({len(grid_ds.lonCell)} cells)",
    fontweight="bold",
    fontsize=14,
)