In [1]:
import xarray as xr
import numpy as np
from minisom import MiniSom
import pandas as pd
from sklearn.preprocessing import RobustScaler
import pprint
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.gridspec import GridSpec 
import matplotlib.colors as mcolors
from matplotlib import colormaps as cm
import matplotlib.gridspec as gridspec 
import warnings

# Ignore all warnings
warnings.filterwarnings("ignore")

from som_2var_training import read_and_transform, build_scaler, train_som

In [2]:
ds_tor = xr.open_dataset('./Datasets/pper_tor_1979_2023.nc')
ds_wind = xr.open_dataset('./Datasets/pper_wind_1979_2023.nc')
ds_hail = xr.open_dataset('./Datasets/pper_hail_1979_2023.nc')

In [3]:
prefs = {'filename': "/home/scratch/dwefer/GEFSv12/z500_pwat.nc", 
         'var1': 'gh', 'var2': 'pwat', 'wlon': 220, 'elon': 305, 'nlat': 55, 'slat': 20, 
         'som_config': {'x': 5, 'y': 5,'sigma': 0.5, 'random_seed': 42},
         'som_train': {'num_iteration': 10000, 'random_order': True, 'verbose': True}}

trained_som, trained_scaler, train_info, xr_data = train_som(prefs)


current model configuration
{'input_len': 24282, 'random_seed': 42, 'sigma': 0.5, 'x': 5, 'y': 5}
current training configuration
{'data': array([[-0.66050106, -0.68316907, -0.70958024, ..., -0.7250001 ,
        -0.7583334 , -0.7583332 ],
       [-0.71175027, -0.7709796 , -0.83132076, ..., -0.39166674,
        -0.35000005, -0.2916665 ],
       [-0.28915408, -0.28948432, -0.29526958, ..., -0.6833334 ,
        -0.8833334 , -1.0499998 ],
       ...,
       [-0.87913156, -0.8735199 , -0.8690234 , ...,  0.5749998 ,
         0.49166647,  0.375     ],
       [-0.31367952, -0.29052007, -0.26362565, ...,  0.25887743,
         0.21721077,  0.16721089],
       [-0.43595633, -0.43200427, -0.43102336, ..., -0.10833359,
        -0.19166677, -0.24999984]], dtype=float32),
 'num_iteration': 10000,
 'random_order': True,
 'verbose': True}
 [ 10000 / 10000 ] 100% - 0:00:00 left 
 quantization error: 72.93549590715341


In [None]:
plt.rcParams['figure.figsize'] = 7, 3
x = [x+0.5 for x in range(train_info['som_config']['x'])]
y = [y+0.5 for y in range(train_info['som_config']['y'])]

ax = plt.subplot(1, 2, 1)
ax.set_yticks(y, range(train_info['som_config']['y']))
ax.set_xticks(x, range(train_info['som_config']['x']))
ax.set_title("Distance between nodes")
mmp = ax.pcolor(trained_som.distance_map().T, cmap='bone_r')
plt.colorbar(mmp, ax=ax)

ax = plt.subplot(1, 2, 2)
ax.set_yticks(y, range(train_info['som_config']['y']))
ax.set_xticks(x, range(train_info['som_config']['x']))
ax.set_title("Node frequency")
mmp = ax.pcolor(trained_som.activation_response(train_info['som_train']['data']).T, cmap='Blues')
plt.colorbar(mmp, ax=ax)

In [None]:
win_map = trained_som.win_map(train_info["som_train"]["data"])
freq = trained_som.activation_response(train_info["som_train"]["data"]).ravel()

nx, ny = prefs["som_config"]["x"], prefs["som_config"]["y"]  # 5, 5
n_nodes = nx * ny  

lat = xr_data['latitude'].values
lon = xr_data['longitude'].values
nlat, nlon = len(lat), len(lon)
ngrid = nlat * nlon


nodes = [(i, j) for i in range(nx) for j in range(ny)]

def inv_mean(node):
    v = np.mean(win_map[node], axis=0)
    return trained_scaler.inverse_transform([v])[0]

means = np.vstack([inv_mean(node) for node in nodes]) 

z500_avg = means[:, :ngrid].reshape(n_nodes, nlat, nlon)
pwat_avg = means[:, ngrid:].reshape(n_nodes, nlat, nlon)

da_z500 = xr.DataArray(
    z500_avg,
    dims=("node", "latitude", "longitude"),
    coords={"node": np.arange(n_nodes), "latitude": lat, "longitude": lon},
    name="gh"
)
da_pwat = xr.DataArray(
    pwat_avg,
    dims=("node", "latitude", "longitude"),
    coords={"node": np.arange(n_nodes), "latitude": lat, "longitude": lon},
    name="pwat"
)


