In [25]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from scipy.interpolate import griddata
import torch
import os
import matplotlib.cm as cm
import pandas as pd

In [2]:
# ---- single-file metrics (no dask) ----
def compute_metrics_one(ds: xr.Dataset, var="sample_data") -> xr.Dataset:
    """
    Computes per-sample metrics for one file.
    Expects ds[var] with dims ('sample','y','x') and carries along 'time' if present.
    """
    da = ds[var]  # (sample, y, x)

    mean_ = da.mean(("y", "x"))
    std_  = da.std(("y", "x"), ddof=0)
    max_  = da.max(("y", "x"))
    q90   = da.quantile(0.9, dim=("y", "x")).squeeze(drop=True)  # drop 'quantile' dim

    out = xr.Dataset(
        {
            "msg_108_mean": mean_,
            "msg_108_std":  std_,
            "msg_108_max":  max_,
            "msg_108_q90":  q90,
        }
    )

    # keep time coordinate if available (shape: sample)
    if "time" in ds:
        out = out.assign_coords(time=ds["time"])

    return out


# ---- multi-year driver (dask-free, concat along sample) ----
def compute_metrics_multi_year(nc_paths, var="sample_data", engine="netcdf4") -> xr.Dataset:
    parts = []
    offset = 0

    for p in nc_paths:
        with xr.open_dataset(p, engine=engine, mask_and_scale=False) as ds:
            # compute per-file metrics
            dsm = compute_metrics_one(ds, var=var)

            # give each file a unique running sample index to allow concat
            n = dsm.sizes["sample"]
            dsm = dsm.assign_coords(sample=np.arange(offset, offset + n))
            offset += n

            # optional: annotate source file/year for debugging
            dsm = dsm.assign_coords(source_file=os.path.basename(p))
            parts.append(dsm)

    # concatenate along the (now unique) sample dimension
    ds_metrics = xr.concat(parts, dim="sample")

    # ensure sample is an index
    ds_metrics = ds_metrics.sortby("sample")
    return ds_metrics


# ---- write / read helpers ----
def write_metrics(ds_metrics: xr.Dataset, out_path: str):
    ds_metrics.to_netcdf(out_path)

def read_metrics(in_path: str) -> xr.Dataset:
    return xr.open_dataset(in_path)

In [3]:
years = [
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2020.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2021.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2022.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2023.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2024.nc',
]

out_path = "/p/project/exaww/chatterjee1/dataset/warmworld_datasets/metrics_msgobs_my_108_randomcrops.nc"
recompute = False  # set False to reuse written file

if recompute:
    ds_metrics = compute_metrics_multi_year(years, var="sample_data")
    write_metrics(ds_metrics, out_path)
else:
    ds_metrics = read_metrics(out_path)

print(ds_metrics)
print(ds_metrics.data_vars)

<xarray.Dataset>
Dimensions:       (sample: 383928)
Coordinates:
  * sample        (sample) int64 0 1 2 3 4 ... 383924 383925 383926 383927
    time          (sample) object ...
    quantile      float64 ...
    source_file   (sample) object ...
Data variables:
    msg_108_mean  (sample) float32 ...
    msg_108_std   (sample) float32 ...
    msg_108_max   (sample) float32 ...
    msg_108_q90   (sample) float64 ...
Data variables:
    msg_108_mean  (sample) float32 ...
    msg_108_std   (sample) float32 ...
    msg_108_max   (sample) float32 ...
    msg_108_q90   (sample) float64 ...


In [4]:
def read_features(compression="multiscale"):
    if compression == "multiscale":
        fext = "500multiscale50"
    elif compression == "annealing":
        fext = "500annealing50"
    elif compression == "pca":
        fext = "pcacosine"
    else:
        raise ValueError(f"Unknown compression {compression}")

    ds_tsne = np.load(
        f"/p/scratch/exaww/chatterjee1/nn_obs/all_nc/features_obs/tsne_obs_my_{fext}.npy"
    )
    x = ds_tsne[:, 0]
    y = ds_tsne[:, 1]
    return x, y

