## Doesn't work with FV3

In [1]:
import logging
import pdb
from pathlib import Path

import cartopy
import cmaps
import holoviews as hv
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import uxarray
import xarray
from tqdm import tqdm

try:  # Avoid ValueError when assigning cmap again
    cmap = cmaps.WhiteBlueGreenYellowRed
except ValueError:
    pass
import warnings

import geoviews.feature as gf
from hwt import Model, helicityThresholds, mpas, xtime

In [2]:
uxarray.__version__

'2024.10.2.dev10+gbddd2fa3'

In [3]:
v = "updraft_helicity_max"
lon_bounds = (-95, -86)
lat_bounds = (36, 45)

features = gf.coastline(projection=cartopy.crs.PlateCarree(), scale="50m") * gf.states(
    projection=cartopy.crs.PlateCarree(), scale="50m"
)

In [4]:
%%time
grid_path = "/glade/campaign/mmm/parc/schwartz/MPAS/15-3km_mesh/mpas_init/static.nc"

# Multiple thresholds for updraft helicity
lead_time_days = range(2)
forecast_hours = range(13, 37, 6)
valid_date = pd.to_datetime("20240521")
oneday = pd.to_timedelta(1, unit="day")

CPU times: user 1.38 ms, sys: 334 μs, total: 1.72 ms
Wall time: 1.72 ms


In [None]:
model = mpas
mpas.v = "updraft_helicity_max03"
assert model.name != "fv3", (
    "can't read fv3 cubed sphere nest yet.\n"
    "TODO:\n"
    "sfc and atmos variables split between two files "
    "atmos_sos*.nest02_%Y_%m_%d_%H.tile7.nc\n"
    "time dimension called time\n"
    "grid_xt = T-cell longitude\n"
    "grid_yt = T-cell latitude\n"
)

# Create list of input files
# This is a nested list comprehension, looping through
# forecast_hours (iterable of forecast hours)
#    lead_time_days (iterable of lead times in days)
#        members (1 through model.nmem)
fmt = "diag.%Y-%m-%d_%H.%M.%S.nc"
ifiles = [
    Path(f"/glade/campaign/mmm/parc/schwartz/HWT2024/{model}")
    / (valid_date - lead_time_day * oneday).strftime("%Y%m%d%H")
    / "post"
    / f"mem_{mem}"
    / (valid_date + pd.to_timedelta(fhr, unit="hour")).strftime(fmt)
    for mem in range(1, model.nmem + 1)
    for lead_time_day in lead_time_days
    for fhr in forecast_hours
]
print(f"open {len(ifiles)} files")

print(ifiles[0:10])
ds0 = xarray.open_mfdataset(
    ifiles,
    preprocess=xtime,
    combine_attrs="drop_conflicts",
)

uxds = uxarray.open_mfdataset(
    grid_path,
    ifiles,
    preprocess=xtime,
)
uxds

