In [35]:
import os
import numpy as np
import xarray as xr
import torch
import pandas as pd
from scipy.special import rel_entr
from scipy.spatial.distance import jensenshannon
from sklearn.decomposition import PCA
import networkx as nx
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import pdist, squareform
from scipy.special import rel_entr
from itertools import combinations
import seaborn as sns

In [8]:
# === Bin edges for 3 channels ===
bin_edges_C5 = np.arange(190, 255, 5)   # 6.2 µm
bin_edges_C6 = np.arange(190, 275, 5)   # 7.3 µm
bin_edges_C9 = np.arange(180, 325, 5)   # 10.8 µm
bins = [bin_edges_C5, bin_edges_C6, bin_edges_C9]

n_clusters = 10
n_hours = 24
n_sites = 11

# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"), 1: ("lindenberg", "lindenberg"), 2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"), 4: ("bourges", "bourges"), 5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"), 7: ("cabauw", "cabauw"), 8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"), 10: ("dresden", "dresden")
}
obs_cluster_sites = {
    0: "juelich", 1: "lin", 2: "warsaw", 3: "vienna", 4: "bourges", 5: "zargoza",
    6: "sirta", 7: "cabauw", 8: "nuremberg", 9: "aurillac", 10: "dresden"
}

# === Accumulators ===
all_joint_histograms = [np.zeros((len(bin_edges_C5)-1, len(bin_edges_C6)-1, len(bin_edges_C9)-1)) for _ in range(n_clusters)]
obs_hourly_counts = np.zeros((n_clusters, n_sites, n_hours))

# === Load Data and Compute ===
for i in range(n_sites):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]

    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    if not os.path.exists(nc_file) or not os.path.exists(cluster_file):
        print(f"Missing: {site_nc}")
        continue

    ds = xr.open_dataset(nc_file)
    c5 = ds[f'sample_{site_var}_data_5'].values.reshape(ds.dims['sample'], -1)
    c6 = ds[f'sample_{site_var}_data_6'].values.reshape(ds.dims['sample'], -1)
    c9 = ds[f'sample_{site_var}_data_9'].values.reshape(ds.dims['sample'], -1)
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))

    raw_times = ds['time'].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")
    hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'), end=obs_datetimes.max().ceil('H'), freq='H')

    matched_indices = [np.argmin(np.abs((obs_datetimes - h).total_seconds())) for h in hourly_times
                       if np.abs((obs_datetimes - h).total_seconds()).min() <= 900]

    obs_hourly_clusters = cluster_labels[matched_indices]
    obs_hours = obs_datetimes[matched_indices].hour

    for cl in range(n_clusters):
        counts = np.bincount(obs_hours[obs_hourly_clusters == cl], minlength=24)
        obs_hourly_counts[cl, i, :] = counts

    for cl in range(n_clusters):
        mask = cluster_labels == cl
        if np.any(mask):
            x = c5[mask].ravel()
            y = c6[mask].ravel()
            z = c9[mask].ravel()
            xyz = np.vstack((x, y, z)).T
            joint_hist, _ = np.histogramdd(xyz, bins=bins, density=True)
            all_joint_histograms[cl] += joint_hist

# === Flatten Histograms and Normalize ===
histograms_flat = []
for hist3d in all_joint_histograms:
    hist_flat = hist3d.flatten() + 1e-10
    hist_flat /= hist_flat.sum()
    histograms_flat.append(hist_flat)

# === JS Divergence Matrix ===
js_matrix = np.zeros((n_clusters, n_clusters))
for i in range(n_clusters):
    for j in range(n_clusters):
        js_matrix[i, j] = jensenshannon(histograms_flat[i], histograms_flat[j], base=2) ** 2

# === Diurnal Normalization ===
diurnal_norm = obs_hourly_counts / obs_hourly_counts.sum(axis=2, keepdims=True)

In [11]:
diurnal_norm.shape

(10, 11, 24)

In [21]:
# === Compute PCA distance matrix ===
n_clusters, n_sites, n_hours = diurnal_norm.shape
diurnal_pca_matrix = np.zeros((n_clusters, n_clusters))
pca_profiles = []

for cl in range(n_clusters):
    profile = diurnal_norm[cl].reshape(n_sites, n_hours)
    mean_profile = profile.mean(axis=0)
    pca_profiles.append(mean_profile)

