In [None]:
import logging

import cm1
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from cm1.input.sounding import Sounding
from cm1.utils import TMPDIR, skewt
from metpy.units import units

# Configure logging
logging.basicConfig(
    level=logging.WARNING, format="%(asctime)s - %(levelname)s: %(message)s", force=True
)

### Predefined CM1 input soundings

In [None]:
dss = [
    Sounding().get_case("trier"),
    Sounding().get_case("jordan_allmean"),
    Sounding().get_case("jordan_hurricane"),
    Sounding().get_case("rotunno_emanuel"),
    Sounding().get_case("dunion_MT"),
    Sounding().get_case("bryan_morrison"),
    Sounding().get_case("seabreeze_test"),
]
fig, axes = plt.subplots(
    ncols=2,
    figsize=(10, 16),
)
# Hide the axes for all subplots
for ax in axes.flat:
    ax.axis("off")

for i, ds in enumerate(dss):
    # Use subplot because skewt() uses metpy.plots.SkewT, which uses it.
    logging.info(ds.case)
    # replace zero mixing ratio (Q) with small value so skewT isn't cut off.
    ds["Q"] = ds["Q"].where(ds["Q"] != 0, 1e-9)
    skew = skewt(ds, fig=fig, subplot=(4, 2, i + 1))
    skew.ax.set_title(
        ds.case + skew.ax.get_title(), fontsize="x-small"
    )  # Set title to sounding case

plt.tight_layout()

In [None]:
valid_time = pd.to_datetime("20240525")
lon = -95 * units.degree_E
lat = 33 * units.degree_N
ds_aws = cm1.input.sounding.era5_aws(valid_time, lat, lon)
ds_aws

### Given a latitude and longitude, select nearest square slice

In [None]:
ds0 = cm1.input.era5.model_level(valid_time).load()

# n x n square around nearest point
n = 3
sel = cm1.utils.nearest_grid_block_sel(ds0, lat=lat, lon=lon, n=n)

fig, axes = plt.subplots(ncols=n, nrows=n, sharex=True, sharey=True, figsize=(n * 5, n * 5))
# Hide the axes for all subplots
for i, ax in enumerate(axes.flat):
    ax.axis("off")
    ds = ds0.sel(sel).stack(i=("latitude", "longitude")).isel(i=i)
    skew = skewt(ds, fig=fig, subplot=(n, n, i + 1))
    skew.ax.set_title(f"{skew.ax.get_title()}", fontsize="x-small")
    ofile = skew.ax.get_title().split("\n")[0].replace(" ", "_") + ".txt"
    ofile = TMPDIR / ofile
    with open(ofile, "w") as fh:
        fh.write(Sounding(ds).to_txt())
        logging.warning(ofile)

In [None]:
ds0

In [None]:
import numpy as np
from scipy.spatial import cKDTree


def select_n_nearest(da, target_lat, target_lon, n=5, lat_dim="latitude", lon_dim="longitude"):
    """
    Selects the n-nearest spatial neighbors using a geographically accurate
    great-circle distance search.
    """

    # Helper function to convert lat/lon degrees to 3D Cartesian coordinates
    def to_cartesian(lat, lon):
        lat_rad = np.deg2rad(lat)
        lon_rad = np.deg2rad(lon)
        x = np.cos(lat_rad) * np.cos(lon_rad)
        y = np.cos(lat_rad) * np.sin(lon_rad)
        z = np.sin(lat_rad)
        return np.column_stack([x, y, z])

    # Convert all grid points to 3D
    lat_vals = da[lat_dim].values
    lon_vals = da[lon_dim].values
    lat_grid, lon_grid = np.meshgrid(lat_vals, lon_vals, indexing="ij")
    grid_points_3d = to_cartesian(lat_grid.ravel(), lon_grid.ravel())

    # Build the K-D Tree on the 3D coordinates
    kdtree = cKDTree(grid_points_3d)

    # Convert the target point and query the tree
    target_point_3d = to_cartesian(target_lat, target_lon)
    _, flat_indices = kdtree.query(target_point_3d, k=n)

    # Convert flat indices back to 2D grid indices
    lat_indices_2d, lon_indices_2d = np.unravel_index(flat_indices, (len(lat_vals), len(lon_vals)))

    # Flatten the index arrays immediately to make them 1D for all subsequent use.
    lat_indices = lat_indices_2d.flatten()
    lon_indices = lon_indices_2d.flatten()

    # Create 1D DataArray indexers
    lat_indexer = xr.DataArray(lat_indices, dims="station")
    lon_indexer = xr.DataArray(lon_indices, dims="station")

    # Select data
    result = da.isel({lat_dim: lat_indexer, lon_dim: lon_indexer})

    # Assign final coordinates using the now-1D index arrays
    result = result.assign_coords(
        {
            "latitude": ("station", lat_vals[lat_indices]),
            "longitude": ("station", lon_vals[lon_indices]),
        }
    )

    return result


from IPython.display import HTML
from matplotlib.animation import FuncAnimation

# Nine nearest neighbors are not always a 3x3 square.
# 1. Create a list of all parameters that will define the frames
num_frames = 8

# 2. Set up the figure and axes
fig, ax = plt.subplots(figsize=(8, 6))

# Create the plot objects that will be updated in each frame
# We start with empty data
scatter = ax.scatter([], [], c=[], cmap="viridis", s=100, vmin=0, vmax=1)
(target_marker,) = ax.plot([], [], "r+", markersize=12)  # The comma is important!
fig.colorbar(scatter, ax=ax, label="Z Value")
target_lat = lat
target_lon = lon


