In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import glob
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
#import cartopy.crs as ccrs
#import cartopy.feature as cfeature
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import torch
from collections import defaultdict
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
import os
import seaborn as sns

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D

## merged lcuster 9 BT OBS

In [2]:
# Custom label handling
custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}
final_clusters = [3, 4, 5, 6, 7, 8, 9]
n_clusters = len(final_clusters)
xtick_labels = [custom_labels.get(i, str(i)) for i in final_clusters]

# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# Settings
channel = 9
bin_edges = np.arange(180,320,5) #for 6.2 or channel_5 = (190,250,5) || for 7.3 or channel_6 = (190,270,5) || for 10.8 or channel_9 = (180, 320, 5) || for channel_10 i.e 12.0 um = (320,180,5)
n_bins = len(bin_edges) - 1
all_bt = [[] for _ in range(n_clusters)]

# Load data and compute BT per final cluster
for i in range(11):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]

    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"

    if not os.path.exists(nc_file) or not os.path.exists(cluster_file):
        print(f"Missing: {site_nc}")
        continue

    ds = xr.open_dataset(nc_file)
    bt = ds[f'sample_{site_var}_data_{channel}'].values
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))

    cluster_labels[cluster_labels == 0] = 7
    cluster_labels[cluster_labels == 1] = 5
    cluster_labels[cluster_labels == 2] = 3

    for new_cl in final_clusters:
        mask = cluster_labels == new_cl
        if np.any(mask):
            idx = final_clusters.index(new_cl)
            all_bt[idx].append(bt[mask].reshape(-1))

# Histograms
histograms = []
for cl_bt in all_bt:
    if cl_bt:
        all_pixels = np.concatenate(cl_bt)
        hist, _ = np.histogram(all_pixels, bins=bin_edges, density=True)
    else:
        hist = np.zeros(n_bins)
    histograms.append(hist)
histograms = np.array(histograms)

# JS & KL divergence
js_matrix = np.zeros((n_clusters, n_clusters))
kl_matrix = np.zeros((n_clusters, n_clusters))

for i in range(n_clusters):
    for j in range(n_clusters):
        P = histograms[i] + 1e-10
        Q = histograms[j] + 1e-10
        P /= P.sum()
        Q /= Q.sum()
        js_matrix[i, j] = jensenshannon(P, Q, base=2) ** 2
        kl_matrix[i, j] = np.sum(rel_entr(P, Q))

# Plot JS divergence
plt.figure(figsize=(10, 8))
sns.heatmap(js_matrix, annot=True, fmt=".3f", cmap="viridis", xticklabels=xtick_labels, yticklabels=xtick_labels)
plt.title("JS Divergence Between Clusters (Obs)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/js_divergence_merged_obs_clusters_C{channel}.png")

# Plot KL divergence
plt.figure(figsize=(10, 8))
sns.heatmap(kl_matrix, annot=True, fmt=".2f", cmap="magma", xticklabels=xtick_labels, yticklabels=xtick_labels)
plt.title("KL Divergence Between Clusters (Obs)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/kl_divergence_merged_obs_clusters_C{channel}.png")

# Plot histograms per cluster
fig, axes = plt.subplots(2, 4, figsize=(20, 8), sharex=True, sharey=True)
axes = axes.flatten()

for cluster_idx, label in enumerate(xtick_labels):
    ax = axes[cluster_idx]
    ax.bar(
        x=bin_edges[:-1] + 2.5,
        height=histograms[cluster_idx],
        width=5,
        color='skyblue',
        edgecolor='black'
    )
    ax.set_title(f'Cluster {label}')
    ax.set_xlabel('BT (K)')
    ax.set_ylabel('Normalized Frequency')

fig.suptitle('BT (10.8µm) Histogram for Each Cluster (All Sites, All Timestamps)', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/histcluster_allsites_merged_obs_C{channel}.png")


## merged lcuster 9 BT OBS +MODEL

In [7]:

# === Cluster config ===
final_clusters = [3, 4, 5, 6, 7, 8, 9]
n_clusters = len(final_clusters)
cluster_idx_map = {cl: i for i, cl in enumerate(final_clusters)}

custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}
xtick_labels = [custom_labels.get(cl, str(cl)) for cl in final_clusters]

# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Bin settings ===
bin_edges = np.arange(180, 320, 5)
bin_centers = bin_edges[:-1] + np.diff(bin_edges) / 2
obs_all_bt = [[] for _ in range(n_clusters)]
model_all_bt = [[] for _ in range(n_clusters)]

def match_closest(datetimes, reference_times):
    matched = []
    for ref in reference_times:
        diffs = np.abs((datetimes - ref).total_seconds())
        min_idx = np.argmin(diffs)
        if diffs[min_idx] <= 900:
            matched.append(min_idx)
    return matched