In [5]:
def plot_tsne(stats, compression):
    fig, axs = plt.subplots(1, 1, figsize=(7, 5), layout="constrained")

    x, y = read_features(compression=compression)
    if stats == 'std':
        im = axs.scatter(
        x, y, s=1, c=ds_metrics[f"msg_108_{stats}"], cmap='viridis', vmax=35, vmin=0,
        )
        axs.set_title(f"{stats}_128x128")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'mean':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"msg_108_{stats}"], cmap='viridis', vmax=290, vmin=210,
            )
        axs.set_title(f"{stats}_128x128")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'max':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"msg_108_{stats}"], cmap='viridis', vmax=300, vmin=270,
            )
        axs.set_title(f"{stats}_128x128")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'q90':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"msg_108_{stats}"], cmap='viridis', vmax=300, vmin=220,
            )
        axs.set_title(f"{stats}_128x128")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    

    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/tsne_obs_my_{stats}_{compression}.png", dpi=300)
    plt.close()

In [13]:
stats = ['mean', 'std', 'max', 'q90']
for stat in stats:
    #plot_tsne(stat, compression="multiscale")
    #plot_tsne(stat, compression="annealing")
    plot_tsne(stat, compression="pca")

In [8]:
# ---------- 1) Open multi-year dataset (no dask) ----------
def open_multiyear(paths, var="sample_data", engine="netcdf4"):
    parts = []
    offset = 0
    for p in paths:
        ds = xr.open_dataset(p, engine=engine, mask_and_scale=False)
        n = int(ds.dims["sample"])
        # make a unique running sample index to avoid collisions
        ds = ds.assign_coords(sample=np.arange(offset, offset + n))
        offset += n
        parts.append(ds)
    ds_all = xr.concat(parts, dim="sample")
    ds_all = ds_all.sortby("sample")
    return ds_all

# ---------- 2) Load 2D features (must match multi-year sample order) ----------
def read_features(compression="multiscale"):
    if compression == "multiscale":
        fext = "500multiscale50"
    elif compression == "annealing":
        fext = "500annealing50"
    elif compression == "pca":
        fext = "pcacosine"
    else:
        raise ValueError(f"Unknown compression {compression}")
    arr = np.load(f"/p/scratch/exaww/chatterjee1/nn_obs/all_nc/features_obs/tsne_obs_my_{fext}.npy")
    # expect shape [N,2]
    assert arr.ndim == 2 and arr.shape[1] == 2, f"features shape {arr.shape} not Nx2"
    return arr[:, 0], arr[:, 1]

# ---------- 3) RGB conversion for a single channel ----------
img_kwargs = {
    "sample_data": {"vmin": 210, "vmax": 290, "cmap": "Spectral_r"},
    # add other variables here if you want to visualize them too
}

def channel2rgb(ds, i_sample, v):
    """Return uint8 image HxWx3 or HxWx4 for sample i of variable v."""
    da = ds[v].isel(sample=int(i_sample)).values  # (H,W)
    vmin = img_kwargs[v]["vmin"]
    vmax = img_kwargs[v]["vmax"]
    cmap = plt.get_cmap(img_kwargs[v].get("cmap", "Spectral_r"))
    normed = (np.clip(da, vmin, vmax) - vmin) / max(vmax - vmin, 1e-6)
    rgba = cmap(normed)  # (H,W,4), float in [0,1]
    # flip vertically (optional, if desired)
    rgba = rgba[::-1, :, :]
    return (rgba * 255).astype(np.uint8)  # (H,W,4)

