In [1]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from scipy.interpolate import griddata
import torch
import os
import matplotlib.cm as cm

In [3]:
file = "/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_icon.nc"
ds = xr.open_dataset(file)
ds

In [5]:
def compute_metrics(ds):
    """
    Computes metrics for each sample.
    """


    # radar reflectivity statistics
    da_ze_mean = ds.model_108.mean(["y", "x"])  # TODO: in dB?
    da_ze_std = ds.model_108.std(["y", "x"])
    da_ze_max = ds.model_108.max(["y", "x"])
    da_ze_q90 = ds.model_108.quantile(0.9, ["y", "x"])


    # combine to single dataset
    ds_metrics = xr.Dataset(
        {
            "ze_mean": da_ze_mean,
            "ze_std": da_ze_std,
            "ze_max": da_ze_max,
            "ze_q90": da_ze_q90,
        }
    )

    return ds_metrics


def write_metrics(ds_metrics):
    ds_metrics.to_netcdf(
        "/p/project1/exaww/chatterjee1/dataset/metrics_icon_108_randomcrops.nc"
    )


def read_metrics():
    return xr.open_dataset(
        "/p/project1/exaww/chatterjee1/dataset/metrics_icon_108_randomcrops.nc"
    )

In [6]:
recompute = False

if recompute:
    ds_metrics = compute_metrics(ds)
    write_metrics(ds_metrics)
else:
    ds_metrics = read_metrics()

print(ds_metrics.data_vars)

Data variables:
    ze_mean  (sample) float32 ...
    ze_std   (sample) float32 ...
    ze_max   (sample) float32 ...
    ze_q90   (sample) float64 ...


In [8]:
ds_metrics.ze_mean

In [4]:
def read_features(compression="multiscale"):
    if compression == "multiscale":
        fext = "500multiscale50"
    elif compression == "annealing":
        fext = "500annealing50"
    elif compression == "pca":
        fext = "pcacosine"
    else:
        raise ValueError(f"Unknown compression {compression}")

    ds_tsne = np.load(
        f"/p/scratch/exaww/chatterjee1/nn_obs/all_nc/features_icon/tsne_icon_my_{fext}.npy"
    )
    x = ds_tsne[:, 0]
    y = ds_tsne[:, 1]
    return x, y

In [9]:
def plot_tsne(stats, compression):
    fig, axs = plt.subplots(1, 1, figsize=(7, 5), layout="constrained")

    x, y = read_features(compression=compression)
    if stats == 'std':
        im = axs.scatter(
        x, y, s=1, c=ds_metrics[f"ze_{stats}"], cmap='viridis', vmax=35, vmin=0,
        )
        axs.set_title(f"{stats}_88x88")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'mean':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"ze_{stats}"], cmap='viridis', vmax=290, vmin=210,
            )
        axs.set_title(f"{stats}_88x88")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'max':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"ze_{stats}"], cmap='viridis', vmax=300, vmin=270,
            )
        axs.set_title(f"{stats}_88x88")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'q90':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"ze_{stats}"], cmap='viridis', vmax=300, vmin=220,
            )
        axs.set_title(f"{stats}_88x88")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    

    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/tsne_icon_my_{stats}_{compression}.png", dpi=300)
    plt.close()

In [10]:
stats = ['mean', 'std', 'max', 'q90']
for stat in stats:
    #plot_tsne(stat, compression="multiscale")
    #plot_tsne(stat, compression="annealing")
    plot_tsne(stat, compression="pca")

In [11]:
img_kwargs = {
    "model_108": {"vmin": 210, "vmax": 290, "cmap": 'Spectral_r'},
}

In [15]:
def images2d(
    x, y, i_samples, f, kwargs, file, subsample=2, zoom=0.1, figsize=(10, 10)
):

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    ax.scatter(
        x=x,
        y=y,
        lw=0,
        s=1,
        color="k",
        alpha=0.2,
    )

    for i_sample in i_samples:
        img = f(i_sample, **kwargs)[:, ::subsample]
        #print(img.size)
        imagebox = OffsetImage(img, zoom=zoom)
        imagebox.image.axes = ax
        ab = AnnotationBbox(
            imagebox,
            [x[i_sample], y[i_sample]],
            xycoords="data",
            frameon=True,
            box_alignment=(0.5, 0.5),
            pad=0,
            bboxprops=dict(edgecolor="#eeeeee", lw=1, facecolor="none"),
        )
        ax.add_artist(ab)

    ax.set_aspect("equal")
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    out_path = f"/p/project1/exaww/chatterjee1/plots/continuous/{file}_icon_multi_year.png"
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"wrote {out_path}")
    
