In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import glob
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
#import cartopy.crs as ccrs
#import cartopy.feature as cfeature
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import torch
from collections import defaultdict
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
import os
import seaborn as sns

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D


## Hourly diurnal

In [8]:
# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
}

n_clusters = 10
n_hours = 24

# Cluster × Site × Hour array
cluster_hour_counts = np.zeros((n_clusters, len(obs_sites_ncvar_name), n_hours))

for i in range(len(obs_sites_ncvar_name)):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    ds = xr.open_dataset(nc_file)
    raw_times = ds['time'].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")
    hours = obs_datetimes.hour
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))

    for cl in range(n_clusters):
        hour_counts = np.bincount(hours[cluster_labels == cl], minlength=24)
        cluster_hour_counts[cl, i, :] = hour_counts

# Normalize across hours per site (optional)
cluster_hour_norm = cluster_hour_counts / cluster_hour_counts.sum(axis=2, keepdims=True)

# === Compute mean & spread across sites ===
cluster_hour_mean = cluster_hour_norm.mean(axis=1)
cluster_hour_std = cluster_hour_norm.std(axis=1)

# === Plotting ===
hours = np.arange(24)
fig, axes = plt.subplots(2, 5, figsize=(20, 8), sharex=True, sharey=True)

axes = axes.flatten()
for cl in range(n_clusters):
    ax = axes[cl]
    mean = cluster_hour_mean[cl]
    std = cluster_hour_std[cl]
    
    ax.plot(hours, mean, label=f'Cluster {cl}', color='tab:blue')
    ax.fill_between(hours, mean - std, mean + std, color='tab:blue', alpha=0.3)
    
    ax.set_title(f'Cluster {cl}')
    ax.set_xticks(np.arange(0, 24, 3))
    ax.set_xlabel('Hour of Day')
    ax.set_ylabel('Norm. Frequency')

fig.suptitle('Diurnal Cycle of Clusters (Obs, 8 Sites)', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/diurnal_cluster_spread.png")
plt.show()

# consider the 15 minutes also

In [7]:
# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
}

n_clusters = 10
n_sites = len(obs_sites_ncvar_name)
n_bins = 96  # 96 x 15-minute intervals in 24 hours

# Initialize storage: (cluster, site, 15min-bin)
cluster_diurnal = np.zeros((n_clusters, n_sites, n_bins))

# 15-minute bin labels
bin_labels = pd.date_range("00:00", "23:45", freq="15min").strftime("%H:%M")

for i in range(n_sites):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    ds = xr.open_dataset(nc_file)
    raw_times = ds['time'].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")
    
    # Compute 15-min bin index: hour * 4 + minute // 15
    obs_datetimes_adj = obs_datetimes - pd.Timedelta(minutes=12)
    bins = obs_datetimes_adj.hour * 4 + obs_datetimes_adj.minute // 15
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))

    for cl in range(n_clusters):
        binned_counts = np.bincount(bins[cluster_labels == cl], minlength=96)
        cluster_diurnal[cl, i, :] = binned_counts

# === Normalize within each site (optional) ===
cluster_diurnal_norm = cluster_diurnal / cluster_diurnal.sum(axis=2, keepdims=True)

# === Mean and spread across sites ===
mean_diurnal = cluster_diurnal_norm.mean(axis=1)
std_diurnal = cluster_diurnal_norm.std(axis=1)

# === Plotting ===
fig, axes = plt.subplots(2, 5, figsize=(22, 8), sharex=True, sharey=True)
axes = axes.flatten()

x = np.arange(96)

