In [26]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from scipy.interpolate import griddata
import torch
import os
import matplotlib.cm as cm
from matplotlib.cm import ScalarMappable
import matplotlib.colors as mcolors

In [12]:
file = "/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops.nc"
ds = xr.open_dataset(file)

In [13]:
ds

In [14]:
ds.sample_data.shape

(74268, 128, 128)

In [15]:
def compute_metrics(ds):
    """
    Computes metrics for each sample.
    """


    # radar reflectivity statistics
    da_ze_mean = ds.sample_data.mean(["y", "x"])  # TODO: in dB?
    da_ze_std = ds.sample_data.std(["y", "x"])
    da_ze_max = ds.sample_data.max(["y", "x"])
    da_ze_q90 = ds.sample_data.quantile(0.9, ["y", "x"])


    # combine to single dataset
    ds_metrics = xr.Dataset(
        {
            "ze_mean": da_ze_mean,
            "ze_std": da_ze_std,
            "ze_max": da_ze_max,
            "ze_q90": da_ze_q90,
        }
    )

    return ds_metrics


def write_metrics(ds_metrics):
    ds_metrics.to_netcdf(
        "/p/project/exaww/chatterjee1/dataset/metrics_msgobs_108_randomcrops.nc"
    )


def read_metrics():
    return xr.open_dataset(
        "/p/project/exaww/chatterjee1/dataset/metrics_msgobs_108_randomcrops.nc"
    )

In [16]:
recompute = False

if recompute:
    ds_metrics = compute_metrics(ds)
    write_metrics(ds_metrics)
else:
    ds_metrics = read_metrics()

print(ds_metrics.data_vars)

Data variables:
    ze_mean  (sample) float32 ...
    ze_std   (sample) float32 ...
    ze_max   (sample) float32 ...
    ze_q90   (sample) float64 ...


In [5]:
ds_metrics.ze_std.max(),ds_metrics.ze_std.min()

(<xarray.DataArray 'ze_std' ()>
 array(34.8343544)
 Coordinates:
     quantile  float64 ...,
 <xarray.DataArray 'ze_std' ()>
 array(0.71462017)
 Coordinates:
     quantile  float64 ...)

In [6]:
ds_metrics.ze_q90.max(),ds_metrics.ze_q90.min()

(<xarray.DataArray 'ze_q90' ()>
 array(304.37316895)
 Coordinates:
     quantile  float64 ...,
 <xarray.DataArray 'ze_q90' ()>
 array(215.91322327)
 Coordinates:
     quantile  float64 ...)

In [17]:
def read_features(compression="multiscale"):
    if compression == "multiscale":
        fext = "500multiscale50"
    elif compression == "annealing":
        fext = "500annealing50"
    elif compression == "pca":
        fext = "pcacosine"
    else:
        raise ValueError(f"Unknown compression {compression}")

    ds_tsne = np.load(
        f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/features_obs/tsne_msg_{fext}.npy"
    )
    x = ds_tsne[:, 0]
    y = ds_tsne[:, 1]
    return x, y

In [18]:
def plot_tsne(stats, compression):
    fig, axs = plt.subplots(1, 1, figsize=(7, 5), layout="constrained")

    x, y = read_features(compression=compression)
    if stats == 'std':
        im = axs.scatter(
        x, y, s=1, c=ds_metrics[f"ze_{stats}"], cmap='viridis', vmax=35, vmin=0,
        )
        axs.set_title(f"{stats}_128x128")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'mean':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"ze_{stats}"], cmap='viridis', vmax=290, vmin=210,
            )
        axs.set_title(f"{stats}_128x128")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'max':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"ze_{stats}"], cmap='viridis', vmax=300, vmin=270,
            )
        axs.set_title(f"{stats}_128x128")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    elif stats == 'q90':
        im = axs.scatter(
            x, y, s=1, c=ds_metrics[f"ze_{stats}"], cmap='viridis', vmax=300, vmin=220,
            )
        axs.set_title(f"{stats}_128x128")
        fig.colorbar(im, ax=axs, label=f"10.8 BT {stats}")
    

    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/tsne_{stats}_{compression}_june.png", dpi=300)
    plt.close()