# ---------- 4) Plotting ----------
def images2d(ds, x, y, i_samples, img_fn, img_kwargs, out_prefix,
             subsample=1, zoom=0.2, figsize=(10,10)):
    assert len(x) == len(y) == ds.dims["sample"], (
        f"Length mismatch: x={len(x)}, y={len(y)}, ds.samples={ds.dims['sample']}"
    )

    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.scatter(x=x, y=y, lw=0, s=1, color="k", alpha=0.2)

    for i in i_samples:
        # build the image and subsample both H and W
        img = img_fn(ds, i, **img_kwargs)
        if subsample > 1:
            img = img[::subsample, ::subsample, ...]
        # HxW or HxWx3/4 is fine for OffsetImage
        ab = AnnotationBbox(
            OffsetImage(img, zoom=zoom),
            (x[int(i)], y[int(i)]),
            xycoords="data",
            frameon=True,
            box_alignment=(0.5, 0.5),
            pad=0,
            bboxprops=dict(edgecolor="#eeeeee", lw=1, facecolor="none"),
        )
        ax.add_artist(ab)

    ax.set_aspect("equal")
    ax.set_xticks([]); ax.set_yticks([])

    out_path = f"/p/project1/exaww/chatterjee1/plots/continuous/{out_prefix}_multi_year.png"
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"wrote {out_path}")

In [9]:
# ---------- 5) Wire everything ----------
years = [
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2020.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2021.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2022.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2023.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2024.nc',
]

# 1) open multi-year ds
ds = open_multiyear(years, var="sample_data")

# 2) read the 2D projection (must have been computed over the same multi-year ordering)
x, y = read_features(compression="pca")

# 3) pick random samples to overlay
N = ds.dims["sample"]
n_images = 1000
rng = np.random.default_rng(42)
i_samples = rng.choice(N, size=min(n_images, N), replace=False)

# 4) render
images2d(
    ds=ds,
    x=x, y=y,
    i_samples=i_samples,
    img_fn=channel2rgb,
    img_kwargs=dict(v="sample_data"),
    out_prefix="images2d_view_msg_108",
    subsample=1,
    zoom=0.20,
    figsize=(10,10),
)

wrote /p/project1/exaww/chatterjee1/plots/continuous/images2d_view_msg_108_multi_year.png


## Cluster the continuum

In [16]:
# ---------- Multi-year dataset ----------
def open_multiyear(paths, var="sample_data", engine="netcdf4"):
    parts = []
    offset = 0
    for p in paths:
        ds = xr.open_dataset(p, engine=engine, mask_and_scale=False)
        n = int(ds.dims["sample"])
        ds = ds.assign_coords(sample=np.arange(offset, offset + n))  # unique running index
        offset += n
        parts.append(ds)
    ds_all = xr.concat(parts, dim="sample").sortby("sample")
    return ds_all

# ---------- 2D features ----------
def read_features(compression="multiscale"):
    if compression == "multiscale":
        fext = "500multiscale50"
    elif compression == "annealing":
        fext = "500annealing50"
    elif compression == "pca":
        fext = "pcacosine"
    else:
        raise ValueError(f"Unknown compression {compression}")
    arr = np.load(f"/p/scratch/exaww/chatterjee1/nn_obs/all_nc/features_obs/tsne_obs_my_{fext}.npy")
    assert arr.ndim == 2 and arr.shape[1] == 2, f"features shape {arr.shape} not Nx2"
    return arr[:, 0], arr[:, 1]

def read_cluster_labels_simple(path):
    obj = torch.load(path, map_location="cpu")   # "cpu" is the safe spelling
    labels = np.asarray(obj).astype(int).ravel() # ensure 1D int array
    return labels

# ---------- Channel -> RGBA image ----------
img_kwargs = {
    "sample_data": {"vmin": 210, "vmax": 290, "cmap": "Spectral_r"},
}

def channel2rgb(ds, i_sample, v):
    da = ds[v].isel(sample=int(i_sample)).values  # (H,W)
    cfg = img_kwargs[v]
    vmin, vmax = cfg["vmin"], cfg["vmax"]
    cmap = plt.get_cmap(cfg.get("cmap", "Spectral_r"))
    normed = (np.clip(da, vmin, vmax) - vmin) / max(vmax - vmin, 1e-6)
    rgba = cmap(normed)            # (H,W,4) floats
    rgba = rgba[::-1, :, :]        # optional flip
    return (rgba * 255).astype(np.uint8)