for cl in range(n_clusters):
    ax = axes[cl]
    mean = mean_diurnal[cl]
    std = std_diurnal[cl]
    
    ax.plot(x, mean, color='tab:blue', label='Mean')
    ax.fill_between(x, mean - std, mean + std, color='tab:blue', alpha=0.3, label='±1 STD')
    
    ax.set_title(f'Cluster {cl}')
    #ax.set_xlim([0, 95])
    #ax.set_xticks(np.linspace(0, 95, 9))  # Label every 3 hours
    #ax.set_xticklabels(pd.date_range("00:00", "24:00", freq="3H").strftime("%H:%M"), rotation=45)
    
    xtick_pos = np.arange(0, 96, 12)  # 0, 12, 24, ..., 84
    xtick_labels = [f"{h:02d}:00" for h in range(0, 24, 3)]  # 8 labels
    ax.set_xticks(xtick_pos)
    ax.set_xticklabels(xtick_labels, rotation=45)
    ax.set_xticks(xtick_pos)
    ax.set_xticklabels(xtick_labels, rotation=45)
    ax.set_ylabel("Norm. Frequency")
    
    ax.grid(True, which='both', axis='both', linestyle='--', alpha=0.4)

fig.suptitle('Diurnal Cycle of Clusters (15-min Bins, Obs Across Sites)', fontsize=18)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/diurnal_cluster_15min_abs.png", dpi=300)
plt.show()

In [8]:
# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
}

n_clusters = 10
n_sites = len(obs_sites_ncvar_name)
n_bins = 96  # 96 x 15-minute intervals in 24 hours

# Initialize storage: (cluster, site, 15min-bin)
cluster_diurnal = np.zeros((n_clusters, n_sites, n_bins))

# 15-minute bin labels
bin_labels = pd.date_range("00:00", "23:45", freq="15min").strftime("%H:%M")

for i in range(n_sites):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    ds = xr.open_dataset(nc_file)
    raw_times = ds['time'].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")
    
    # Compute 15-min bin index: hour * 4 + minute // 15
    obs_datetimes_adj = obs_datetimes - pd.Timedelta(minutes=12)
    bins = obs_datetimes_adj.hour * 4 + obs_datetimes_adj.minute // 15
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))

    for cl in range(n_clusters):
        binned_counts = np.bincount(bins[cluster_labels == cl], minlength=96)
        cluster_diurnal[cl, i, :] = binned_counts

# === Normalize within each site (optional) ===
#cluster_diurnal_norm = cluster_diurnal / cluster_diurnal.sum(axis=2, keepdims=True)

# === Mean and spread across sites ===
mean_diurnal = cluster_diurnal.mean(axis=1)
std_diurnal = cluster_diurnal.std(axis=1)

# === Plotting ===
fig, axes = plt.subplots(2, 5, figsize=(22, 8), sharex=True, sharey=True)
axes = axes.flatten()

x = np.arange(96)

for cl in range(n_clusters):
    ax = axes[cl]
    mean = mean_diurnal[cl]
    std = std_diurnal[cl]
    
    ax.plot(x, mean, color='tab:blue', label='Mean')
    ax.fill_between(x, mean - std, mean + std, color='tab:blue', alpha=0.3, label='±1 STD')
    
    ax.set_title(f'Cluster {cl}')
    #ax.set_xlim([0, 95])
    #ax.set_xticks(np.linspace(0, 95, 9))  # Label every 3 hours
    #ax.set_xticklabels(pd.date_range("00:00", "24:00", freq="3H").strftime("%H:%M"), rotation=45)
    
    xtick_pos = np.arange(0, 96, 12)  # 0, 12, 24, ..., 84
    xtick_labels = [f"{h:02d}:00" for h in range(0, 24, 3)]  # 8 labels
    ax.set_xticks(xtick_pos)
    ax.set_xticklabels(xtick_labels, rotation=45)
    ax.set_xticks(xtick_pos)
    ax.set_xticklabels(xtick_labels, rotation=45)
    ax.set_ylabel("Abs. Frequency")
    
    ax.grid(True, which='both', axis='both', linestyle='--', alpha=0.4)