def channel2rgb(i_sample, v):
    """
    Convert a specific channel to RGB with a fixed vmin and vmax.
    """

    norm = plt.Normalize(
        vmin=img_kwargs[v]["vmin"], vmax=img_kwargs[v]["vmax"]
    )
    colormap = plt.cm.Spectral_r
    # Apply colormap to normalized data
    rgb = colormap(norm(ds[v].isel(sample=i_sample).values))

    # Flip vertically
    rgb = rgb[::-1, :, :]

    # Convert to uint8 format
    return (rgb * 255).astype("uint8")

def channel2gray(i_sample, v):
    """
    Convert a specific channel to grayscale with a fixed vmin and vmax.
    """
    norm = plt.Normalize(
        vmin=img_kwargs[v]["vmin"], vmax=img_kwargs[v]["vmax"]
    )

    # Normalize data between 0 and 1
    gray = norm(ds[v].isel(sample=i_sample).values)

    # Flip vertically
    gray = gray[::-1, :]

    # Convert to uint8 (0-255) grayscale
    return (gray * 255).astype("uint8")  # Shape: (H, W)

In [18]:
N = ds.dims["sample"]
n_images = 1000
rng = np.random.default_rng(42)
i_samples = rng.choice(N, size=min(n_images, N), replace=False)
subsample = 1
zoom = 0.3
figsize = (10, 10)

x, y = read_features(compression='pca')

# individual channel with fixed value ranges
for v in ["model_108"]:
    images2d(
        x=x,
        y=y,
        i_samples=i_samples,
        f=channel2rgb,
        kwargs=dict(v=v),
        file=f"images2d_view1_icon{v}",
        subsample=subsample,
        zoom=zoom,
        figsize=figsize,
    )

wrote /p/project1/exaww/chatterjee1/plots/continuous/images2d_view1_iconmodel_108_icon_multi_year.png


## color bar

In [28]:
img_kwargs = {
    "model_108": {"vmin": 210, "vmax": 290, "cmap": "Spectral_r"},
}

# Color mapping for clusters
colors_per_class1 = {
    str(i): color for i, color in enumerate([
        '#0F3C5F', '#48714F', '#C49138', '#FDB9C2', '#393d76', '#e1bc3a',
        '#af362b', '#19cbf7', '#000000', '#0008fa', '#3e8245', '#7c646e',
        '#34065c', '#afa8ed', '#fc030f', '#699be0', '#999999', '#dffc03',
        '#dea4da', '#cc7445'
    ])
}

def channel2rgb(ds, i_sample, v):
    """Return uint8 image HxWx4 (RGBA) for sample i of variable v."""
    # checks
    if v not in ds.variables:
        raise KeyError(f"Variable '{v}' not found in dataset.")
    if "sample" not in ds[v].dims:
        raise ValueError(f"Variable '{v}' does not have 'sample' dim. Dims: {ds[v].dims}")
    if v not in img_kwargs:
        raise KeyError(f"img_kwargs missing entry for '{v}'.")

    da = ds[v].isel(sample=int(i_sample)).values  # expect (H, W)
    if da.ndim != 2:
        raise ValueError(f"Expected 2D after isel, got shape {da.shape} for '{v}'.")

    vmin = img_kwargs[v]["vmin"]
    vmax = img_kwargs[v]["vmax"]
    cmap = plt.get_cmap(img_kwargs[v].get("cmap", "Spectral_r"))

    # normalize safely
    scale = max(float(vmax) - float(vmin), 1e-6)
    normed = (np.clip(da, vmin, vmax) - vmin) / scale

    rgba = cmap(normed)              # (H,W,4) floats in [0,1]
    rgba = rgba[::-1, :, :]          # optional vertical flip
    return (rgba * 255).astype(np.uint8)

def read_cluster_labels_simple(path):
    obj = torch.load(path, map_location="cpu")
    labels = np.asarray(obj).astype(int).ravel()
    return labels