# ---------- Colors per cluster (extend as needed) ----------
def default_cluster_colors(K=20):
    palette = [
        '#0F3C5F', '#48714F', '#C49138', '#FDB9C2', '#393d76', '#e1bc3a',
        '#af362b', '#19cbf7', '#000000', '#0008fa', '#3e8245', '#7c646e',
        '#34065c', '#afa8ed', '#fc030f', '#699be0', '#999999', '#dffc03',
        '#dea4da', '#cc7445'
    ]
    if K <= len(palette):
        return {i: palette[i] for i in range(K)}
    # repeat if more clusters than colors provided
    return {i: palette[i % len(palette)] for i in range(K)}

# ---------- 2D images with cluster-colored boxes ----------
def images2d_clusters(ds, x, y, i_samples, img_fn, img_kwargs, labels, out_prefix,
                      subsample=1, zoom=0.2, figsize=(10,10), colors=None):
    N = ds.dims["sample"]
    assert len(x) == len(y) == N, f"len(x)={len(x)}, len(y)={len(y)}, N={N}"
    assert len(labels) == N, f"len(labels)={len(labels)} must equal N={N}"
    if colors is None:
        colors = default_cluster_colors(K=int(labels.max())+1)

    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.scatter(x=x, y=y, lw=0, s=1, color="k", alpha=0.2)

    for i in i_samples:
        i = int(i)
        img = img_fn(ds, i, **img_kwargs)
        if subsample > 1:
            img = img[::subsample, ::subsample, ...]
        c = colors.get(int(labels[i]), "#000000")
        ab = AnnotationBbox(
            OffsetImage(img, zoom=zoom),
            (float(x[i]), float(y[i])),
            xycoords="data",
            frameon=True,
            box_alignment=(0.5, 0.5),
            pad=0,
            bboxprops=dict(edgecolor=c, lw=2, facecolor="none"),
        )
        ax.add_artist(ab)

    ax.set_aspect("equal"); ax.set_xticks([]); ax.set_yticks([])
    out_dir = "/p/project1/exaww/chatterjee1/plots/continuous/"
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{out_prefix}_10_cluster_multiyear.png")
    plt.savefig(out_path, dpi=100, bbox_inches="tight")
    plt.close()
    print(f"wrote {out_path}")

In [17]:
years = [
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2020.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2021.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2022.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2023.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2024.nc',
]

# 1) open multi-year ds (ordering must match what you used for feature extraction)
ds = open_multiyear(years, var="sample_data")

# 2) read 2D projection
x, y = read_features(compression="pca")

# 3) load cluster labels
common_path = "/p/scratch/exaww/chatterjee1/nn_obs/all_nc/features_obs/"
labels_path = os.path.join(common_path, "obs_cluster_10_labels_my.pth")
cluster_labels = read_cluster_labels_simple(labels_path)
print(cluster_labels.shape, cluster_labels.dtype)

# 4) choose which samples to annotate
N = ds.dims["sample"]
assert len(x) == len(y) == N, "x/y length must match dataset samples"
assert len(cluster_labels) == N, "labels length must match dataset samples"

rng = np.random.default_rng(42)
i_samples = rng.choice(N, size=min(1000, N), replace=False)

# 5) plot
images2d_clusters(
    ds=ds,
    x=x, y=y,
    i_samples=i_samples,
    img_fn=channel2rgb,
    img_kwargs=dict(v="sample_data"),
    labels=cluster_labels,
    out_prefix="images2d_view_mas_108_multi_year_cluster",
    subsample=1,
    zoom=0.2,
    figsize=(10,10),
    colors=default_cluster_colors(K=int(cluster_labels.max())+1),
)

(383928,) int64
wrote /p/project1/exaww/chatterjee1/plots/continuous/images2d_view_mas_108_multi_year_cluster_10_cluster_multiyear.png


## color bar

In [33]:
years = [
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2020.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2021.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2022.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2023.nc',
    '/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2024.nc',
]