# === Main loop over sites ===
for i in range(11):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
    cluster_id_icon = icon_cluster_sites[i]

    obs_nc = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    obs_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    model_nc = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/icon_{site_nc_icon}_WV_IR_crops.nc"
    model_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth"

    if not os.path.exists(obs_nc):
        print(f"1. Missing files for {site_nc}")
        continue
    elif not os.path.exists(obs_cluster_file):
        print(f"2. Missing files for {obs_cluster_file}")
        continue
    elif not os.path.exists(model_nc):
        print(f"3. Missing files for {model_nc}")
        continue
    elif not os.path.exists(model_cluster_file):
        print(f"4. Missing files for {model_cluster_file}")
        continue

    # === Load OBS ===
    ds = xr.open_dataset(obs_nc)
    obs_bt = ds[f"sample_{site_var}_data_9"].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in ds["time"].values], format="%Y%m%d%H%M")
    obs_cluster_labels = np.array(torch.load(obs_cluster_file, map_location="cpu"))

    # Merge cluster labels
    obs_cluster_labels[obs_cluster_labels == 0] = 7
    obs_cluster_labels[obs_cluster_labels == 1] = 5
    obs_cluster_labels[obs_cluster_labels == 2] = 3

    # === Load MODEL ===
    ds_model = xr.open_dataset(model_nc)
    model_bt = ds_model[f"{site_var_icon}_data_IR"].values
    model_datetimes = pd.to_datetime(ds_model["time"].values)
    model_cluster_labels = np.array(torch.load(model_cluster_file, map_location="cpu"))

    model_cluster_labels[model_cluster_labels == 0] = 7
    model_cluster_labels[model_cluster_labels == 1] = 5
    model_cluster_labels[model_cluster_labels == 2] = 3

    # === Match to hourly timestamps ===
    hourly_times = pd.date_range(start=max(obs_datetimes.min(), model_datetimes.min()).floor('H'),
                                 end=min(obs_datetimes.max(), model_datetimes.max()).ceil('H'),
                                 freq='H')

    obs_idx = match_closest(obs_datetimes, hourly_times)
    obs_bt_hourly = obs_bt[obs_idx]
    obs_clusters_hourly = obs_cluster_labels[obs_idx]

    model_mask = model_datetimes.isin(hourly_times)
    model_bt_hourly = model_bt[model_mask]
    model_clusters_hourly = model_cluster_labels[model_mask]

    # === Aggregate BT values by final cluster ===
    for cl in final_clusters:
        idx = cluster_idx_map[cl]
        obs_mask = obs_clusters_hourly == cl
        model_mask = model_clusters_hourly == cl

        obs_bt_cluster = obs_bt_hourly[obs_mask].reshape(-1) if np.any(obs_mask) else np.array([])
        model_bt_cluster = model_bt_hourly[model_mask].reshape(-1) if np.any(model_mask) else np.array([])

        if obs_bt_cluster.size > 0:
            obs_all_bt[idx].append(obs_bt_cluster)
        if model_bt_cluster.size > 0:
            model_all_bt[idx].append(model_bt_cluster)

# === Plotting ===
fig, axes = plt.subplots(2, 4, figsize=(20, 8), sharex=True, sharey=True)
axes = axes.flatten()

for i, cl in enumerate(final_clusters):
    ax = axes[i]

    obs_samples = np.concatenate(obs_all_bt[i]) if obs_all_bt[i] else np.array([])
    model_samples = np.concatenate(model_all_bt[i]) if model_all_bt[i] else np.array([])

    ax.hist(obs_samples, bins=bin_edges, color='skyblue', edgecolor='black', alpha=0.6, label='Observation', density=True)
    ax.hist(model_samples, bins=bin_edges, color='tomato', edgecolor='black', alpha=0.6, label='Model', density=True)

    ax.set_title(f'Cluster {xtick_labels[i]}')
    ax.set_xlabel('BT (K)')
    ax.set_ylabel('Frequency')
    ax.grid(True, linestyle='--', alpha=0.3)

# Final touches
axes[0].legend(loc='upper right')
fig.suptitle('BT (10.8µm) Histogram: Combined All Sites (Clusters 3–9)', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/allsite_combined_merged_histogram_obs_model_clusterwise.png", dpi=300)
plt.show()


## 3D OBS

In [8]:


# === Cluster config ===
final_clusters = [3, 4, 5, 6, 7, 8, 9]
n_clusters = len(final_clusters)
cluster_idx_map = {cl: i for i, cl in enumerate(final_clusters)}
custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}
xtick_labels = [custom_labels.get(cl, str(cl)) for cl in final_clusters]

# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
obs_cluster_sites = {
    0: "juelich", 1: "lin", 2: "warsaw", 3: "vienna", 4: "bourges",
    5: "zargoza", 6: "sirta", 7: "cabauw", 8: "nuremberg", 9: "aurillac", 10: "dresden",
}

# === Bin edges for 3 channels ===
bin_edges_C5 = np.arange(190, 255, 5)   # 6.2 µm
bin_edges_C6 = np.arange(190, 275, 5)   # 7.3 µm
bin_edges_C9 = np.arange(180, 325, 5)   # 10.8 µm

bins = [bin_edges_C5, bin_edges_C6, bin_edges_C9]
all_joint_histograms = [np.zeros((len(bin_edges_C5)-1, len(bin_edges_C6)-1, len(bin_edges_C9)-1)) for _ in range(n_clusters)]

# === Aggregate 3-channel BT values by cluster ===
for i in range(11):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    if not os.path.exists(nc_file) or not os.path.exists(cluster_file):
        print(f"Missing: {site_nc}")
        continue

    ds = xr.open_dataset(nc_file)
    c5 = ds[f'sample_{site_var}_data_5'].values.reshape(ds.dims['sample'], -1)
    c6 = ds[f'sample_{site_var}_data_6'].values.reshape(ds.dims['sample'], -1)
    c9 = ds[f'sample_{site_var}_data_9'].values.reshape(ds.dims['sample'], -1)
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))

    # === Merge clusters ===
    cluster_labels[cluster_labels == 0] = 7
    cluster_labels[cluster_labels == 1] = 5
    cluster_labels[cluster_labels == 2] = 3

    for cl in final_clusters:
        idx = cluster_idx_map[cl]
        mask = cluster_labels == cl
        if np.any(mask):
            x = c5[mask].ravel()
            y = c6[mask].ravel()
            z = c9[mask].ravel()
            xyz = np.vstack((x, y, z)).T
            joint_hist, _ = np.histogramdd(xyz, bins=bins, density=True)
            all_joint_histograms[idx] += joint_hist

# === Flatten and normalize each joint histogram ===
histograms_flat = []
for hist3d in all_joint_histograms:
    hist_flat = hist3d.flatten()
    hist_flat += 1e-10  # regularize
    hist_flat /= hist_flat.sum()
    histograms_flat.append(hist_flat)

# === Compute JS & KL matrices ===
js_matrix = np.zeros((n_clusters, n_clusters))
kl_matrix = np.zeros((n_clusters, n_clusters))

for i in range(n_clusters):
    for j in range(n_clusters):
        P = histograms_flat[i]
        Q = histograms_flat[j]
        js_matrix[i, j] = jensenshannon(P, Q, base=2) ** 2
        kl_matrix[i, j] = np.sum(rel_entr(P, Q))

# === Plot JS divergence heatmap ===
plt.figure(figsize=(10, 8))
sns.heatmap(js_matrix, annot=True, fmt=".3f", cmap="viridis", xticklabels=xtick_labels, yticklabels=xtick_labels)
plt.title("JS Divergence Between Clusters (Joint 6.2, 7.3, 10.8 µm)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/js_divergence_merged_joint_C5C6C9.png")

# === Plot KL divergence heatmap ===
plt.figure(figsize=(10, 8))
sns.heatmap(kl_matrix, annot=True, fmt=".2f", cmap="magma", xticklabels=xtick_labels, yticklabels=xtick_labels)
plt.title("KL Divergence Between Clusters (Joint 6.2, 7.3, 10.8 µm)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/kl_divergence_merged_joint_C5C6C9.png")


## tSne and PCA

In [9]:
# Stack flattened histograms into matrix (10, features)
X = np.stack(histograms_flat)  # shape: (10, N)

# === PCA ===
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

plt.figure(figsize=(7, 6))
for i in range(n_clusters):
    plt.scatter(X_pca[i, 0], X_pca[i, 1], label=f'Cluster {i}', s=100)
plt.title("PCA Projection of Joint BT Histograms (C5, C6, C9)")
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/pca_merged_joint_C5C6C9.png")
plt.show()

# === t-SNE ===
tsne = TSNE(n_components=2, perplexity=5, random_state=42, init="pca", learning_rate="auto")
X_tsne = tsne.fit_transform(X)

plt.figure(figsize=(7, 6))
for i in range(n_clusters):
    plt.scatter(X_tsne[i, 0], X_tsne[i, 1], label=f'Cluster {i}', s=100)
plt.title("t-SNE Projection of Joint BT Histograms (C5, C6, C9)")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/tsne_merged_joint_C5C6C9.png")
plt.show()