In [10]:
stats = ['mean', 'std', 'max', 'q90']
for stat in stats:
    #plot_tsne(stat, compression="multiscale")
    #plot_tsne(stat, compression="annealing")
    plot_tsne(stat, compression="pca")

In [11]:
img_kwargs = {
    "sample_data": {"vmin": 210, "vmax": 290, "cmap": 'viridis'},
}

In [28]:
def images2d(
    x, y, i_samples, f, kwargs, file, subsample=2, zoom=0.1, figsize=(10, 10)
):

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    ax.scatter(
        x=x,
        y=y,
        lw=0,
        s=1,
        color="k",
        alpha=0.2,
    )

    for i_sample in i_samples:
        img = f(i_sample, **kwargs)[:, ::subsample]
        #print(img.size)
        imagebox = OffsetImage(img, zoom=zoom)
        imagebox.image.axes = ax
        ab = AnnotationBbox(
            imagebox,
            [x[i_sample], y[i_sample]],
            xycoords="data",
            frameon=True,
            box_alignment=(0.5, 0.5),
            pad=0,
            bboxprops=dict(edgecolor="#eeeeee", lw=1, facecolor="none"),
        )
        ax.add_artist(ab)

    ax.set_aspect("equal")
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/{file}_june.png", dpi=600, bbox_inches="tight")
    plt.close()
    
def images2d_colorbar(
    x, y, i_samples, f, kwargs, file, subsample=2, zoom=0.1, figsize=(10, 10)
):
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    ax.scatter(
        x=x,
        y=y,
        lw=0,
        s=1,
        color="k",
        alpha=0.2,
    )

    norm = None
    cmap = None
    sample_img = None

    # Plot image boxes
    for i_sample in i_samples:
        img = f(i_sample, **kwargs)[:, ::subsample]
        sample_img = img  # Save for colorbar
        if norm is None:
            norm = mcolors.Normalize(vmin=img_kwargs[v]["vmin"], vmax=img_kwargs[v]["vmax"])
            cmap = plt.get_cmap('Greys')  # Change if you use a different colormap

        imagebox = OffsetImage(img, zoom=zoom, cmap=cmap, norm=norm)
        imagebox.image.axes = ax
        ab = AnnotationBbox(
            imagebox,
            [x[i_sample], y[i_sample]],
            xycoords="data",
            frameon=True,
            box_alignment=(0.5, 0.5),
            pad=0,
            bboxprops=dict(edgecolor="#eeeeee", lw=1, facecolor="none"),
        )
        ax.add_artist(ab)

    ax.set_aspect("equal")
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Add colorbar
    if sample_img is not None:
        sm = ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.01, shrink=0.6, aspect=30)
        cbar.ax.tick_params(labelsize=8)

    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/{file}_june.png", dpi=600, bbox_inches="tight")
    plt.close()
    
def channel2rgb(i_sample, v):
    """
    Convert a specific channel to RGB with a fixed vmin and vmax.
    """

    norm = plt.Normalize(
        vmin=img_kwargs[v]["vmin"], vmax=img_kwargs[v]["vmax"]
    )
    colormap = plt.cm.Greys
    # Apply colormap to normalized data
    rgb = colormap(norm(ds[v].isel(sample=i_sample).values))

    # Flip vertically
    rgb = rgb[::-1, :, :]

    # Convert to uint8 format
    return (rgb * 255).astype("uint8")


def channel2gray(i_sample, v):
    """
    Convert a specific channel to grayscale with a fixed vmin and vmax.
    """
    norm = plt.Normalize(
        vmin=img_kwargs[v]["vmin"], vmax=img_kwargs[v]["vmax"]
    )

    # Normalize data between 0 and 1
    gray = norm(ds[v].isel(sample=i_sample).values)

    # Flip vertically
    gray = gray[::-1, :]

    # Convert to uint8 (0-255) grayscale
    return (gray * 255).astype("uint8")  # Shape: (H, W)

In [29]:
n_images = 1000
i_samples = np.random.choice(len(ds.sample_data), size=n_images, replace=False)
subsample = 1
zoom = 0.2
figsize = (10, 10)

x, y = read_features(compression='pca')

# individual channel with fixed value ranges
for v in ["sample_data"]:
    images2d_colorbar(
        x=x,
        y=y,
        i_samples=i_samples,
        f=channel2rgb,
        kwargs=dict(v=v),
        file=f"continuous{v}",
        subsample=subsample,
        zoom=zoom,
        figsize=figsize,
    )