# 1) open multi-year ds (ordering must match what you used for feature extraction)
ds = open_multiyear(years, var="sample_data")

def plot_cluster_images_with_colorbar(
    ds,
    cluster_labels,
    cluster_indices,
    f,                   # callable like channel2rgb(ds, i_sample, v=...)
    kwargs,              # e.g., dict(v="sample_data")
    file,
    n_samples_per_cluster=10,
    figsize_per=(2.0, 2.0),   # (width,height) per cell in inches
):
    K = int(np.max(cluster_labels)) + 1
    # determine the max columns we will actually need (in case some clusters < requested)
    n_cols = n_samples_per_cluster

    # build figure
    fig_w = max(6, n_cols * figsize_per[0])
    fig_h = max(6, K * figsize_per[1])
    fig, axes = plt.subplots(K, n_cols, figsize=(fig_w, fig_h), squeeze=False)

    # colorbar setup (matches your Spectral_r mapping)
    vmin = img_kwargs["sample_data"]["vmin"]
    vmax = img_kwargs["sample_data"]["vmax"]
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    sm = cm.ScalarMappable(cmap=plt.cm.Spectral_r, norm=norm)
    sm.set_array([])

    rng = np.random.default_rng(45)  # reproducible sampling

    for cluster_id in range(K):
        idx_all = cluster_indices.get(cluster_id, np.array([], dtype=int))
        # choose up to n_samples_per_cluster without replacement (if enough exist)
        n_take = min(n_samples_per_cluster, len(idx_all))
        if n_take > 0:
            picked = rng.choice(idx_all, size=n_take, replace=False)
        else:
            picked = np.array([], dtype=int)

        # plot selected samples in the row
        for j in range(n_cols):
            ax = axes[cluster_id, j]
            if j < n_take:
                i_sample = int(picked[j])
                img = f(ds, i_sample, **kwargs)  # pass ds explicitly
                ax.imshow(img)                   # RGBA or RGB is fine
                ax.set_xticks([]); ax.set_yticks([])
                # border color by cluster
                bbox_color = colors_per_class1.get(str(cluster_id), "#000000")
                for spine in ax.spines.values():
                    spine.set_color(bbox_color)
                    spine.set_linewidth(2)
            else:
                # hide unused cells in this row
                ax.axis("off")

        # row label on first used column (or first col regardless)
        axes[cluster_id, 0].set_ylabel(f"Cluster {cluster_id}", fontsize=12, rotation=90, labelpad=12)

    # single colorbar on the right
    cbar_ax = fig.add_axes([0.92, 0.3, 0.02, 0.4])
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation="vertical")
    cbar.set_label("BT", fontsize=14)
    cbar.ax.tick_params(labelsize=10)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    out_path = f"/p/project1/exaww/chatterjee1/plots/continuous/{file}_clustered_raw_multi_year.png"
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"wrote {out_path}")

In [35]:
# Color mapping for clusters
colors_per_class1 = {
    str(i): color for i, color in enumerate([
        '#0F3C5F', '#48714F', '#C49138', '#FDB9C2', '#393d76', '#e1bc3a', 
        '#af362b', '#19cbf7', '#000000', '#0008fa', '#3e8245', '#7c646e', 
        '#34065c', '#afa8ed', '#fc030f', '#699be0', '#999999', '#dffc03', 
        '#dea4da', '#cc7445'
    ])
}

K = int(np.max(cluster_labels)) + 1
cluster_indices = {i: np.where(cluster_labels == i)[0] for i in range(K)}

plot_cluster_images_with_colorbar(
    ds=ds,
    cluster_labels=cluster_labels,
    cluster_indices=cluster_indices,
    f=channel2rgb,
    kwargs=dict(v="sample_data"),
    file="cluster_images3",
    n_samples_per_cluster=20,   # 10 per class
    figsize_per=(2.0, 2.0),
)

  plt.tight_layout(rect=[0, 0, 0.9, 1])