fig.suptitle('Diurnal Cycle of Clusters (15-min Bins, Obs Across Sites)', fontsize=18)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/diurnal_cluster_15min_abs.png", dpi=300)
plt.show()

In [4]:
xtick_pos

array([ 0.   , 11.875, 23.75 , 35.625, 47.5  , 59.375, 71.25 , 83.125,
       95.   ])

## Diurnal - Hourly OBS vs Model

In [12]:
# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Icon site info ===
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

n_clusters = 10
n_hours = 24
n_sites = 11

obs_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))
model_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))

for i in range(n_sites):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
    cluster_id_icon = icon_cluster_sites[i]

    # === Load OBS ===
    obs_nc = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    obs_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    ds = xr.open_dataset(obs_nc)
    raw_times = ds['time'].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")
    obs_clusters = np.array(torch.load(obs_cluster_file, map_location="cpu"))
    
    hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'),
                                 end=obs_datetimes.max().ceil('H'),
                                 freq='H')

    matched_indices = []
    for h in hourly_times:
        time_diffs = np.abs((obs_datetimes - h).total_seconds())
        min_idx = np.argmin(time_diffs)
        if time_diffs[min_idx] <= 900:
            matched_indices.append(min_idx)

    obs_hourly_times = obs_datetimes[matched_indices]
    obs_hourly_clusters = obs_clusters[matched_indices]
    obs_hours = obs_hourly_times.hour

    for cl in range(n_clusters):
        counts = np.bincount(obs_hours[obs_hourly_clusters == cl], minlength=24)
        obs_hourly_counts[cl, i, :] = counts

    # === Load MODEL ===
    model_nc = f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_var_icon}crops_icon.nc"
    model_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth"

    ds_model = xr.open_dataset(model_nc)
    model_times = pd.to_datetime(ds_model["time"].values)
    model_clusters = np.array(torch.load(model_cluster_file, map_location="cpu"))
    model_hours = model_times.hour

    for cl in range(n_clusters):
        counts = np.bincount(model_hours[model_clusters == cl], minlength=24)
        model_hourly_counts[cl, i, :] = counts

# === Normalize within each site (optional) ===
obs_hourly_norm = obs_hourly_counts / obs_hourly_counts.sum(axis=2, keepdims=True)
#model_hourly_norm = model_hourly_counts / model_hourly_counts.sum(axis=2, keepdims=True)
model_hourly_norm = np.where(
    model_hourly_counts.sum(axis=2, keepdims=True) > 0,
    model_hourly_counts / model_hourly_counts.sum(axis=2, keepdims=True),
    0
)

# === Mean and spread ===
obs_mean = obs_hourly_norm.mean(axis=1)
obs_std = obs_hourly_norm.std(axis=1)

model_mean = model_hourly_norm.mean(axis=1)
model_std = model_hourly_norm.std(axis=1)

# === Plotting ===
fig, axes = plt.subplots(2, 5, figsize=(22, 8), sharex=True, sharey=True)
axes = axes.flatten()

for cl in range(n_clusters):
    ax = axes[cl]
    h = np.arange(24)

    print(f"ICON cluster {cl} total count across all sites: {model_hourly_counts[cl].sum()}")

    # OBS
    ax.plot(h, obs_mean[cl], label='OBS', color='tab:blue')
    ax.fill_between(h, obs_mean[cl] - obs_std[cl], obs_mean[cl] + obs_std[cl],
                    color='tab:blue', alpha=0.3)

    if model_mean[cl].sum() == 0:
        print(f"[Warning] Cluster {cl} has no ICON samples")
        continue
        
    # ICON
    ax.plot(h, model_mean[cl], label='ICON', color='tab:orange')
    ax.fill_between(h, model_mean[cl] - model_std[cl], model_mean[cl] + model_std[cl],
                    color='tab:orange', alpha=0.3)

    ax.set_title(f'Cluster {cl}')
    ax.set_xticks(np.arange(0, 24, 3))
    ax.set_xticklabels([f"{h:02d}:00" for h in range(0, 24, 3)])
    ax.set_ylabel("Norm. Freq.")
    ax.grid(True, which='both', axis='both', linestyle='--', alpha=0.4)