pca_profiles = np.array(pca_profiles)
pca = PCA(n_components=1)
pca_components = pca.fit_transform(pca_profiles)
for i in range(n_clusters):
    for j in range(n_clusters):
        diurnal_pca_matrix[i, j] = np.abs(pca_components[i] - pca_components[j])

# === Compute EMD distance matrix ===
emd_matrix = np.zeros((n_clusters, n_clusters))
hour_bins = np.arange(n_hours)
for i in range(n_clusters):
    for j in range(n_clusters):
        emd = 0
        for site in range(n_sites):
            p = diurnal_norm[i, site]
            q = diurnal_norm[j, site]
            emd += wasserstein_distance(hour_bins, hour_bins, u_weights=p, v_weights=q)
        emd_matrix[i, j] = emd / n_sites

# === Merge cluster logic and plotting ===
def merge_clusters(js_matrix, pca_matrix, emd_matrix, js_th, pca_th, emd_th, mode="pca"):
    merged_pairs = []
    for i, j in combinations(range(n_clusters), 2):
        if js_matrix[i, j] < js_th:
            if mode == "pca" and pca_matrix[i, j] < pca_th:
                merged_pairs.append((i, j))
            elif mode == "emd" and emd_matrix[i, j] < emd_th:
                merged_pairs.append((i, j))
    return merged_pairs

def build_and_plot_graph(merged_pairs, title, save_path):
    G = nx.Graph()
    G.add_nodes_from(range(n_clusters))
    G.add_edges_from(merged_pairs)

    pos = nx.spring_layout(G, seed=42)
    plt.figure(figsize=(8, 6))
    nx.draw(G, pos, with_labels=True, node_color="skyblue", edge_color="gray", node_size=800, font_size=14)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
# Extract upper triangle values
def get_upper_triangular_values(matrix):
    return matrix[np.triu_indices(n_clusters, k=1)]

# Compute min, max, and percentile thresholds
def compute_threshold_range(matrix, name, percentile, num):
    values = get_upper_triangular_values(matrix)
    min_val = np.min(values)
    max_val = np.max(values)
    percentile_50 = np.percentile(values, percentile)
    thresholds = np.linspace(min_val, percentile_50, num=num)
    return {
        "min": min_val,
        "max": max_val,
        "percentile_50": percentile_50,
        "thresholds": thresholds
    }

pca_threshold_info = compute_threshold_range(diurnal_pca_matrix, "PCA",50,10)
emd_threshold_info = compute_threshold_range(emd_matrix, "EMD",50,10)
js_threshold_info = compute_threshold_range(js_matrix, "JS",50,10)

# === Run merge logic for different thresholds ===
js_thresholds = js_threshold_info['thresholds']
pca_thresholds = pca_threshold_info['thresholds']
emd_thresholds = emd_threshold_info['thresholds']

results = []

for js_th in js_thresholds:
    for pca_th in pca_thresholds:
        merged_pca = merge_clusters(js_matrix, diurnal_pca_matrix, emd_matrix, js_th, pca_th, 0, mode="pca")
        print(f"[JS < {js_th}, PCA < {pca_th}] Merged Clusters: {merged_pca}")
        build_and_plot_graph(
            merged_pca,
            title=f"Merge Graph (JS<{js_th}, PCA<{pca_th})",
            save_path=f"/p/project1/exaww/chatterjee1/plots/continuous/merge_graph_js{js_th}_pca{pca_th}.png"
        )

    for emd_th in emd_thresholds:
        merged_emd = merge_clusters(js_matrix, diurnal_pca_matrix, emd_matrix, js_th, 0, emd_th, mode="emd")
        print(f"[JS < {js_th}, EMD < {emd_th}] Merged Clusters: {merged_emd}")
        build_and_plot_graph(
            merged_emd,
            title=f"Merge Graph (JS<{js_th}, EMD<{emd_th})",
            save_path=f"/p/project1/exaww/chatterjee1/plots/continuous/merge_graph_js{js_th}_emd{emd_th}.png"
        )

[JS < 0.009682720143900949, EMD < 0.39438671628546546] Merged Clusters: []
[JS < 0.009682720143900949, EMD < 0.6764462838743355] Merged Clusters: []
[JS < 0.009682720143900949, EMD < 0.9585058514632054] Merged Clusters: []
[JS < 0.009682720143900949, EMD < 1.2405654190520756] Merged Clusters: []
[JS < 0.009682720143900949, EMD < 1.5226249866409456] Merged Clusters: []


  plt.tight_layout()