# 3. Define the initialization function
def init():
    ax.set_xlim(lon - 0.7 * units.degrees_E, lon + 0.8 * units.degrees_E)
    ax.set_ylim(lat - 0.7 * units.degrees_N, lat + 0.8 * units.degrees_N)
    ax.grid(True, linestyle="--", alpha=0.6)
    # Return the artists that will be updated
    return scatter, target_marker


# 4. Define the update function (called for each frame)
def update(frame):
    global target_lat, target_lon
    distance = 0.02 * units.degree
    # 2. Choose a random direction (angle in radians)
    angle = np.random.uniform(0, 2 * np.pi)
    # 3. Calculate the change in lat and lon
    dlat = distance * np.sin(angle)
    # Correct dlon for latitude to ensure consistent distance
    dlon = distance * np.cos(angle) / np.cos(np.deg2rad(target_lat.m))

    # 4. Update the target's position from its PREVIOUS location
    target_lat += dlat
    target_lon += dlon

    # Calculate the new data
    nine_points = select_n_nearest(ds0.Z, target_lat=target_lat, target_lon=target_lon, n=9)
    # Update the data of the existing plot objects
    positions = np.column_stack((nine_points["longitude"] - 360, nine_points["latitude"]))
    scatter.set_offsets(positions)
    scatter.set_array(nine_points.isel(level=0).values)
    target_marker.set_data([target_lon], [target_lat])

    # Return the updated artists
    return scatter, target_marker


# 5. Create the animation object
# blit=False is simpler and ensures the title updates correctly
ani = FuncAnimation(fig, update, frames=num_frames, init_func=init, interval=100, blit=False)

plt.close()

# 6. Display in Jupyter
HTML(ani.to_html5_video())

In [None]:
import metpy.calc as mpcalc
from scipy.ndimage import gaussian_filter

ds = ds0

ds = ds.sel(longitude=slice(lon.m + 335, lon.m + 385))
ds = ds.sel(latitude=slice(lat.m + 15, lat.m - 15))

# Tried attach MetPy parsing (to interpret units and coordinates)
# but gaussian_filter dropped them anyway. Just add manually when
# making the DataArrays.

# Get wind components (assumes they are in m/s and with coordinates parsed)
level = 60
u = ds["U"].metpy.sel(level=level)  # or any other level
v = ds["V"].metpy.sel(level=level)
u_smooth = xr.DataArray(
    gaussian_filter(u.values, sigma=5) * u.metpy.units,
    dims=u.dims,
    coords=u.coords,
    attrs=u.attrs,
)
v_smooth = xr.DataArray(
    gaussian_filter(v.values, sigma=5) * v.metpy.units,
    dims=v.dims,
    coords=v.coords,
    attrs=v.attrs,
)


# Grid spacing (assumes 1D lat/lon)
dx, dy = mpcalc.lat_lon_grid_deltas(ds["longitude"], ds["latitude"])

# Compute vorticity (returns units of 1/s)
fld = mpcalc.vorticity(u_smooth, v_smooth, dx=dx, dy=dy)
fld_str = "vorticity"
# `div` is a 2D array (latitude x longitude) at the selected level
ds[fld_str] = (("latitude", "longitude"), fld.data)

z500 = ds["Z"].sel(level=level) / 10.0  # decameters (dam)

z500_smooth = xr.DataArray(
    gaussian_filter(z500.values, sigma=5),
    dims=z500.dims,
    coords=z500.coords,
    attrs=z500.attrs,
)

fig, ax = plt.subplots(figsize=(8, 6))
ds[fld_str].plot(robust=True)
print(z500_smooth.mean())
contours = z500_smooth.plot.contour(ax=ax, colors="k", levels=range(30, 3000, 6))
ax.clabel(contours, fmt="%d")  # Add contour labels
ax.plot(
    lon.m + 360,
    lat.m,
    "X",
    markersize=10,
    color="yellow",
    markeredgewidth=2,
    markerfacecolor="none",
    label="sounding",
)

In [None]:
cm1.input.era5.pressure_level(valid_time).load()

In [None]:
ds0.surface_geopotential_height.plot()  # LSM went away from Gaussian grid model level invariants.

In [None]:
s = Sounding().get_case("trier").to_txt()
print(s)
Sounding.from_txt(s)

## Show difference between pressure-level and model level soundings
* fewer pressure levels than model levels
* winds at surface (10u and 10v) not in pressure-level sounding

In [None]:
fig, axes = plt.subplots(ncols=2, sharey=True, figsize=(14, 7))
# Hide the axes for all subplots
for ax in axes.flat:
    ax.axis("off")
skew = ds_aws.plot(fig=fig, subplot=(1, 2, 1))
skew.ax.set_title(f"ERA5 AWS pressure level  {skew.ax.get_title()}", fontsize="x-small")
skew = skewt(
    ds0.sel(
        latitude=ds_aws.coords["latitude"], longitude=ds_aws.coords["longitude"], method="nearest"
    ),
    fig=fig,
    subplot=(1, 2, 2),
)
skew.ax.set_title(f"model level  {skew.ax.get_title()}", fontsize="x-small")

In [None]:
ds_aws.plot()

In [None]:
ds = cm1.input.sounding.era5_pressure_level(valid_time, lat, lon)
ds