axes[0].legend(loc='upper right')
fig.suptitle("OBS vs ICON Diurnal Cluster Profiles (Hourly)", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/diurnal_obs_vs_icon_clusters_new.png")
plt.show()

ICON cluster 0 total count across all sites: 1225.0
ICON cluster 1 total count across all sites: 300.0
ICON cluster 2 total count across all sites: 1517.0
ICON cluster 3 total count across all sites: 1081.0
ICON cluster 4 total count across all sites: 15704.0
ICON cluster 5 total count across all sites: 956.0
ICON cluster 6 total count across all sites: 4830.0
ICON cluster 7 total count across all sites: 1139.0
ICON cluster 8 total count across all sites: 19668.0
ICON cluster 9 total count across all sites: 3905.0


## Diurnal - Hourly obs - icon - smoothed

In [7]:
# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
}

# === Icon site info ===
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
}

def smooth_rolling_circular(arr, window=3):
    """Apply circular rolling mean to 1D array along last axis."""
    padded = np.concatenate([arr[..., -window:], arr, arr[..., :window]], axis=-1)
    kernel = np.ones(2 * window + 1) / (2 * window + 1)
    return np.apply_along_axis(lambda m: np.convolve(m, kernel, mode='valid'), -1, padded)


# Use ±1 hour around each time step (i.e., 3-point window for 2h smoothing)
window = 1

n_clusters = 10
n_hours = 24
n_sites = 8

obs_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))
model_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))

for i in range(n_sites):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
    cluster_id_icon = icon_cluster_sites[i]

    # === Load OBS ===
    obs_nc = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    obs_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    ds = xr.open_dataset(obs_nc)
    raw_times = ds['time'].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")
    obs_clusters = np.array(torch.load(obs_cluster_file, map_location="cpu"))
    
    hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'),
                                 end=obs_datetimes.max().ceil('H'),
                                 freq='H')

    matched_indices = []
    for h in hourly_times:
        time_diffs = np.abs((obs_datetimes - h).total_seconds())
        min_idx = np.argmin(time_diffs)
        if time_diffs[min_idx] <= 900:
            matched_indices.append(min_idx)

    obs_hourly_times = obs_datetimes[matched_indices]
    obs_hourly_clusters = obs_clusters[matched_indices]
    obs_hours = obs_hourly_times.hour

    for cl in range(n_clusters):
        counts = np.bincount(obs_hours[obs_hourly_clusters == cl], minlength=24)
        obs_hourly_counts[cl, i, :] = counts

    # === Load MODEL ===
    model_nc = f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_var_icon}crops_icon.nc"
    model_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth"

    ds_model = xr.open_dataset(model_nc)
    model_times = pd.to_datetime(ds_model["time"].values)
    model_clusters = np.array(torch.load(model_cluster_file, map_location="cpu"))
    model_hours = model_times.hour

    for cl in range(n_clusters):
        counts = np.bincount(model_hours[model_clusters == cl], minlength=24)
        model_hourly_counts[cl, i, :] = counts

# === Normalize within each site (optional) ===
obs_hourly_norm = obs_hourly_counts / obs_hourly_counts.sum(axis=2, keepdims=True)
model_hourly_norm = model_hourly_counts / model_hourly_counts.sum(axis=2, keepdims=True)

# === Mean and spread ===
obs_mean = obs_hourly_norm.mean(axis=1)
obs_std = obs_hourly_norm.std(axis=1)

model_mean = model_hourly_norm.mean(axis=1)
model_std = model_hourly_norm.std(axis=1)

# === rolling smoothness ===
obs_mean_smooth = smooth_rolling_circular(obs_mean, window=window)
obs_std_smooth = smooth_rolling_circular(obs_std, window=window)

