In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import glob
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
#import cartopy.crs as ccrs
#import cartopy.feature as cfeature
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import torch
from collections import defaultdict
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
import os
import seaborn as sns

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D

In [3]:
# === Final cluster settings ===
final_clusters = [3, 4, 5, 6, 7, 8, 9]
n_clusters = len(final_clusters)
n_hours = 24
n_sites = 11
cluster_idx_map = {cl: i for i, cl in enumerate(final_clusters)}

custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}
xtick_labels = [custom_labels.get(cl, str(cl)) for cl in final_clusters]

# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Icon site info ===
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Initialize arrays ===
obs_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))
model_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))

# === Data processing ===
for i in range(n_sites):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
    cluster_id_icon = icon_cluster_sites[i]

    # === Load OBS ===
    obs_nc = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    obs_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    ds = xr.open_dataset(obs_nc)
    raw_times = ds['time'].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")
    obs_clusters = np.array(torch.load(obs_cluster_file, map_location="cpu"))
    obs_clusters[obs_clusters == 0] = 7
    obs_clusters[obs_clusters == 1] = 5
    obs_clusters[obs_clusters == 2] = 3

    hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'), end=obs_datetimes.max().ceil('H'), freq='H')
    matched_indices = [np.argmin(np.abs((obs_datetimes - h).total_seconds())) for h in hourly_times if np.min(np.abs((obs_datetimes - h).total_seconds())) <= 900]
    obs_hourly_times = obs_datetimes[matched_indices]
    obs_hourly_clusters = obs_clusters[matched_indices]
    obs_hours = obs_hourly_times.hour

    for cl in final_clusters:
        idx = cluster_idx_map[cl]
        counts = np.bincount(obs_hours[obs_hourly_clusters == cl], minlength=24)
        obs_hourly_counts[idx, i, :] = counts

    # === Load MODEL ===
    model_nc = f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_var_icon}crops_icon.nc"
    model_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth"
    ds_model = xr.open_dataset(model_nc)
    model_times = pd.to_datetime(ds_model["time"].values)
    model_clusters = np.array(torch.load(model_cluster_file, map_location="cpu"))
    model_clusters[model_clusters == 0] = 7
    model_clusters[model_clusters == 1] = 5
    model_clusters[model_clusters == 2] = 3
    model_hours = model_times.hour

    for cl in final_clusters:
        idx = cluster_idx_map[cl]
        counts = np.bincount(model_hours[model_clusters == cl], minlength=24)
        model_hourly_counts[idx, i, :] = counts

# === Normalize per site ===
obs_hourly_norm = obs_hourly_counts / obs_hourly_counts.sum(axis=2, keepdims=True)
model_hourly_norm = np.where(
    model_hourly_counts.sum(axis=2, keepdims=True) > 0,
    model_hourly_counts / model_hourly_counts.sum(axis=2, keepdims=True),
    0
)

# === Mean and spread ===
obs_mean = obs_hourly_norm.mean(axis=1)
obs_std = obs_hourly_norm.std(axis=1)
model_mean = model_hourly_norm.mean(axis=1)
model_std = model_hourly_norm.std(axis=1)

# === Plotting ===
fig, axes = plt.subplots(2, 4, figsize=(22, 8), sharex=True, sharey=True)
axes = axes.flatten()

for i, cl in enumerate(final_clusters):
    ax = axes[i]
    h = np.arange(24)

    print(f"ICON cluster {cl} total count across all sites: {model_hourly_counts[cluster_idx_map[cl]].sum()}")

    # OBS
    ax.plot(h, obs_mean[i], label='OBS', color='tab:blue')
    ax.fill_between(h, obs_mean[i] - obs_std[i], obs_mean[i] + obs_std[i], color='tab:blue', alpha=0.3)

    # ICON
    if model_mean[i].sum() == 0:
        print(f"[Warning] Cluster {cl} has no ICON samples")
        continue

    ax.plot(h, model_mean[i], label='ICON', color='tab:orange')
    ax.fill_between(h, model_mean[i] - model_std[i], model_mean[i] + model_std[i], color='tab:orange', alpha=0.3)

    ax.set_title(f'Cluster {xtick_labels[i]}')
    ax.set_xticks(np.arange(0, 24, 3))
    ax.set_xticklabels([f"{h:02d}:00" for h in range(0, 24, 3)])
    ax.set_ylabel("Norm. Freq.")
    ax.grid(True, which='both', axis='both', linestyle='--', alpha=0.4)

axes[0].legend(loc='upper right')
fig.suptitle("OBS vs ICON Diurnal Cluster Profiles (Hourly)", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/diurnal_obs_vs_icon_clusters_merged.png")
plt.show()

ICON cluster 3 total count across all sites: 2598.0
ICON cluster 4 total count across all sites: 15704.0
ICON cluster 5 total count across all sites: 1256.0
ICON cluster 6 total count across all sites: 4830.0
ICON cluster 7 total count across all sites: 2364.0
ICON cluster 8 total count across all sites: 19668.0
ICON cluster 9 total count across all sites: 3905.0


In [4]:
# === Final merged clusters ===
final_clusters = [3, 4, 5, 6, 7, 8, 9]
n_clusters = len(final_clusters)
n_sites = 11
cluster_idx_map = {cl: i for i, cl in enumerate(final_clusters)}

custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}
xtick_labels = [custom_labels.get(cl, str(cl)) for cl in final_clusters]

# === Site mapping (as before) ===
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Initialize count containers ===
obs_cluster_counts = np.zeros(n_clusters)
model_cluster_counts = np.zeros(n_clusters)

# === Loop over all sites and accumulate cluster counts ===
for i in range(n_sites):
    obs_id = obs_cluster_sites[i]
    icon_id = icon_cluster_sites[i]

    # === Load OBS clusters ===
    obs_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{obs_id}_cluster_10_labels.pth"
    obs_clusters = np.array(torch.load(obs_cluster_file, map_location="cpu"))
    obs_clusters[obs_clusters == 0] = 7
    obs_clusters[obs_clusters == 1] = 5
    obs_clusters[obs_clusters == 2] = 3

    # === Load ICON clusters ===
    model_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{icon_id}_cluster_10_labels.pth"
    model_clusters = np.array(torch.load(model_cluster_file, map_location="cpu"))
    model_clusters[model_clusters == 0] = 7
    model_clusters[model_clusters == 1] = 5
    model_clusters[model_clusters == 2] = 3

    # === Count occurrences only for final clusters ===
    for cl in final_clusters:
        idx = cluster_idx_map[cl]
        obs_cluster_counts[idx] += np.sum(obs_clusters == cl)
        model_cluster_counts[idx] += np.sum(model_clusters == cl)

# === Normalize to frequencies ===
obs_freq = obs_cluster_counts / obs_cluster_counts.sum()
model_freq = model_cluster_counts / model_cluster_counts.sum()

# === Plotting ===
x = np.arange(n_clusters)
bar_width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))

ax.bar(x - bar_width/2, obs_freq, width=bar_width, label='OBS', color='tab:blue')
ax.bar(x + bar_width/2, model_freq, width=bar_width, label='ICON', color='tab:orange')

ax.set_xlabel("Cluster")
ax.set_ylabel("Normalized Frequency")
ax.set_title("Cluster Occurrence Frequency (All Sites Combined)")
ax.set_xticks(x)
ax.set_xticklabels(xtick_labels)
ax.legend()
ax.grid(True, axis='y', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/cluster_freq_all_sites_merged.png", dpi=300)
plt.show()