open 40 files
[PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052100/post/mem_1/diag.2024-05-21_13.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052100/post/mem_1/diag.2024-05-21_19.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052100/post/mem_1/diag.2024-05-22_01.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052100/post/mem_1/diag.2024-05-22_07.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052000/post/mem_1/diag.2024-05-21_13.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052000/post/mem_1/diag.2024-05-21_19.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052000/post/mem_1/diag.2024-05-22_01.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052000/post/mem_1/diag.2024-05-22_07.00.00.nc'), PosixPath('/glade/campaign/mmm/parc/schwartz/HWT2024/mpas/2024052100/post/mem_2/diag.2024-05-21_1

In [None]:
uxds[v]

In [None]:
%%time
hv.extension("matplotlib")

# maximum value for all times
# fixed subset.bounding_box so it does not lose mem and initial_time coordinate indices.
uh_24hmax = (
    uxds[v].subset.bounding_box(lon_bounds, lat_bounds).metpy.quantify().max(dim="valid_time")
)
uh_24hmax.name += "_24h"


rpt_type = "hail"
opts = {}
opts["torn"] = dict(color="red", marker="v")
opts["hail"] = dict(color="green", marker="^")
opts["wind"] = dict(color="blue", marker="s")
rpts = pd.read_csv(
    f"https://www.spc.noaa.gov/climo/reports/{valid_date.strftime('%y%m%d')}_rpts_{rpt_type}.csv"
)

# Time is an integer with hours in the thousands and hundreds place and minutes
# in the tens and ones places.
# Extract hours from time by dividing by 100 and converting result to integer.
rpts["h"] = (rpts.Time / 100).astype(int)
# Time window of spc reports is from 12 UTC on the current day to 12 UTC the next day.
# If hours is less than 12, it is from the next day. Therefore, add 24.
rpts.loc[rpts["h"] < 12, "h"] += 24
rpts["minutes"] = rpts.Time % 100
rpts["valid_time"] = (
    valid_date
    + pd.to_timedelta(rpts["h"], unit="h")
    + pd.to_timedelta(rpts["minutes"], unit="minute")
)
irange = rpts.Lat.between(*lat_bounds) & rpts.Lon.between(*lon_bounds)
logging.warning(f"{irange.sum()}/{len(rpts)} {rpt_type} rpts in bounds")
rpts = rpts[irange]


scatter = hv.Scatter(rpts[["Lon", "Lat"]]).opts(s=25, edgecolor="none", **opts[rpt_type])

title = valid_date
p_list = [
    uh_24hmax.sel(mem=mem)
    .sel(initial_time=initial_time)
    .plot.polygons(
        rasterize=True,
        title=f"initial_time {initial_time} mem_{mem}",
        cmap=cmap,
        clim=(0, 250),
        backend="matplotlib",
    )
    .opts(
        fontsize={
            "title": "xx-small",
            "xticks": "xx-small",
            "yticks": "xx-small",
            "labels": "xx-small",
        },
        backend_opts={"colorbar.ax.get_yticklabels().fontsize": "xx-small"},
    )
    * features
    * scatter
    for initial_time in uh_24hmax.initial_time.data
    for mem in uh_24hmax.mem.data
]
layout = hv.Layout(p_list).cols(model.nmem)
layout.opts(fig_size=64)

In [None]:
%%time
da = uh_24hmax.remap.inverse_distance_weighted(
    uh_24hmax.uxgrid,
    remap_to="face centers",
    power=0,
    k=169,
)
p_list = [
    da.sel(mem=mem)
    .sel(initial_time=initial_time)
    .plot.polygons(
        rasterize="true",
        title=f"initial_time {initial_time} mem_{mem}",
        cmap=cmap,
        clim=(0, 160),
        backend="matplotlib",
    )
    .opts(
        fontsize={
            "title": "xx-small",
            "xticks": "xx-small",
            "yticks": "xx-small",
            "labels": "xx-small",
        },
        backend_opts={"colorbar.ax.get_yticklabels().fontsize": "xx-small"},
    )
    * features
    * scatter
    for initial_time in uh_24hmax.initial_time.data
    for mem in uh_24hmax.mem.data
]
layout = hv.Layout(p_list).cols(model.nmem)
layout.opts(fig_size=64)

In [None]:
%%time
coarse_mesh_path = Path(
    "/glade/campaign/mmm/parc/schwartz/"
    "MPAS_regional/15km_mesh_regional/mpas_init/regional_15km_mesh_2000km.static.nc"
)
coarse_mesh = uxarray.open_grid(coarse_mesh_path).subset.bounding_box(lon_bounds, lat_bounds)

da = uh_24hmax.remap.inverse_distance_weighted(coarse_mesh).neighborhood_filter(
    r=20.0 / 111.0, func=np.mean
)
p_list = [
    da.sel(mem=mem)
    .sel(initial_time=initial_time)
    .plot.polygons(
        rasterize=True,
        title=f"initial_time {initial_time} mem_{mem}",
        cmap=cmap,
        clim=(0, 160),
        backend="matplotlib",
    )
    .opts(fontsize='x-small')
    * features
    * scatter
    for initial_time in uh_24hmax.initial_time.data
    for mem in uh_24hmax.mem.data
]
layout = hv.Layout(p_list).cols(model.nmem)
layout.opts(fig_size=100)

In [None]:
uh_thresh = helicityThresholds.metpy.quantify()[3]

r = 20.0
fy = (
    (uh_24hmax >= uh_thresh)
    .remap.inverse_distance_weighted(coarse_mesh)
    .neighborhood_filter(r=r / 111.0, func=np.max)
)

fy = fy.mean(dim="mem")
fy.name = f"ens prob of {uh_thresh.values}+ {uh_24hmax.name} w/in {r} km"

values = fy.values
values[values > 1] = 1
fy.values = values
assert fy.max() <= 1
p_list = [
    fy.sel(initial_time=initial_time)
    .plot.polygons(
        rasterize=True,
        title=f"initial_time {initial_time}",
        cmap=cmap,
        # norm=matplotlib.colors.BoundaryNorm(
        #    boundaries=np.arange(-0.5 / model.nmem, 1 + 0.5 / model.nmem, 1 / model.nmem),
        #    ncolors=256,
        # ),
        backend="matplotlib",
    )
    .opts(fontsize={"title": "x-small", "labels": "x-small"}, alpha=0.6)
    * features
    * scatter
    for initial_time in fy.initial_time.data
]
layout = hv.Layout(p_list).cols(fy.initial_time.size)
layout

In [None]:
verify_fine_mesh = False
if verify_fine_mesh:
    prob_thresh = 0.4
    fy = uxarray.UxDataArray((fy >= prob_thresh).astype(int), uxgrid=coarse_mesh)
    fy.name = f"ensemble probability of {uh_thresh}+ {uh_24hmax.name} >= {prob_thresh}"
    fy = fy.remap.nearest_neighbor(uh_24hmax.uxgrid)
    fy = uxarray.UxDataArray(fy, uxgrid=uh_24hmax.uxgrid)

    p_list = [
        fy.sel(initial_time=initial_time).plot.rasterize(
            method="polygon",
            exclude_antimeridian=True,
            backend="matplotlib",
        )
        * features
        for initial_time in fy.initial_time.data
    ]
    layout = hv.Layout(p_list).cols(fy.initial_time.size)
    layout.opts(fig_size=250)

In [None]:
fy

In [None]:
# reduce initialization time dimension.
# could be max. could be mean..
fy = fy.mean(dim="initial_time")


uxtree = coarse_mesh.get_ball_tree(coordinates="face centers")
rptdist = 40  # km
# get indices close to report (for each report)
idx = uxtree.query_radius(rpts[["Lon", "Lat"]], rptdist / 111.0)


oy = fy.copy()
# initialize with False
oy.values[:] = False
oy.name = rpt_type
for i in idx:
    oy.values.put(i, True)

In [None]:
oy.plot.polygons(
    rasterize=True,
    cmap=["white", opts[rpt_type]["color"]],
    backend="matplotlib",
).opts(fig_size=100, alpha=0.25) * scatter * features

In [None]:
prob_thresh = 1e-9
on = 1 - oy
hit = oy * (fy >= prob_thresh)
miss = oy * (fy < prob_thresh)
correct_null = on * (fy < prob_thresh)
fa = on * (fy >= prob_thresh)
hmfn = hit * 0 + miss * 1 + fa * 2 + correct_null * 3
hmfn.name = "hit/miss/fa/null"

hmfn.plot.polygons(
    rasterize=True,
    cmap=["lightgreen", "blue", "red", "white"],
    clim=(0, 4),
    backend="matplotlib",
).opts(alpha=0.5) * scatter * features