model_mean_smooth = smooth_rolling_circular(model_mean, window=window)
model_std_smooth = smooth_rolling_circular(model_std, window=window)

# === Plotting ===
fig, axes = plt.subplots(2, 5, figsize=(22, 8), sharex=True, sharey=True)
axes = axes.flatten()

for cl in range(n_clusters):
    ax = axes[cl]
    h = np.arange(24)

    # OBS
    ax.plot(h, obs_mean[cl], label='OBS', color='tab:blue')
    ax.fill_between(h, obs_mean_smooth[cl] - obs_std_smooth[cl], obs_mean_smooth[cl] + obs_std_smooth[cl],
                    color='tab:blue', alpha=0.3)

    # ICON
    ax.plot(h, model_mean[cl], label='ICON', color='tab:orange')
    ax.fill_between(h, model_mean_smooth[cl] - model_std_smooth[cl], model_mean_smooth[cl] + model_std_smooth[cl],
                    color='tab:orange', alpha=0.3)

    ax.set_title(f'Cluster {cl}')
    ax.set_xticks(np.arange(0, 24, 3))
    ax.set_xticklabels([f"{h:02d}:00" for h in range(0, 24, 3)])
    ax.set_ylabel("Norm. Freq.")
    ax.grid(True, which='both', axis='both', linestyle='--', alpha=0.4)

axes[0].legend(loc='upper right')
fig.suptitle("OBS vs ICON Diurnal Cluster Profiles (Hourly, Smoothed)", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/diurnal_obs_vs_icon_clusters_smoothed.png")
plt.show()

## frequency of occurency

In [2]:

# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Icon site info ===
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

n_clusters = 10
n_sites = 11

# === Initialize count containers ===
obs_cluster_counts = np.zeros(n_clusters)
model_cluster_counts = np.zeros(n_clusters)

# === Loop over all sites and accumulate cluster counts ===
for i in range(n_sites):
    obs_id = obs_cluster_sites[i]
    icon_id = icon_cluster_sites[i]

    # === Load OBS clusters ===
    obs_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{obs_id}_cluster_10_labels.pth"
    obs_clusters = np.array(torch.load(obs_cluster_file, map_location="cpu"))
    obs_cluster_counts += np.bincount(obs_clusters, minlength=n_clusters)

    # === Load ICON clusters ===
    model_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{icon_id}_cluster_10_labels.pth"
    model_clusters = np.array(torch.load(model_cluster_file, map_location="cpu"))
    model_cluster_counts += np.bincount(model_clusters, minlength=n_clusters)

# === Normalize to frequencies ===
obs_freq = obs_cluster_counts / obs_cluster_counts.sum()
model_freq = model_cluster_counts / model_cluster_counts.sum()

# === Plotting ===
x = np.arange(n_clusters)
bar_width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))

ax.bar(x - bar_width/2, obs_freq, width=bar_width, label='OBS', color='tab:blue')
ax.bar(x + bar_width/2, model_freq, width=bar_width, label='ICON', color='tab:orange')

ax.set_xlabel("Cluster")
ax.set_ylabel("Normalized Frequency")
ax.set_title("Cluster Occurrence Frequency (All Sites Combined)")
ax.set_xticks(x)
ax.set_xticklabels([f"{i}" for i in x])
ax.legend()
ax.grid(True, axis='y', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/cluster_freq_all_sites.png", dpi=300)
plt.show()

## Diurnal PCA obs vs model on hourly scale 

In [8]:
# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Icon site info ===
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

n_clusters = 10
n_hours = 24
n_sites = 11

obs_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))
model_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))