In [50]:
len(ds.sample_data)

74268

# plot classes 

## tsne

In [16]:
# Load cluster labels
common_path = "/p/project/exaww/chatterjee1/mcspss_continuous/analysis/features_obs/"
file_path = os.path.join(common_path, "obs_cluster_10_labels_new.pth")

if os.path.exists(file_path):
    print("File exists!")
else:
    print("File NOT found!")
    
    
data = torch.load(file_path, map_location="cpu")  # Ensure it's loaded to CPU

print("Type of loaded data:", type(data))

if isinstance(data, torch.Tensor):
    print("Shape of tensor:", data.shape)
elif isinstance(data, dict):
    print("Keys in the dictionary:", data.keys())
elif isinstance(data, list):
    print("Length of list:", len(data))    

cluster_labels = torch.load(os.path.join(common_path, "obs_cluster_10_labels.pth"), map_location="cpu")

# Color mapping for clusters
colors_per_class1 = {
    str(i): color for i, color in enumerate([
        '#0F3C5F', '#48714F', '#C49138', '#FDB9C2', '#393d76', '#e1bc3a', 
        '#af362b', '#19cbf7', '#000000', '#0008fa', '#3e8245', '#7c646e', 
        '#34065c', '#afa8ed', '#fc030f', '#699be0', '#999999', '#dffc03', 
        '#dea4da', '#cc7445'
    ])
}

# Function to plot images with cluster-based color boxes
def images2d(
    x, y, i_samples, f, kwargs, file, cluster_labels, subsample=2, zoom=0.1, figsize=(10, 10)
):
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    ax.scatter(
        x=x,
        y=y,
        lw=0,
        s=1,
        color="k",
        alpha=0.2,
    )

    for i_sample in i_samples:
        img = f(i_sample, **kwargs)[:, ::subsample]

        # Get the cluster label and corresponding color
        cluster_id = str(cluster_labels[i_sample])
        bbox_color = colors_per_class1.get(cluster_id, "#000000")  # Default black if missing

        imagebox = OffsetImage(img, zoom=zoom)
        imagebox.image.axes = ax
        ab = AnnotationBbox(
            imagebox,
            [x[i_sample], y[i_sample]],
            xycoords="data",
            frameon=True,
            box_alignment=(0.5, 0.5),
            pad=0,
            bboxprops=dict(edgecolor=bbox_color, lw=2, facecolor="none"),  # Set box color
        )
        ax.add_artist(ab)

    ax.set_aspect("equal")
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/{file}_10_cluster_may.png", dpi=600, bbox_inches="tight")
    plt.close()

# Load t-SNE embeddings
x, y = read_features(compression='pca')

# Randomly sample images
n_images = 1000
i_samples = np.random.choice(len(cluster_labels), size=n_images, replace=False)
subsample = 1
zoom = 0.2
figsize = (10, 10)

# Plot images with cluster color coding
for v in ["sample_data"]:
    images2d(
        x=x,
        y=y,
        i_samples=i_samples,
        f=channel2rgb,
        kwargs=dict(v=v),
        file=f"images2d_view{v}",
        cluster_labels=cluster_labels,  # Pass cluster labels
        subsample=subsample,
        zoom=zoom,
        figsize=figsize,
    )

File exists!
Type of loaded data: <class 'numpy.ndarray'>


  data = torch.load(file_path, map_location="cpu")  # Ensure it's loaded to CPU
  cluster_labels = torch.load(os.path.join(common_path, "obs_cluster_10_labels.pth"), map_location="cpu")


## Cluster wise 10 random images

### WITHOUT COLORBAR

In [18]:

# Number of clusters
K = len(np.unique(cluster_labels))  # Assuming clusters are labeled from 0 to K-1
n_samples_per_cluster = 10  # Number of images per cluster

# Collect indices for each cluster
cluster_indices = {i: np.where(cluster_labels == i)[0] for i in range(K)}