extent = [
    float(xr_data["longitude"].values.min() - 340),
    float(xr_data["longitude"].values.max() - 380),
    float(xr_data["latitude"].values.min() + 3),
    float(xr_data["latitude"].values.max() - 3),
]

proj = ccrs.LambertConformal()
pc = ccrs.PlateCarree()

norm_z500 = mcolors.Normalize(vmin=float(5400), vmax=float(da_z500.max()))
norm_pwat = mcolors.Normalize(vmin=float(da_pwat.min()), vmax=float(da_pwat.max()))

def plot_node_block(node_indices, nrows, ncols, title, dpi=220):
    # axes grid is (nrows) x (2*ncols)
    fig_w = 2.2 * (2 * ncols)
    fig_h = 2.0 * nrows

    fig, axs = plt.subplots(
        nrows, 2 * ncols,
        figsize=(fig_w, fig_h),
        dpi=dpi,
        subplot_kw={"projection": proj},
        constrained_layout=True
    )

    
    if nrows == 1:
        axs = np.expand_dims(axs, axis=0)

    mappable_z = None
    mappable_p = None

    for k in range(nrows * ncols):
        r = k // ncols
        c = k % ncols

        ax_z = axs[r, 2*c]
        ax_p = axs[r, 2*c + 1]

        if k >= len(node_indices):
            # blank out unused panels 
            ax_z.set_visible(False)
            ax_p.set_visible(False)
            continue

        node_idx = node_indices[k]
        i = node_idx // ny
        j = node_idx % ny

        for ax in (ax_z, ax_p):
            ax.set_extent(extent, crs=pc)
            ax.add_feature(cfeature.BORDERS, linewidth=0.6)
            ax.add_feature(cfeature.STATES, linewidth=0.4)

        mappable_z = da_z500.isel(node=node_idx).plot.pcolormesh(
            ax=ax_z, transform=pc,
            cmap="coolwarm",
            add_colorbar=False,
            norm=norm_z500
        )
        da_z500.isel(node=node_idx).plot.contour(
            ax=ax_z, transform=pc,
            colors="k", linewidths=0.6,
            add_colorbar=False
        )
        ax_z.set_title(f"node {i},{j}  n={int(freq[node_idx])}\nZ500", fontsize=8)

        mappable_p = da_pwat.isel(node=node_idx).plot.pcolormesh(
            ax=ax_p, transform=pc,
            cmap="gist_earth_r",
            add_colorbar=False,
            norm=norm_pwat
        )
        da_pwat.isel(node=node_idx).plot.contour(
            ax=ax_p, transform=pc,
            colors="k", linewidths=0.6,
            add_colorbar=False
        )
        ax_p.set_title(f"node {i},{j}  n={int(freq[node_idx])}\nPWAT", fontsize=8)

    
    visible_axes = [ax for ax in fig.axes if ax.get_visible()]
    cb_z = fig.colorbar(mappable_z, ax=visible_axes, orientation="horizontal",
                        pad=0.02, fraction=0.05, extend="both")
    cb_z.set_label("500 hPa heights (m)")

    cb_p = fig.colorbar(mappable_p, ax=visible_axes, orientation="horizontal",
                        pad=0.07, fraction=0.05, extend="both")
    cb_p.set_label("PWAT (kg m**-2)")

    fig.suptitle(title, y=1.02, fontsize=12)
    plt.show()


plot_node_block(list(range(0, 12)),  nrows=4, ncols=3, title="SOM nodes 0–11: z500 | PWAT")
plot_node_block(list(range(12, 21)), nrows=3, ncols=3, title="SOM nodes 12–20: z500 | PWAT")
plot_node_block(list(range(21, 25)), nrows=3, ncols=3, title="SOM nodes 21–24: z500 | PWAT")

In [None]:
extent = [
    float(xr_data.longitude.values.min() - 340),
    float(xr_data.longitude.values.max() - 380),
    float(xr_data.latitude.values.min() + 3),
    float(xr_data.latitude.values.max() - 3),
]

fig2 = plt.figure(figsize=(8, 7), dpi=200)
ax = fig2.add_subplot(1, 1, 1, projection=ccrs.LambertConformal())
ax.set_title("Node (0, 0) Mean PWAT", fontsize=12)
ax.set_extent(extent, crs=ccrs.PlateCarree())

node_idx = 0

pcm = da_pwat.isel(node=node_idx).plot.pcolormesh(
    ax=ax,
    x="longitude", y="latitude",
    cmap="gist_earth_r",
    shading="auto",
    transform=ccrs.PlateCarree(),
    add_colorbar=False
)