def plot_cluster_images_with_colorbar(
    ds,
    cluster_labels,
    cluster_indices,
    f,
    kwargs,
    file,
    n_samples_per_cluster=10,
    figsize_per=(2.0, 2.0)
):
    if "sample" not in ds.dims:
        raise ValueError(f"Dataset missing 'sample' dimension. Dims: {ds.dims}")
    n_total = int(ds.dims["sample"])

    cluster_labels = np.asarray(cluster_labels).astype(int).ravel()
    if cluster_labels.size != n_total:
        raise ValueError(f"cluster_labels length {cluster_labels.size} != dataset samples {n_total}")

    K = int(cluster_labels.max()) + 1
    n_cols = int(n_samples_per_cluster)
    if n_cols < 1:
        raise ValueError("n_samples_per_cluster must be >= 1")

    fig_w = max(6, n_cols * figsize_per[0])
    fig_h = max(6, K * figsize_per[1])
    fig, axes = plt.subplots(K, n_cols, figsize=(fig_w, fig_h), squeeze=False)

    # colorbar setup
    v = kwargs.get("v", None)
    if v is None or v not in img_kwargs:
        raise KeyError("kwargs must include 'v' present in img_kwargs.")
    vmin = img_kwargs[v]["vmin"]
    vmax = img_kwargs[v]["vmax"]
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    sm = cm.ScalarMappable(cmap=plt.get_cmap(img_kwargs[v].get("cmap", "Spectral_r")), norm=norm)
    sm.set_array([])

    rng = np.random.default_rng(45)

    # ensure indices are within bounds
    for cid, idxs in cluster_indices.items():
        if np.any((idxs < 0) | (idxs >= n_total)):
            raise IndexError(f"cluster_indices[{cid}] contains out-of-range indices.")

    for cluster_id in range(K):
        idx_all = cluster_indices.get(cluster_id, np.array([], dtype=int))
        n_take = min(n_samples_per_cluster, len(idx_all))
        picked = rng.choice(idx_all, size=n_take, replace=False) if n_take > 0 else np.array([], dtype=int)

        for j in range(n_cols):
            ax = axes[cluster_id, j]
            if j < n_take:
                i_sample = int(picked[j])
                img = f(ds, i_sample, **kwargs)
                ax.imshow(img)
                ax.set_xticks([]); ax.set_yticks([])
                bbox_color = colors_per_class1.get(str(cluster_id), "#000000")
                for spine in ax.spines.values():
                    spine.set_color(bbox_color)
                    spine.set_linewidth(2)
            else:
                ax.axis("off")

        axes[cluster_id, 0].set_ylabel(f"Cluster {cluster_id}", fontsize=12, rotation=90, labelpad=20)

    cbar_ax = fig.add_axes([0.92, 0.3, 0.02, 0.4])
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation="vertical")
    cbar.set_label("BT", fontsize=14)
    cbar.ax.tick_params(labelsize=10)

    plt.tight_layout(rect=[0, 0, 0.9, 1])

    out_dir = "/p/project1/exaww/chatterjee1/plots/continuous"
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{file}_clusteredrowrgbicon_my.png")
    plt.savefig(out_path, dpi=600, bbox_inches="tight")
    print(f"wrote {out_path}")

In [31]:
file_nc = "/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_icon.nc"
ds = xr.open_dataset(file_nc)

labels_path = "/p/scratch/exaww/chatterjee1/nn_obs/all_nc/features_icon/icon_cluster_10_labels_multi_year_.pth"
cluster_labels = read_cluster_labels_simple(labels_path)

K = int(cluster_labels.max()) + 1
cluster_indices = {i: np.where(cluster_labels == i)[0] for i in range(K)}

plot_cluster_images_with_colorbar(
    ds=ds,
    cluster_labels=cluster_labels,
    cluster_indices=cluster_indices,
    f=channel2rgb,
    kwargs=dict(v="model_108"),
    file="cluster_images2_rgb",
    n_samples_per_cluster=20,
    figsize_per=(2.0, 2.0),
)

  plt.tight_layout(rect=[0, 0, 0.9, 1])


wrote /p/project1/exaww/chatterjee1/plots/continuous/cluster_images2_rgb_clusteredrowrgbicon_my.png


## Cluster freq