# Function to plot images row-wise per cluster
def plot_cluster_images(x, y, cluster_indices, f, kwargs, file, subsample=2, zoom=0.1):
    fig, axes = plt.subplots(K, n_samples_per_cluster, figsize=(2 * n_samples_per_cluster, 2 * K))

    for cluster_id in range(K):
        # Randomly sample 10 indices from this cluster (if available)
        indices = np.random.choice(cluster_indices[cluster_id], size=min(n_samples_per_cluster, len(cluster_indices[cluster_id])), replace=False)

        for j, i_sample in enumerate(indices):
            ax = axes[cluster_id, j]  # Select subplot
            
            # Get image
            img = f(i_sample, **kwargs)[:, ::subsample]
            
            # Plot the image
            ax.imshow(img)
            ax.set_xticks([])
            ax.set_yticks([])

            # Set border color based on cluster
            bbox_color = colors_per_class1.get(str(cluster_id), "#000000")
            ax.spines['top'].set_color(bbox_color)
            ax.spines['bottom'].set_color(bbox_color)
            ax.spines['left'].set_color(bbox_color)
            ax.spines['right'].set_color(bbox_color)
            ax.spines['top'].set_linewidth(2)
            ax.spines['bottom'].set_linewidth(2)
            ax.spines['left'].set_linewidth(2)
            ax.spines['right'].set_linewidth(2)

        # Label the row with the cluster ID
        axes[cluster_id, 0].set_ylabel(f"Cluster {cluster_id}", fontsize=12, rotation=90, labelpad=20)

    plt.tight_layout()
    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/{file}_clusteredraw_may.png", dpi=600, bbox_inches="tight")
    plt.close()

# Call the function to plot
plot_cluster_images(
    x, y, cluster_indices, 
    f=channel2rgb, 
    kwargs=dict(v="sample_data"),
    file="cluster_images",
    subsample=1, 
    zoom=0.2
)

### WITH COLORBAR

In [11]:
# === Final Clusters ===
final_clusters = [3, 4, 5, 6, 7, 8, 9]
cluster_idx_map = {cl: i for i, cl in enumerate(final_clusters)}
custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}

# === Color Mapping ===
colors_per_class1 = {
    str(i): color for i, color in enumerate([
        '#0F3C5F', '#48714F', '#C49138', '#FDB9C2', '#393d76', '#e1bc3a',
        '#af362b', '#19cbf7', '#000000', '#0008fa', '#3e8245', '#7c646e',
        '#34065c', '#afa8ed', '#fc030f', '#699be0', '#999999', '#dffc03',
        '#dea4da', '#cc7445'
    ])
}

# === Load TSNE Features ===
def read_features(compression="multiscale"):
    if compression == "multiscale":
        fext = "500multiscale50"
    elif compression == "annealing":
        fext = "500annealing50"
    elif compression == "pca":
        fext = "pcacosine"
    else:
        raise ValueError(f"Unknown compression {compression}")
    ds_tsne = np.load(f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/features_obs/tsne_msg_{fext}.npy")
    return ds_tsne[:, 0], ds_tsne[:, 1]

# === RGB Conversion ===
def channel2rgb(i_sample, v):
    norm = plt.Normalize(vmin=img_kwargs[v]["vmin"], vmax=img_kwargs[v]["vmax"])
    colormap = plt.cm.Greys
    rgb = colormap(norm(ds[v].isel(sample=i_sample).values))
    rgb = rgb[::-1, :, :]
    return (rgb * 255).astype("uint8")

# === Input Data and Labels ===
x, y = read_features(compression='pca')
ds = xr.open_dataset("/p/project1/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops.nc")

# Load and merge cluster labels
labels_raw = torch.load("/p/project/exaww/chatterjee1/mcspss_continuous/analysis/features_obs/obs_cluster_10_labels.pth", map_location="cpu")
labels_raw = np.array(labels_raw)

labels_raw[labels_raw == 0] = 7
labels_raw[labels_raw == 1] = 5
labels_raw[labels_raw == 2] = 3

# === Filter only final cluster labels ===
mask_final = np.isin(labels_raw, final_clusters)
x = x[mask_final]
y = y[mask_final]
labels_final = labels_raw[mask_final]

# === Regroup indices per final cluster ===
cluster_indices = {cl: np.where(labels_final == cl)[0] for cl in final_clusters}

# === Visualization Setup ===
img_kwargs = {
    "sample_data": {"vmin": 210, "vmax": 290, "cmap": 'viridis'},
}
n_samples_per_cluster = 10

# === Plot Function ===
def plot_cluster_images_with_colorbar(x, y, cluster_indices, f, kwargs, file, subsample=2, zoom=0.1):
    fig, axes = plt.subplots(len(final_clusters), n_samples_per_cluster,
                             figsize=(2 * n_samples_per_cluster, 2 * len(final_clusters)))
    norm = plt.Normalize(vmin=img_kwargs["sample_data"]["vmin"], vmax=img_kwargs["sample_data"]["vmax"])
    sm = cm.ScalarMappable(cmap=plt.cm.Greys, norm=norm)
    sm.set_array([])

    for row_idx, cl in enumerate(final_clusters):
        indices = cluster_indices[cl]
        chosen = np.random.choice(indices, size=min(n_samples_per_cluster, len(indices)), replace=False)

        for j, i_sample in enumerate(chosen):
            ax = axes[row_idx, j]
            img = f(i_sample, **kwargs)
            ax.imshow(img)
            ax.set_xticks([])
            ax.set_yticks([])

            bbox_color = colors_per_class1.get(str(cl), "#000000")
            for spine in ax.spines.values():
                spine.set_color(bbox_color)
                spine.set_linewidth(2)

        label = custom_labels.get(cl, str(cl))
        axes[row_idx, 0].set_ylabel(f"Cluster {label}", fontsize=12, rotation=90, labelpad=20)

    cbar_ax = fig.add_axes([0.92, 0.3, 0.02, 0.4])
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation="vertical")
    cbar.set_label("BT", fontsize=20)
    cbar.ax.tick_params(labelsize=16)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/{file}_clusteredraw_merged_3.png", dpi=600)
    plt.show()