[JS < 0.009682720143900949, EMD < 1.8046845542298156] Merged Clusters: []
[JS < 0.009682720143900949, EMD < 2.0867441218186857] Merged Clusters: []
[JS < 0.009682720143900949, EMD < 2.3688036894075557] Merged Clusters: []
[JS < 0.009682720143900949, EMD < 2.6508632569964257] Merged Clusters: []
[JS < 0.009682720143900949, EMD < 2.9329228245852956] Merged Clusters: []
[JS < 0.04268421155602693, EMD < 0.39438671628546546] Merged Clusters: []
[JS < 0.04268421155602693, EMD < 0.6764462838743355] Merged Clusters: [(0, 7)]
[JS < 0.04268421155602693, EMD < 0.9585058514632054] Merged Clusters: [(0, 7)]
[JS < 0.04268421155602693, EMD < 1.2405654190520756] Merged Clusters: [(0, 7)]
[JS < 0.04268421155602693, EMD < 1.5226249866409456] Merged Clusters: [(0, 7)]
[JS < 0.04268421155602693, EMD < 1.8046845542298156] Merged Clusters: [(0, 7)]
[JS < 0.04268421155602693, EMD < 2.0867441218186857] Merged Clusters: [(0, 7)]
[JS < 0.04268421155602693, EMD < 2.3688036894075557] Merged Clusters: [(0, 7)]
[JS

In [15]:
np.linspace(np.min(emd_matrix), np.percentile(emd_matrix, 50), num=10)

array([0.        , 0.29433812, 0.58867624, 0.88301436, 1.17735248,
       1.47169061, 1.76602873, 2.06036685, 2.35470497, 2.64904309])

In [17]:
np.linspace(np.min(diurnal_pca_matrix), np.percentile(diurnal_pca_matrix, 50), num=10)

array([0.        , 0.01235041, 0.02470083, 0.03705124, 0.04940165,
       0.06175207, 0.07410248, 0.08645289, 0.0988033 , 0.11115372])

In [18]:
pca_threshold_info['thresholds']

array([0.00428861, 0.01914   , 0.03399138, 0.04884276, 0.06369415,
       0.07854553, 0.09339691, 0.1082483 , 0.12309968, 0.13795107])

## Original 15 minutes

In [54]:
'''
# === Configuration ===
n_clusters = 10
n_bins = 96  # 15-minute intervals
n_sites = 11
bin_edges_C5 = np.arange(190, 255, 5)
bin_edges_C6 = np.arange(190, 275, 5)
bin_edges_C9 = np.arange(180, 325, 5)
bins = [bin_edges_C5, bin_edges_C6, bin_edges_C9]

obs_sites_ncvar_name = {
    0: ("juelich", "juelich"), 1: ("lindenberg", "lindenberg"), 2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"), 4: ("bourges", "bourges"), 5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"), 7: ("cabauw", "cabauw"), 8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"), 10: ("dresden", "dresden")
}
obs_cluster_sites = {
    0: "juelich", 1: "lin", 2: "warsaw", 3: "vienna", 4: "bourges", 5: "zargoza",
    6: "sirta", 7: "cabauw", 8: "nuremberg", 9: "aurillac", 10: "dresden"
}

# === Initialize containers ===
all_joint_histograms = [np.zeros((len(bin_edges_C5)-1, len(bin_edges_C6)-1, len(bin_edges_C9)-1)) for _ in range(n_clusters)]
cluster_diurnal = np.zeros((n_clusters, n_sites, n_bins))

# === Data Loop ===
for i in range(n_sites):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    if not os.path.exists(nc_file) or not os.path.exists(cluster_file):
        continue

    ds = xr.open_dataset(nc_file)
    c5 = ds[f'sample_{site_var}_data_5'].values.reshape(ds.dims['sample'], -1)
    c6 = ds[f'sample_{site_var}_data_6'].values.reshape(ds.dims['sample'], -1)
    c9 = ds[f'sample_{site_var}_data_9'].values.reshape(ds.dims['sample'], -1)
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))

    obs_datetimes = pd.to_datetime([str(t)[:12] for t in ds['time'].values], format="%Y%m%d%H%M")
    obs_datetimes_adj = obs_datetimes - pd.Timedelta(minutes=12)
    time_bins = obs_datetimes_adj.hour * 4 + obs_datetimes_adj.minute // 15

    for cl in range(n_clusters):
        mask = cluster_labels == cl
        if np.any(mask):
            x, y, z = c5[mask].ravel(), c6[mask].ravel(), c9[mask].ravel()
            xyz = np.stack((x, y, z), axis=1)
            hist, _ = np.histogramdd(xyz, bins=bins, density=True)
            all_joint_histograms[cl] += hist

            binned_counts = np.bincount(time_bins[mask], minlength=n_bins)
            cluster_diurnal[cl, i, :] = binned_counts
'''
# === Normalize ===
histograms_flat = [h.flatten() + 1e-10 for h in all_joint_histograms]
histograms_flat = [h / h.sum() for h in histograms_flat]
cluster_diurnal_norm = cluster_diurnal / cluster_diurnal.sum(axis=2, keepdims=True)

# === JS Divergence Matrix ===
js_matrix = np.zeros((n_clusters, n_clusters))
for i, j in combinations(range(n_clusters), 2):
    js_val = jensenshannon(histograms_flat[i], histograms_flat[j], base=2) ** 2
    js_matrix[i, j] = js_matrix[j, i] = js_val

# === PCA Matrix ===
pca_profiles = np.array([cluster_diurnal_norm[cl].mean(axis=0) for cl in range(n_clusters)])
pca = PCA(n_components=1)
pca_components = pca.fit_transform(pca_profiles).flatten()
diurnal_pca_matrix = np.abs(pca_components[:, None] - pca_components[None, :])

# === EMD Matrix ===
bin_pos = np.arange(n_bins)
emd_matrix = np.zeros((n_clusters, n_clusters))
for i, j in combinations(range(n_clusters), 2):
    emd = np.mean([
        wasserstein_distance(bin_pos, bin_pos, u_weights=cluster_diurnal_norm[i, site], v_weights=cluster_diurnal_norm[j, site])
        for site in range(n_sites)
    ])
    emd_matrix[i, j] = emd_matrix[j, i] = emd

# === Merge + Plot Functions ===
def get_upper_tri(matrix): 
    return matrix[np.triu_indices_from(matrix, k=1)]

def threshold_range(matrix, percentile=50, num=15):
    upper = get_upper_tri(matrix)
    return np.linspace(upper.min(), np.percentile(upper, percentile), num=num)

def merge_clusters(js, pca, emd, js_th, pca_th, emd_th, mode="pca"):
    merged = []
    for i, j in combinations(range(n_clusters), 2):
        if js[i, j] < js_th:
            if mode == "pca" and pca[i, j] < pca_th:
                merged.append((i, j))
            elif mode == "emd" and emd[i, j] < emd_th:
                merged.append((i, j))
            elif mode == "all" and pca[i, j] < pca_th and emd[i, j] < emd_th:
                merged.append((i, j))
    return merged
'''
def plot_merge_graph(pairs, title, path):
    G = nx.Graph()
    G.add_nodes_from(range(n_clusters))
    G.add_edges_from(pairs)
    pos = nx.spring_layout(G, seed=42)
    plt.figure(figsize=(8, 6))
    nx.draw(G, pos, with_labels=True, node_color="skyblue", edge_color="gray", node_size=800)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(path)
    plt.close()
'''    
def plot_merge_graph(pairs, title, path):

    G = nx.Graph()
    G.add_nodes_from(range(n_clusters))
    G.add_edges_from(pairs)

    pos = nx.spring_layout(G, seed=42)

    plt.figure(figsize=(8, 6))

    # Draw nodes with edge colors and borders
    nx.draw_networkx_nodes(
        G, pos,
        node_size=800,
        node_color="#A2C4C9",
        edgecolors="black",
        linewidths=1.2
    )

    # Draw curved edges with transparency
    arc_rad = 0.2
    for edge in pairs:
        nx.draw_networkx_edges(
            G, pos,
            edgelist=[edge],
            width=2,
            edge_color="#555555",
            alpha=0.7,
            connectionstyle=f'arc3,rad={arc_rad}'
        )

    # Draw text labels
    nx.draw_networkx_labels(
        G, pos,
        font_size=14,
        font_color="black"
    )

    # Final plot formatting
    plt.title(title, fontsize=16)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(path, dpi=300)
    plt.close()

# === Threshold Sweeps ===
js_ths = threshold_range(js_matrix)
pca_ths = threshold_range(diurnal_pca_matrix)
emd_ths = threshold_range(emd_matrix)



for js_th in js_ths:
    for pca_th in pca_ths:
        merged = merge_clusters(js_matrix, diurnal_pca_matrix, emd_matrix, js_th, pca_th, 0, mode="pca")
        print(f"[JS<{js_th:.3f}, PCA<{pca_th:.3f}] Merged: {merged}")
        #if js_th<0.307 and pca_th<0.009:
        #    plot_merge_graph(merged, f"JS<{js_th:.3f}, PCA<{pca_th:.3f}", f"/p/project1/exaww/chatterjee1/plots/continuous/js{js_th:.3f}_pca{pca_th:.3f}.png")
        #    print(f"[JS<{js_th:.3f}, PCA<{pca_th:.3f}] Merged: {merged}")
    for emd_th in emd_ths:
        merged = merge_clusters(js_matrix, diurnal_pca_matrix, emd_matrix, js_th, 0, emd_th, mode="emd")
        print(f"[JS<{js_th:.3f}, EMD<{emd_th:.3f}] Merged: {merged}")
        plot_merge_graph(merged, f"JS<{js_th:.3f}, EMD<{emd_th:.3f}", f"/p/project1/exaww/chatterjee1/plots/continuous/js{js_th:.3f}_emd{emd_th:.3f}.png")
'''        
# === Full Agreement Sweep ===
for js_th in js_ths:
    for pca_th in pca_ths:
        for emd_th in emd_ths:
            merged_all = merge_clusters(js_matrix, diurnal_pca_matrix, emd_matrix, js_th, pca_th, emd_th, mode="all")
            print(f"[JS<{js_th:.3f}, PCA<{pca_th:.3f}, EMD<{emd_th:.3f}] Merged: {merged_all}")
            plot_merge_graph(
                merged_all,
                title=f"Merge Graph (ALL: JS<{js_th:.3f}, PCA<{pca_th:.3f}, EMD<{emd_th:.3f})",
                path=f"/p/project1/exaww/chatterjee1/plots/continuous/merge_all_js{js_th:.3f}_pca{pca_th:.3f}_emd{emd_th:.3f}.png"
            )
'''

[JS<0.010, PCA<0.004] Merged: []
[JS<0.010, PCA<0.009] Merged: []
[JS<0.010, PCA<0.013] Merged: []
[JS<0.010, PCA<0.018] Merged: []
[JS<0.010, PCA<0.023] Merged: []
[JS<0.010, PCA<0.027] Merged: []
[JS<0.010, PCA<0.032] Merged: []
[JS<0.010, PCA<0.037] Merged: []
[JS<0.010, PCA<0.041] Merged: []
[JS<0.010, PCA<0.046] Merged: []
[JS<0.010, PCA<0.051] Merged: []
[JS<0.010, PCA<0.055] Merged: []
[JS<0.010, PCA<0.060] Merged: []
[JS<0.010, PCA<0.065] Merged: []
[JS<0.010, PCA<0.069] Merged: []
[JS<0.010, EMD<1.544] Merged: []
[JS<0.010, EMD<2.309] Merged: []
[JS<0.010, EMD<3.075] Merged: []
[JS<0.010, EMD<3.840] Merged: []
[JS<0.010, EMD<4.606] Merged: []
[JS<0.010, EMD<5.371] Merged: []
[JS<0.010, EMD<6.137] Merged: []
[JS<0.010, EMD<6.902] Merged: []
[JS<0.010, EMD<7.668] Merged: []
[JS<0.010, EMD<8.433] Merged: []
[JS<0.010, EMD<9.199] Merged: []
[JS<0.010, EMD<9.965] Merged: []
[JS<0.010, EMD<10.730] Merged: []
[JS<0.010, EMD<11.496] Merged: []
[JS<0.010, EMD<12.261] Merged: []
[JS<0.0

'        \n# === Full Agreement Sweep ===\nfor js_th in js_ths:\n    for pca_th in pca_ths:\n        for emd_th in emd_ths:\n            merged_all = merge_clusters(js_matrix, diurnal_pca_matrix, emd_matrix, js_th, pca_th, emd_th, mode="all")\n            print(f"[JS<{js_th:.3f}, PCA<{pca_th:.3f}, EMD<{emd_th:.3f}] Merged: {merged_all}")\n            plot_merge_graph(\n                merged_all,\n                title=f"Merge Graph (ALL: JS<{js_th:.3f}, PCA<{pca_th:.3f}, EMD<{emd_th:.3f})",\n                path=f"/p/project1/exaww/chatterjee1/plots/continuous/merge_all_js{js_th:.3f}_pca{pca_th:.3f}_emd{emd_th:.3f}.png"\n            )\n'

In [23]:
# === Mean and spread across sites ===
mean_diurnal = cluster_diurnal_norm.mean(axis=1)
std_diurnal = cluster_diurnal_norm.std(axis=1)

# === Plotting ===
fig, axes = plt.subplots(2, 5, figsize=(22, 8), sharex=True, sharey=True)
axes = axes.flatten()

x = np.arange(96)

for cl in range(n_clusters):
    ax = axes[cl]
    mean = mean_diurnal[cl]
    std = std_diurnal[cl]
    
    ax.plot(x, mean, color='tab:blue', label='Mean')
    ax.fill_between(x, mean - std, mean + std, color='tab:blue', alpha=0.3, label='±1 STD')
    
    ax.set_title(f'Cluster {cl}')
    #ax.set_xlim([0, 95])
    #ax.set_xticks(np.linspace(0, 95, 9))  # Label every 3 hours
    #ax.set_xticklabels(pd.date_range("00:00", "24:00", freq="3H").strftime("%H:%M"), rotation=45)
    
    xtick_pos = np.arange(0, 96, 12)  # 0, 12, 24, ..., 84
    xtick_labels = [f"{h:02d}:00" for h in range(0, 24, 3)]  # 8 labels
    ax.set_xticks(xtick_pos)
    ax.set_xticklabels(xtick_labels, rotation=45)
    ax.set_xticks(xtick_pos)
    ax.set_xticklabels(xtick_labels, rotation=45)
    ax.set_ylabel("Norm. Frequency")
    
    ax.grid(True, which='both', axis='both', linestyle='--', alpha=0.4)

fig.suptitle('Diurnal Cycle of Clusters (hourly Bins, Obs Across Sites)', fontsize=18)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/diurnal_mergingcheck_15min.png", dpi=300)
plt.show()

In [26]:
all_joint_histograms_ = np.array(all_joint_histograms)
cluster_diurnal_ = np.array(cluster_diurnal)
all_joint_histograms_.shape, cluster_diurnal_.shape

((10, 12, 16, 28), (10, 11, 96))

In [27]:
histograms_flat_ = np.array(histograms_flat)
cluster_diurnal_norm_ = np.array(cluster_diurnal_norm)
histograms_flat_.shape, cluster_diurnal_norm_.shape

((10, 5376), (10, 11, 96))

In [28]:
pca_profiles.shape

(10, 96)

In [36]:
# === Plot heatmaps ===
plt.figure(figsize=(10, 8))
sns.heatmap(js_matrix, annot=True, fmt=".3f", cmap="viridis")
plt.title("JS Divergence Between Clusters (Joint 6.2, 7.3, 10.8 µm)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/js_divergence_merged_C5C6C9.png")

In [44]:
# Define cluster count and merged pairs
n_clusters = 10
merged_pairs = [(0, 7), (1, 5), (2, 3)]

# Create graph and add nodes and edges
G = nx.Graph()
G.add_nodes_from(range(n_clusters))
G.add_edges_from(merged_pairs)

# Use spring layout for aesthetics
pos = nx.spring_layout(G, seed=42)

# Begin plotting
plt.figure(figsize=(8, 6))

# Draw nodes
nx.draw_networkx_nodes(G, pos, node_size=800, node_color="#A2C4C9", edgecolors="black", linewidths=1.2)

# Draw curved edges with alpha blending
arc_rad = 0.2  # curvature
for edge in merged_pairs:
    nx.draw_networkx_edges(
        G, pos, edgelist=[edge],
        width=2, edge_color="#555555",
        alpha=0.7, connectionstyle=f'arc3,rad={arc_rad}'
    )

# Draw labels
nx.draw_networkx_labels(G, pos, font_size=14, font_color="black")

# Final touches
plt.title("Merged Cluster Graph", fontsize=16)
plt.axis("off")
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/merged_decision.png", dpi=300)