In [32]:
def plot_cluster_frequencies(cluster_labels, colors_per_class1, out_prefix="cluster_freq_overall"):
    labels = np.asarray(cluster_labels).astype(int).ravel()
    K = int(labels.max()) + 1
    counts = np.bincount(labels, minlength=K)
    total = counts.sum()
    perc = 100.0 * counts / max(total, 1)

    # colors per bar
    bar_colors = [colors_per_class1.get(str(i), "#000000") for i in range(K)]

    fig, ax = plt.subplots(figsize=(10, 4))
    x = np.arange(K)
    ax.bar(x, counts, color=bar_colors, edgecolor="k", linewidth=0.8)
    ax.set_xlabel("Cluster")
    ax.set_ylabel("Count")
    ax.set_title("Cluster frequency (overall)")

    # annotate with %
    for i, (c, p) in enumerate(zip(counts, perc)):
        ax.text(i, c, f"{p:.1f}%", ha="center", va="bottom", fontsize=8, rotation=0)

    ax.set_xticks(x)
    ax.set_xticklabels([str(i) for i in range(K)])
    ax.grid(axis="y", linestyle=":", alpha=0.4)

    out_dir = "/p/project1/exaww/chatterjee1/plots/continuous"
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{out_prefix}.png")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"wrote {out_path}")

In [33]:
plot_cluster_frequencies(cluster_labels, colors_per_class1, out_prefix="iconcluster_freq_overall_10c")

wrote /p/project1/exaww/chatterjee1/plots/continuous/iconcluster_freq_overall_10c.png


## VMF softlabels

In [38]:
icon_dir = "/p/scratch/exaww/chatterjee1/nn_obs/all_nc/features_icon/"
obs_dir  = "/p/scratch/exaww/chatterjee1/nn_obs/all_nc/features_obs/"

labels = np.asarray(torch.load(os.path.join(icon_dir,"icon_labels_soft_vmf.pth")))
conf   = np.asarray(torch.load(os.path.join(icon_dir,"icon_labels_soft_vmf_confidence.pth")))

print("N:", labels.size)
print("rejected frac:", np.mean(labels==-1).round(4))

valid = labels >= 0
print("label counts (valid only):", np.bincount(labels[valid]))

print("conf mean/median:", conf.mean().round(4), np.median(conf).round(4))
print("conf p10/p90:", np.percentile(conf,[10,90]).round(4))

N: 18300
rejected frac: 0.0
label counts (valid only): [ 278 1017  227  271 1333  696 2484  542 1526 9926]
conf mean/median: 0.1607 0.161
conf p10/p90: [0.1298 0.1903]


In [39]:
X = np.asarray(torch.load(os.path.join(icon_dir,"trainfeat_new_multiyear.pth")), np.float32)
C = np.asarray(torch.load(os.path.join(obs_dir,"obs_final_10_centroids_multiyear.pth")), np.float32)

X /= np.maximum(np.linalg.norm(X, axis=1, keepdims=True), 1e-12)
C /= np.maximum(np.linalg.norm(C, axis=1, keepdims=True), 1e-12)

labels_soft = np.asarray(torch.load(os.path.join(icon_dir,"icon_labels_soft_vmf.pth")))
S = X @ C.T
labels_plain = S.argmax(1).astype(np.int32)

valid = labels_soft >= 0
print("Disagreement rate (valid only):", np.mean(labels_plain[valid] != labels_soft[valid]).round(4))

Disagreement rate (valid only): 0.0354


In [40]:
# Using labels from the soft method:
K = C.shape[0]
S = X @ C.T
for k in range(K):
    idx = np.where(labels_soft == k)[0]
    if idx.size == 0: 
        print(f"k={k}: no ICON points"); 
        continue
    mean_cos_icon = float(S[idx, k].mean())
    print(f"k={k}: mean cos (ICON→centroid) = {mean_cos_icon:.3f}")

k=0: mean cos (ICON→centroid) = 0.442
k=1: mean cos (ICON→centroid) = 0.344
k=2: mean cos (ICON→centroid) = 0.287
k=3: mean cos (ICON→centroid) = 0.346
k=4: mean cos (ICON→centroid) = 0.391
k=5: mean cos (ICON→centroid) = 0.396
k=6: mean cos (ICON→centroid) = 0.402
k=7: mean cos (ICON→centroid) = 0.331
k=8: mean cos (ICON→centroid) = 0.355
k=9: mean cos (ICON→centroid) = 0.413