wrote /p/project1/exaww/chatterjee1/plots/continuous/cluster_images3_clustered_raw_multi_year.png


## stacked per year FoOc

In [28]:
def plot_cluster_frequencies_per_year(ds, cluster_labels, colors_per_class1, out_prefix="cluster_freq_per_year"):
    labels = np.asarray(cluster_labels).astype(int).ravel()
    N = ds.dims["sample"]
    assert len(labels) == N, f"labels length {len(labels)} must equal samples {N}"

    # --- parse time strings to datetime ---
    # format is like '20200401001242' → YYYYMMDDHHMMSS
    years = pd.to_datetime(ds["time"].values, format="%Y%m%d%H%M%S").year
    unique_years = np.unique(years)

    K = int(labels.max()) + 1
    counts = np.zeros((len(unique_years), K), dtype=int)

    # --- count per (year, cluster) ---
    for yi, yr in enumerate(unique_years):
        mask = (years == yr)
        if mask.any():
            cl = labels[mask]
            counts[yi] = np.bincount(cl, minlength=K)

    # --- stacked bar plot ---
    x = np.arange(len(unique_years))
    fig, ax = plt.subplots(figsize=(max(8, len(unique_years)*1.2), 5))

    cum = np.zeros(len(unique_years), dtype=float)
    for k in range(K):
        c = counts[:, k]
        ax.bar(x, c, bottom=cum,
               color=colors_per_class1.get(str(k), "#000000"),
               edgecolor="white", linewidth=0.5, label=f"C{k}")
        cum += c

    ax.set_xticks(x)
    ax.set_xticklabels([str(int(y)) for y in unique_years], rotation=0)
    ax.set_ylabel("Count")
    ax.set_title("Cluster frequency per year (stacked)")
    ax.grid(axis="y", linestyle=":", alpha=0.4)
    ax.legend(ncol=min(5, K), fontsize=8, loc="upper left", bbox_to_anchor=(1.01, 1.0))

    out_dir = "/p/project1/exaww/chatterjee1/plots/continuous"
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{out_prefix}.png")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"wrote {out_path}")

In [29]:
plot_cluster_frequencies_per_year(
    ds, cluster_labels, colors_per_class1,
    out_prefix="cluster_freq_per_year_10_multiyear"
)

wrote /p/project1/exaww/chatterjee1/plots/continuous/cluster_freq_per_year_10_multiyear.png


## Cluster freq

In [41]:
def plot_cluster_frequencies(cluster_labels, colors_per_class1, out_prefix="cluster_freq_overall"):
    labels = np.asarray(cluster_labels).astype(int).ravel()
    K = int(labels.max()) + 1
    counts = np.bincount(labels, minlength=K)
    total = counts.sum()
    perc = 100.0 * counts / max(total, 1)

    # colors per bar
    bar_colors = [colors_per_class1.get(str(i), "#000000") for i in range(K)]

    fig, ax = plt.subplots(figsize=(10, 4))
    x = np.arange(K)
    ax.bar(x, counts, color=bar_colors, edgecolor="k", linewidth=0.8)
    ax.set_xlabel("Cluster")
    ax.set_ylabel("Count")
    ax.set_title("Cluster frequency (overall)")

    # annotate with %
    for i, (c, p) in enumerate(zip(counts, perc)):
        ax.text(i, c, f"{p:.1f}%", ha="center", va="bottom", fontsize=8, rotation=0)

    ax.set_xticks(x)
    ax.set_xticklabels([str(i) for i in range(K)])
    ax.grid(axis="y", linestyle=":", alpha=0.4)

    out_dir = "/p/project1/exaww/chatterjee1/plots/continuous"
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{out_prefix}.png")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"wrote {out_path}")

In [42]:
plot_cluster_frequencies(cluster_labels, colors_per_class1, out_prefix="obscluster_freq_overall_10c")

wrote /p/project1/exaww/chatterjee1/plots/continuous/obscluster_freq_overall_10c.png


In [43]:
x = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_lincrops.nc')
x