# === Execute ===
plot_cluster_images_with_colorbar(
    x, y, cluster_indices,
    f=channel2rgb,
    kwargs=dict(v="sample_data"),
    file="cluster_images",
    subsample=1,
    zoom=0.2
)

  plt.tight_layout(rect=[0, 0, 0.9, 1])


## gray scale

In [30]:
def plot_cluster_images_with_colorbar(x, y, cluster_indices, f, kwargs, file, subsample=2, zoom=0.1):
    fig, axes = plt.subplots(K, n_samples_per_cluster, figsize=(2 * n_samples_per_cluster, 2 * K))
    
    # Set up a single colorbar
    norm = plt.Normalize(vmin=img_kwargs["sample_data"]["vmin"], vmax=img_kwargs["sample_data"]["vmax"])
    sm = cm.ScalarMappable(cmap=plt.cm.gray, norm=norm)
    sm.set_array([])  # Required for colorbar

    for cluster_id in range(K):
        indices = np.random.choice(cluster_indices[cluster_id], size=min(n_samples_per_cluster, len(cluster_indices[cluster_id])), replace=False)

        for j, i_sample in enumerate(indices):
            ax = axes[cluster_id, j]
            img = f(i_sample, **kwargs)  # Returns 2D grayscale image

            ax.imshow(img, cmap="gray", vmin=norm.vmin, vmax=norm.vmax)
            ax.set_xticks([])
            ax.set_yticks([])

            # Add cluster color border
            bbox_color = colors_per_class1.get(str(cluster_id), "#000000")
            for spine in ax.spines.values():
                spine.set_color(bbox_color)
                spine.set_linewidth(2)

        axes[cluster_id, 0].set_ylabel(f"Cluster {cluster_id}", fontsize=12, rotation=90, labelpad=20)

    # Add a single colorbar on the right side
    cbar_ax = fig.add_axes([0.92, 0.3, 0.02, 0.4])  # [left, bottom, width, height]
    fig.colorbar(sm, cax=cbar_ax, orientation="vertical", label="Feature Intensity")

    plt.tight_layout(rect=[0, 0, 0.9, 1])  # Leave space for colorbar
    plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/{file}_clusteredrowgray_march.png", dpi=600, bbox_inches="tight")
    plt.close()

# Call the function
plot_cluster_images_with_colorbar(
    x, y, cluster_indices,
    f=channel2gray, 
    kwargs=dict(v="sample_data"),
    file="cluster_images_gray",
    subsample=1, 
    zoom=0.2
)

  plt.tight_layout(rect=[0, 0, 0.9, 1])  # Leave space for colorbar