pct = da_pwat.isel(node=node_idx).plot.contour(
    ax=ax,
    x="longitude", y="latitude",
    colors="k",
    linewidths=0.6,
    transform=ccrs.PlateCarree(),
    add_colorbar=False
)

ax.clabel(pct, fmt="%.1f", inline=True, fontsize=7)
ax.add_feature(cfeature.BORDERS, linewidth=0.6)
ax.add_feature(cfeature.STATES, linewidth=0.4)
ax.coastlines()

cb = plt.colorbar(pcm, orientation="horizontal", pad=0.05, aspect=40, extend="both")
cb.set_label("PWAT (kg m$^{-2}$)")

plt.tight_layout()
plt.show()


In [None]:
extent = [
    float(xr_data.longitude.values.min() - 340),
    float(xr_data.longitude.values.max() - 380),
    float(xr_data.latitude.values.min() + 3),
    float(xr_data.latitude.values.max() - 3),
]

node_idx = 0


fig1 = plt.figure(figsize=(8, 7), dpi=200)
ax_z500 = fig1.add_subplot(1, 1, 1, projection=ccrs.LambertConformal())
ax_z500.set_title("Node 0 Upper Level Flow", fontsize=12)
ax_z500.set_extent(extent, crs=ccrs.PlateCarree())


zcm = da_z500.isel(node=node_idx).plot.pcolormesh(
    ax=ax_z500,
    x="longitude", y="latitude",
    cmap="coolwarm",
    shading="auto",
    transform=ccrs.PlateCarree(),
    add_colorbar=False
)

zct = da_z500.isel(node=node_idx).plot.contour(
    ax=ax_z500,
    x="longitude", y="latitude",
    colors="k",
    linewidths=0.7,
    levels=np.arange(4500, 6300, 45),
    transform=ccrs.PlateCarree(),
    add_colorbar=False
)

ax_z500.clabel(zct, fmt="%.0f", inline=True, fontsize=7)
ax_z500.add_feature(cfeature.BORDERS, linewidth=0.6)
ax_z500.add_feature(cfeature.STATES, linewidth=0.4)
ax_z500.coastlines()

cb = plt.colorbar(zcm, orientation="horizontal", pad=0.05, aspect=40, extend="both")
cb.set_label("500 hPa height (m)")

plt.tight_layout()
plt.savefig("./Figures/node00z500.png", bbox_inches="tight")
plt.show()


In [None]:
win_map_idx = trained_som.win_map(train_info['som_train']['data'], return_indices = True)

In [None]:
colors = [
    "#f4d71a",
    "#eb973d",
    "#bf4c78",
    "#5811a7",
    "#170f88",
]
levels = [0.1, 0.5, 2.0, 5.0, 10.0, 100]

nodes = list(sorted(win_map_idx.keys()))
n_nodes = len(nodes)

ncols = 3
nrows = int(np.ceil(n_nodes / ncols))

dpi = 300

plt.rcParams["figure.figsize"] = (5 * ncols, 4.5 * nrows)
fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    dpi=dpi,
    subplot_kw={"projection": ccrs.LambertConformal()},
)

axes = np.atleast_2d(axes)

for k, node in enumerate(nodes):
    row = k // ncols
    col = k % ncols
    ax = axes[row, col]

    idx_list = win_map_idx[node]
    nodesel = xr_data.isel(time=idx_list)

    # shift back a day bc pper is 12z–12z, and dataset is 0z
    time_sel = (nodesel.time - np.timedelta64(1, "D")).astype("datetime64[D]").astype(str)

    ds_hail_sel = ds_hail.sel(time=time_sel, method="nearest")
    hail_mean = ds_hail_sel.p_perfect_hail.mean(dim="time")  # (y, x)

    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.STATES, linewidth=0.5)
    ax.add_feature(cfeature.LAKES)
    ax.set_extent([-122, -67, 20, 50], crs=ccrs.PlateCarree())

    ax.set_title(f"Node {node} – Hail (n={len(idx_list)})", fontsize=9)

    cm_hail = ax.contourf(
        ds_hail.lon,
        ds_hail.lat,
        hail_mean,
        colors=colors,
        levels=levels,
        transform=ccrs.PlateCarree(),
    )

    plt.colorbar(cm_hail, orientation="horizontal", ax=ax, pad=0.02, aspect=30)


for k in range(n_nodes, nrows * ncols):
    axes.flat[k].set_visible(False)

plt.tight_layout()
plt.show()
# fig.savefig("./Figures/hail_by_node_3col.png", dpi=dpi, bbox_inches="tight")