for i in range(n_sites):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
    cluster_id_icon = icon_cluster_sites[i]

    # === Load OBS ===
    obs_nc = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    obs_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    ds = xr.open_dataset(obs_nc)
    raw_times = ds['time'].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")
    obs_clusters = np.array(torch.load(obs_cluster_file, map_location="cpu"))
    
    hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'),
                                 end=obs_datetimes.max().ceil('H'),
                                 freq='H')

    matched_indices = []
    for h in hourly_times:
        time_diffs = np.abs((obs_datetimes - h).total_seconds())
        min_idx = np.argmin(time_diffs)
        if time_diffs[min_idx] <= 900:
            matched_indices.append(min_idx)

    obs_hourly_times = obs_datetimes[matched_indices]
    obs_hourly_clusters = obs_clusters[matched_indices]
    obs_hours = obs_hourly_times.hour

    for cl in range(n_clusters):
        counts = np.bincount(obs_hours[obs_hourly_clusters == cl], minlength=24)
        obs_hourly_counts[cl, i, :] = counts

    # === Load MODEL ===
    model_nc = f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_var_icon}crops_icon.nc"
    model_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth"

    ds_model = xr.open_dataset(model_nc)
    model_times = pd.to_datetime(ds_model["time"].values)
    model_clusters = np.array(torch.load(model_cluster_file, map_location="cpu"))
    model_hours = model_times.hour

    for cl in range(n_clusters):
        counts = np.bincount(model_hours[model_clusters == cl], minlength=24)
        model_hourly_counts[cl, i, :] = counts

# === Normalize within each site (optional) ===
obs_hourly_norm = obs_hourly_counts / obs_hourly_counts.sum(axis=2, keepdims=True)
#model_hourly_norm = model_hourly_counts / model_hourly_counts.sum(axis=2, keepdims=True)
model_hourly_norm = np.where(
    model_hourly_counts.sum(axis=2, keepdims=True) > 0,
    model_hourly_counts / model_hourly_counts.sum(axis=2, keepdims=True),
    0
)

# === Reshape: (n_clusters * n_sites, n_hours) ===
obs_flat = obs_hourly_norm.reshape(n_clusters * n_sites, n_hours)
model_flat = model_hourly_norm.reshape(n_clusters * n_sites, n_hours)

# === PCA ===
pca = PCA(n_components=2)
obs_pca = pca.fit_transform(obs_flat)
model_pca = pca.fit_transform(model_flat)

# === Create labels for coloring ===
cluster_labels = np.repeat(np.arange(n_clusters), n_sites)

# === Plot ===
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)

# OBS
ax = axes[0]
sc = ax.scatter(obs_pca[:, 0], obs_pca[:, 1], c=cluster_labels, cmap='tab10')
ax.set_title("OBS PCA: Site-level Diurnal Cycles")
ax.set_xlabel("PC 1")
ax.set_ylabel("PC 2")
ax.grid(True, linestyle='--', alpha=0.5)

# ICON
ax = axes[1]
sc = ax.scatter(model_pca[:, 0], model_pca[:, 1], c=cluster_labels, cmap='tab10')
ax.set_title("ICON PCA: Site-level Diurnal Cycles")
ax.set_xlabel("PC 1")
ax.set_ylabel("PC 2")
ax.grid(True, linestyle='--', alpha=0.5)

plt.subplots_adjust(bottom=0.2)  # Adjust as needed
# Legend
cbar = fig.colorbar(sc, ax=axes, orientation='horizontal', fraction=0.05, pad=0.1, label='Cluster ID')
#cbar = fig.colorbar(sc, ax=axes.ravel().tolist(), orientation='horizontal',
#                    fraction=0.05, pad=0.15, label='Cluster ID')
plt.suptitle("PCA of Site-wise Diurnal Profiles by Cluster", fontsize=14)


#plt.tight_layout(rect=[0, 0, 1, 0.92])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/PCA_diurnal_sitewise.png", dpi=300)
plt.close()

In [4]:
obs_diurnals_flat.shape

(110, 24)

In [5]:
obs_hourly_norm.shape

(10, 11, 24)