In [62]:
import xarray as xr
import numpy as np
import pandas as pd
import glob
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
#import cartopy.crs as ccrs
#import cartopy.feature as cfeature
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import torch
from collections import defaultdict
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
import os
import seaborn as sns

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D


### Hourly OBS HIST

In [12]:
ds = xr.open_dataset("/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_cabauw_allchannelcrops.nc")
cluster_data = torch.load("/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_cabauw_cluster_10_labels.pth", map_location="cpu")

model_ds = xr.open_dataset("/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_cabauw_allchannelcrops.ncicon_cabauw_WV_IR_crops.nc")
model_cluster_data = torch.load("/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/" + "icon_cabauw_cluster_10_labels.pth", map_location="cpu")

# Extract brightness temperature and time
bt_data = ds['sample_cabauw_data_9'].values  # shape: (sample, h, w)
raw_times = ds['time'].values
obs_datetimes = pd.to_datetime([str(t)[:12] for t in raw_times], format="%Y%m%d%H%M")

# Match closest timestamp to each full hourly timestamp (±15 min)
hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'),
                             end=obs_datetimes.max().ceil('H'),
                             freq='H')

matched_indices = []
for h in hourly_times:
    time_diffs = np.abs((obs_datetimes - h).total_seconds())
    min_idx = np.argmin(time_diffs)
    if time_diffs[min_idx] <= 900:  # 15 minutes = 900 seconds
        matched_indices.append(min_idx)

# Subset the data
bt_data_hourly = bt_data[matched_indices]
cluster_data_hourly = np.array(cluster_data)[matched_indices]

# Bin settings
bin_edges = np.arange(203, 311, 10)  # Bins: 203–210, 210–220, ..., 300–310
n_clusters = 10

# Plot
fig, axes = plt.subplots(2, 5, figsize=(20, 8), sharex=True, sharey=True)
axes = axes.flatten()

for cluster_idx in range(n_clusters):
    ax = axes[cluster_idx]
    
    # Mask samples for this cluster
    cluster_mask = cluster_data_hourly == cluster_idx
    cluster_samples = bt_data_hourly[cluster_mask]
    
    # Flatten all pixels into 1D
    flattened_bt = cluster_samples.reshape(-1)
    
    # Plot histogram
    ax.hist(flattened_bt, bins=bin_edges, color='skyblue', edgecolor='black')
    ax.set_title(f'Cluster {cluster_idx}')
    ax.set_xlabel('BT (K)')
    ax.set_ylabel('Frequency')

fig.suptitle('Brightness Temperature Histograms (10.8µm) by Cluster', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/histcluster_cabauw_obs.png", dpi=100, bbox_inches="tight")

In [7]:
ds

In [11]:
ds['sample_cabauw_data_9'].min().values, ds['sample_cabauw_data_9'].max().values

(array(203.57324219), array(308.64157104))

In [14]:
model_ds = xr.open_dataset("/p/project/exaww/chatterjee1/dataset/warmworld_datasets/icon_cabauw_WV_IR_crops.nc")
model_ds

### Hourly OBS + ICON HIST

In [47]:

place = 'cabauw'

# ==== Load OBSERVATION data ====
obs_ds = xr.open_dataset("/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_"+ place +"_allchannelcrops.nc")
obs_bt = obs_ds['sample_'+ place +'_data_9'].values  # (sample, h, w)
obs_raw_times = obs_ds['time'].values
obs_datetimes = pd.to_datetime([str(t)[:12] for t in obs_raw_times], format="%Y%m%d%H%M")
obs_cluster_labels = torch.load(
    "/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_"+ place +"_cluster_10_labels.pth",
    map_location="cpu"
)

# ==== Load MODEL data ====
model_ds = xr.open_dataset("/p/project/exaww/chatterjee1/dataset/warmworld_datasets/icon_"+ place +"_WV_IR_crops.nc")
model_bt = model_ds[place +'_data_IR'].values  # (sample, h, w)
model_times = model_ds["time"].values
model_datetimes = pd.to_datetime(model_times)
model_cluster_labels = torch.load(
    "/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_"+ place +"_cluster_10_labels.pth",
    map_location="cpu"
)

# ==== Match timestamps to full hourly timestamps (±15min tolerance) ====
# Define common hourly grid
hourly_times = pd.date_range(start=max(obs_datetimes.min(), model_datetimes.min()).floor('H'),
                             end=min(obs_datetimes.max(), model_datetimes.max()).ceil('H'),
                             freq='H')

# Helper: match closest within ±15 min
def match_closest(datetimes, reference_times):
    matched = []
    for ref in reference_times:
        diffs = np.abs((datetimes - ref).total_seconds())
        min_idx = np.argmin(diffs)
        if diffs[min_idx] <= 900:  # ±15 min
            matched.append(min_idx)
    return matched

# Match OBS data to nearest hourly times (±15 min)
obs_matched_idx = match_closest(obs_datetimes, hourly_times)
obs_bt_hourly = obs_bt[obs_matched_idx]
obs_clusters_hourly = np.array(obs_cluster_labels)[obs_matched_idx]

# Model data is already hourly — exact match
model_mask = model_datetimes.isin(hourly_times)
model_bt_hourly = model_bt[model_mask]
model_clusters_hourly = np.array(model_cluster_labels)[model_mask]

# ==== Plot settings ====
bin_edges = np.arange(203, 311, 10)
n_clusters = 10

fig, axes = plt.subplots(2, 5, figsize=(20, 8), sharex=True, sharey=True)
axes = axes.flatten()

for cluster_idx in range(n_clusters):
    ax = axes[cluster_idx]

    # OBS samples for this cluster
    obs_mask = obs_clusters_hourly == cluster_idx
    obs_samples = obs_bt_hourly[obs_mask].reshape(-1)

    # MODEL samples for this cluster
    model_mask = model_clusters_hourly == cluster_idx
    model_samples = model_bt_hourly[model_mask].reshape(-1)
    model_samples_orig = model_bt_hourly[model_mask]
    if cluster_idx == 1:
        print(model_samples_orig.shape)

    # Plot histograms
    ax.hist(obs_samples, bins=bin_edges, color='skyblue', edgecolor='black', alpha=0.6, label='Observation',density=True)
    ax.hist(model_samples, bins=bin_edges, color='tomato', edgecolor='black', alpha=0.6, label='Model',density=True)

    ax.set_title(f'Cluster {cluster_idx}')
    ax.set_xlabel('BT (K)')
    ax.set_ylabel('Frequency')

# Add legend only to one axis to avoid clutter
axes[0].legend(loc='upper right')

fig.suptitle('BT (10.8µm) Histogram: '+ place , fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/histcluster_"+ place +"_obs_icon.png", dpi=100, bbox_inches="tight")


(2, 78, 78)


## hourly OBS + ICON all sites combined

In [91]:

# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

n_clusters = 10
bin_edges = np.arange(180, 320, 5)
obs_all_bt = [[] for _ in range(n_clusters)]
model_all_bt = [[] for _ in range(n_clusters)]

def match_closest(datetimes, reference_times):
    matched = []
    for ref in reference_times:
        diffs = np.abs((datetimes - ref).total_seconds())
        min_idx = np.argmin(diffs)
        if diffs[min_idx] <= 900:
            matched.append(min_idx)
    return matched

for i in range(11):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
    cluster_id_icon = icon_cluster_sites[i]

    obs_nc = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    obs_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    model_nc = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/icon_{site_var_icon}_WV_IR_crops.nc"  # icon_"+ place +"_WV_IR_crops.nc"
    model_cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth"

    if not os.path.exists(obs_nc) or not os.path.exists(obs_cluster_file) or not os.path.exists(model_nc) or not os.path.exists(model_cluster_file):
        continue

    # === Load OBS ===
    ds = xr.open_dataset(obs_nc)
    obs_bt = ds[f"sample_{site_var}_data_9"].values
    obs_datetimes = pd.to_datetime([str(t)[:12] for t in ds["time"].values], format="%Y%m%d%H%M")
    obs_cluster_labels = np.array(torch.load(obs_cluster_file, map_location="cpu"))

    # === Load MODEL ===
    ds_model = xr.open_dataset(model_nc)
    model_bt = ds_model[f"{site_var_icon}_data_IR"].values  # model_ds[place +'_data_IR'].values
    model_datetimes = pd.to_datetime(ds_model["time"].values)
    model_cluster_labels = np.array(torch.load(model_cluster_file, map_location="cpu"))

    hourly_times = pd.date_range(start=max(obs_datetimes.min(), model_datetimes.min()).floor('H'),
                                 end=min(obs_datetimes.max(), model_datetimes.max()).ceil('H'),
                                 freq='H')

    obs_idx = match_closest(obs_datetimes, hourly_times)
    obs_bt_hourly = obs_bt[obs_idx]
    obs_clusters_hourly = obs_cluster_labels[obs_idx]

    model_mask = model_datetimes.isin(hourly_times)
    model_bt_hourly = model_bt[model_mask]
    model_clusters_hourly = model_cluster_labels[model_mask]

    for cl in range(n_clusters):
        obs_mask = obs_clusters_hourly == cl
        model_mask = model_clusters_hourly == cl

        obs_bt_cluster = obs_bt_hourly[obs_mask].reshape(-1)
        model_bt_cluster = model_bt_hourly[model_mask].reshape(-1)

        obs_all_bt[cl].append(obs_bt_cluster)
        model_all_bt[cl].append(model_bt_cluster)

# === Plotting ===
fig, axes = plt.subplots(2, 5, figsize=(20, 8), sharex=True, sharey=True)
axes = axes.flatten()

bin_centers = bin_edges[:-1] + np.diff(bin_edges) / 2

for cluster_idx in range(n_clusters):
    ax = axes[cluster_idx]

    obs_samples = np.concatenate(obs_all_bt[cluster_idx]) if obs_all_bt[cluster_idx] else np.array([])
    model_samples = np.concatenate(model_all_bt[cluster_idx]) if model_all_bt[cluster_idx] else np.array([])

    ax.hist(obs_samples, bins=bin_edges, color='skyblue', edgecolor='black', alpha=0.6, label='Observation', density=True)
    ax.hist(model_samples, bins=bin_edges, color='tomato', edgecolor='black', alpha=0.6, label='Model', density=True)

    ax.set_title(f'Cluster {cluster_idx}')
    ax.set_xlabel('BT (K)')
    ax.set_ylabel('Frequency')
    ax.grid(True, linestyle='--', alpha=0.3)

axes[0].legend(loc='upper right')
fig.suptitle('BT (10.8µm) Histogram: Combined All Sites', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/allsite_combined_histogram_obs_model_clusterwise.png", dpi=300)
plt.show()

### Cluster wise KL and JS Divergence on all time stamps of OBS 

In [87]:
channel = 9

# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Bin settings ===
bin_edges = np.arange(180,320,5) #for 6.2 or channel_5 = (190,250,5) || for 7.3 or channel_6 = (190,270,5) || for 10.8 or channel_9 = (180, 320, 5) || for channel_10 i.e 12.0 um = (320,180,5)
n_bins = len(bin_edges) - 1
n_clusters = 10
all_bt = [[] for _ in range(n_clusters)]

# === Load data and aggregate BT values by cluster ===
for i in range(11):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    if not os.path.exists(nc_file) or not os.path.exists(cluster_file):
        print(f"Missing: {site_nc}")
        continue
    
    ds = xr.open_dataset(nc_file)
    bt = ds[f'sample_{site_var}_data_{channel}'].values  # shape: (samples, h, w)
    print(bt.shape, bt.max(), bt.min())
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))
    
    for cl in range(n_clusters):
        mask = cluster_labels == cl
        if np.any(mask):
            all_bt[cl].append(bt[mask].reshape(-1))

# === Histograms ===
histograms = []
for cl_bt in all_bt:
    if cl_bt:
        all_pixels = np.concatenate(cl_bt)
        hist, _ = np.histogram(all_pixels, bins=bin_edges, density=True)
    else:
        hist = np.zeros(n_bins)
    histograms.append(hist)

histograms = np.array(histograms)

# === JS and KL divergence ===
js_matrix = np.zeros((n_clusters, n_clusters))
kl_matrix = np.zeros((n_clusters, n_clusters))

for i in range(n_clusters):
    for j in range(n_clusters):
        P = histograms[i] + 1e-10
        Q = histograms[j] + 1e-10
        P /= P.sum()
        Q /= Q.sum()
        js_matrix[i, j] = jensenshannon(P, Q, base=2) ** 2
        kl_matrix[i, j] = np.sum(rel_entr(P, Q))

# === Plot heatmaps ===
plt.figure(figsize=(10, 8))
sns.heatmap(js_matrix, annot=True, fmt=".3f", cmap="viridis")
plt.title("JS Divergence Between Clusters (Obs)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/js_divergence_obs_clusters_C{channel}.png")

plt.figure(figsize=(10, 8))
sns.heatmap(kl_matrix, annot=True, fmt=".2f", cmap="magma", norm=None)
plt.title("KL Divergence Between Clusters (Obs)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/kl_divergence_obs_clusters_C{channel}.png")

# Plot 2x5 histogram subplots for the 10 clusters using the computed histograms
fig, axes = plt.subplots(2, 5, figsize=(20, 8), sharex=True, sharey=True)
axes = axes.flatten()

for cluster_idx in range(n_clusters):
    ax = axes[cluster_idx]
    ax.bar(
        x=bin_edges[:-1] + 2.5,  # center of each bin
        height=histograms[cluster_idx],
        width=5,
        color='skyblue',
        edgecolor='black'
    )
    ax.set_title(f'Cluster {cluster_idx}')
    ax.set_xlabel('BT (K)')
    ax.set_ylabel('Normalized Frequency')

fig.suptitle('BT (10.8µm) Histogram for Each Cluster (All Sites, All Timestamps)', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/histcluster_allsites_obs_C{channel}.png")

(18567, 128, 128) 308.64157 201.1547
(18567, 128, 128) 308.07324 179.1197
(18567, 128, 128) 306.00558 180.0001
(18567, 128, 128) 309.43298 181.69418
(18567, 128, 128) 313.864 196.97769
(18567, 128, 128) 324.54245 178.21509
(18567, 128, 128) 312.5505 196.97769
(18567, 128, 128) 308.64157 203.57324
(18567, 128, 128) 308.07324 196.97769
(18567, 128, 128) 316.23877 132.62302
(18567, 128, 128) 308.07324 180.8578


In [52]:
histograms.max()

0.14001186658423803

## Checking channel 1 i.e VIS 0.6 um min max

In [13]:
ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_juelich_allchannelcrops.nc')
print(ds.sample_juelich_data_1.max().values, '|', ds.sample_juelich_data_1.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_lindenberg_allchannelcrops.nc')
print(ds.sample_lindenberg_data_1.max().values, '|', ds.sample_lindenberg_data_1.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_warsaw_allchannelcrops.nc')
print(ds.sample_warsaw_data_1.max().values, '|', ds.sample_warsaw_data_1.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_vienna_allchannelcrops.nc')
print(ds.sample_vienna_data_1.max().values, '|', ds.sample_vienna_data_1.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_bourges_allchannelcrops.nc')
print(ds.sample_bourges_data_1.max().values, '|', ds.sample_bourges_data_1.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_zaragoza_allchannelcrops.nc')
print(ds.sample_zaragoza_data_1.max().values, '|', ds.sample_zaragoza_data_1.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_sirta_allchannelcrops.nc')
print(ds.sample_sirta_data_1.max().values, '|', ds.sample_sirta_data_1.min().values)

244.31982421875 | 0.0
382.85687255859375 | -4.438074111938477
1457.006103515625 | -5.87713623046875
746.41357421875 | -1.8636250495910645
1150.609619140625 | 0.0
1711.7685546875 | -5.101034641265869
1858.959228515625 | 0.0


## Checking channel 5 i.e 6.2 um min max

In [14]:
ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_juelich_allchannelcrops.nc')
print(ds.sample_juelich_data_5.max().values, '|', ds.sample_juelich_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_lindenberg_allchannelcrops.nc')
print(ds.sample_lindenberg_data_5.max().values, '|', ds.sample_lindenberg_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_warsaw_allchannelcrops.nc')
print(ds.sample_warsaw_data_5.max().values, '|', ds.sample_warsaw_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_vienna_allchannelcrops.nc')
print(ds.sample_vienna_data_5.max().values, '|', ds.sample_vienna_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_bourges_allchannelcrops.nc')
print(ds.sample_bourges_data_5.max().values, '|', ds.sample_bourges_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_zaragoza_allchannelcrops.nc')
print(ds.sample_zaragoza_data_5.max().values, '|', ds.sample_zaragoza_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_sirta_allchannelcrops.nc')
print(ds.sample_sirta_data_5.max().values, '|', ds.sample_sirta_data_5.min().values)

249.46510314941406 | 202.43875122070312
248.96591186523438 | 199.27882385253906
247.32826232910156 | 198.9933319091797
251.13706970214844 | 200.1131591796875
250.1787567138672 | 200.1131591796875
252.1121368408203 | 198.11294555664062
250.1787567138672 | 200.1131591796875


In [16]:
ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_juelich_allchannelcrops.nc')
print(ds.sample_juelich_data_5.max().values, '|', ds.sample_juelich_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_lindenberg_allchannelcrops.nc')
print(ds.sample_lindenberg_data_5.max().values, '|', ds.sample_lindenberg_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_warsaw_allchannelcrops.nc')
print(ds.sample_warsaw_data_5.max().values, '|', ds.sample_warsaw_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_vienna_allchannelcrops.nc')
print(ds.sample_vienna_data_5.max().values, '|', ds.sample_vienna_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_bourges_allchannelcrops.nc')
print(ds.sample_bourges_data_5.max().values, '|', ds.sample_bourges_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_zaragoza_allchannelcrops.nc')
print(ds.sample_zaragoza_data_5.max().values, '|', ds.sample_zaragoza_data_5.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_sirta_allchannelcrops.nc')
print(ds.sample_sirta_data_5.max().values, '|', ds.sample_sirta_data_5.min().values)

249.46510314941406 | 202.43875122070312
248.96591186523438 | 199.27882385253906
247.32826232910156 | 198.9933319091797
251.13706970214844 | 200.1131591796875
250.1787567138672 | 200.1131591796875
252.1121368408203 | 198.11294555664062
250.1787567138672 | 200.1131591796875


## Checking channel 9 i.e 10.8 um min max

In [28]:
ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_juelich_allchannelcrops.nc')
print(ds.sample_juelich_data_9.max().values, '|', ds.sample_juelich_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_lindenberg_allchannelcrops.nc')
print(ds.sample_lindenberg_data_9.max().values, '|', ds.sample_lindenberg_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_warsaw_allchannelcrops.nc')
print(ds.sample_warsaw_data_9.max().values, '|', ds.sample_warsaw_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_vienna_allchannelcrops.nc')
print(ds.sample_vienna_data_9.max().values, '|', ds.sample_vienna_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_bourges_allchannelcrops.nc')
print(ds.sample_bourges_data_9.max().values, '|', ds.sample_bourges_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_zaragoza_allchannelcrops.nc')
print(ds.sample_zaragoza_data_9.max().values, '|', ds.sample_zaragoza_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_sirta_allchannelcrops.nc')
print(ds.sample_sirta_data_9.max().values, '|', ds.sample_sirta_data_9.min().values)

308.6415710449219 | 201.15469360351562
308.0732421875 | 179.1197052001953
306.0055847167969 | 180.00010681152344
309.4329833984375 | 181.69418334960938
313.864013671875 | 196.97769165039062
324.5424499511719 | 178.215087890625
312.5505065917969 | 196.97769165039062


In [29]:
ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_juelich_allchannelcrops.nc')
print(ds.sample_juelich_data_9.max().values, '|', ds.sample_juelich_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_lindenberg_allchannelcrops.nc')
print(ds.sample_lindenberg_data_9.max().values, '|', ds.sample_lindenberg_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_warsaw_allchannelcrops.nc')
print(ds.sample_warsaw_data_9.max().values, '|', ds.sample_warsaw_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_vienna_allchannelcrops.nc')
print(ds.sample_vienna_data_9.max().values, '|', ds.sample_vienna_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_bourges_allchannelcrops.nc')
print(ds.sample_bourges_data_9.max().values, '|', ds.sample_bourges_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_zaragoza_allchannelcrops.nc')
print(ds.sample_zaragoza_data_9.max().values, '|', ds.sample_zaragoza_data_9.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_sirta_allchannelcrops.nc')
print(ds.sample_sirta_data_9.max().values, '|', ds.sample_sirta_data_9.min().values)

308.6415710449219 | 201.15469360351562
308.0732421875 | 179.1197052001953
306.0055847167969 | 180.00010681152344
309.4329833984375 | 181.69418334960938
313.864013671875 | 196.97769165039062
324.5424499511719 | 178.215087890625
312.5505065917969 | 196.97769165039062


## Checking channel 10 i.e 12.0 um min max

In [30]:
ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_juelich_allchannelcrops.nc')
print(ds.sample_juelich_data_10.max().values, '|', ds.sample_juelich_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_lindenberg_allchannelcrops.nc')
print(ds.sample_lindenberg_data_10.max().values, '|', ds.sample_lindenberg_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_warsaw_allchannelcrops.nc')
print(ds.sample_warsaw_data_10.max().values, '|', ds.sample_warsaw_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_vienna_allchannelcrops.nc')
print(ds.sample_vienna_data_10.max().values, '|', ds.sample_vienna_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_bourges_allchannelcrops.nc')
print(ds.sample_bourges_data_10.max().values, '|', ds.sample_bourges_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_zaragoza_allchannelcrops.nc')
print(ds.sample_zaragoza_data_10.max().values, '|', ds.sample_zaragoza_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/msgobs_sirta_allchannelcrops.nc')
print(ds.sample_sirta_data_10.max().values, '|', ds.sample_sirta_data_10.min().values)

306.499755859375 | 201.58216857910156
305.1568603515625 | 192.33349609375
302.4330139160156 | 190.77215576171875
306.7425842285156 | 193.84230041503906
311.16510009765625 | 197.18199157714844
322.0851135253906 | 190.77215576171875
309.5069885253906 | 197.18199157714844


In [31]:
ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_juelich_allchannelcrops.nc')
print(ds.sample_juelich_data_10.max().values, '|', ds.sample_juelich_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_lindenberg_allchannelcrops.nc')
print(ds.sample_lindenberg_data_10.max().values, '|', ds.sample_lindenberg_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_warsaw_allchannelcrops.nc')
print(ds.sample_warsaw_data_10.max().values, '|', ds.sample_warsaw_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_vienna_allchannelcrops.nc')
print(ds.sample_vienna_data_10.max().values, '|', ds.sample_vienna_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_bourges_allchannelcrops.nc')
print(ds.sample_bourges_data_10.max().values, '|', ds.sample_bourges_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_zaragoza_allchannelcrops.nc')
print(ds.sample_zaragoza_data_10.max().values, '|', ds.sample_zaragoza_data_10.min().values)

ds = xr.open_dataset('/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_sirta_allchannelcrops.nc')
print(ds.sample_sirta_data_10.max().values, '|', ds.sample_sirta_data_10.min().values)

306.499755859375 | 201.58216857910156
305.1568603515625 | 192.33349609375
302.4330139160156 | 190.77215576171875
306.7425842285156 | 193.84230041503906
311.16510009765625 | 197.18199157714844
322.0851135253906 | 190.77215576171875
309.5069885253906 | 197.18199157714844


## 3D histogram

In [75]:
# === Obs site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lindenberg", "lindenberg"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg","nuremberg"),
    9: ("aurillac","aurillac"),
    10: ("dresden","dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# === Bin edges for 3 channels ===
bin_edges_C5 = np.arange(190, 255, 5)   # 6.2 µm
bin_edges_C6 = np.arange(190, 275, 5)   # 7.3 µm
bin_edges_C9 = np.arange(180, 325, 5)   # 10.8 µm

bins = [bin_edges_C5, bin_edges_C6, bin_edges_C9]
n_clusters = 10
all_joint_histograms = [np.zeros((len(bin_edges_C5)-1, len(bin_edges_C6)-1, len(bin_edges_C9)-1)) for _ in range(n_clusters)]

# === Aggregate 3-channel BT values by cluster ===
for i in range(8):
    site_nc, site_var = obs_sites_ncvar_name[i]
    cluster_id = obs_cluster_sites[i]
    
    nc_file = f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_{site_nc}_allchannelcrops.nc"
    cluster_file = f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth"
    
    if not os.path.exists(nc_file) or not os.path.exists(cluster_file):
        print(f"Missing: {site_nc}")
        continue

    ds = xr.open_dataset(nc_file)
    c5 = ds[f'sample_{site_var}_data_5'].values.reshape(ds.dims['sample'], -1)
    c6 = ds[f'sample_{site_var}_data_6'].values.reshape(ds.dims['sample'], -1)
    c9 = ds[f'sample_{site_var}_data_9'].values.reshape(ds.dims['sample'], -1)
    cluster_labels = np.array(torch.load(cluster_file, map_location="cpu"))

    for cl in range(n_clusters):
        mask = cluster_labels == cl
        if np.any(mask):
            x = c5[mask].ravel()
            y = c6[mask].ravel()
            z = c9[mask].ravel()
            xyz = np.vstack((x, y, z)).T
            joint_hist, _ = np.histogramdd(xyz, bins=bins, density=True)
            all_joint_histograms[cl] += joint_hist

# === Flatten and normalize each joint histogram ===
histograms_flat = []
for hist3d in all_joint_histograms:
    hist_flat = hist3d.flatten()
    hist_flat += 1e-10  # regularize
    hist_flat /= hist_flat.sum()
    histograms_flat.append(hist_flat)

# === Compute JS & KL matrices ===
js_matrix = np.zeros((n_clusters, n_clusters))
kl_matrix = np.zeros((n_clusters, n_clusters))

for i in range(n_clusters):
    for j in range(n_clusters):
        P = histograms_flat[i]
        Q = histograms_flat[j]
        js_matrix[i, j] = jensenshannon(P, Q, base=2) ** 2
        kl_matrix[i, j] = np.sum(rel_entr(P, Q))

# === Plot heatmaps ===
plt.figure(figsize=(10, 8))
sns.heatmap(js_matrix, annot=True, fmt=".3f", cmap="viridis")
plt.title("JS Divergence Between Clusters (Joint 6.2, 7.3, 10.8 µm)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/js_divergence_joint_C5C6C9.png")

plt.figure(figsize=(10, 8))
sns.heatmap(kl_matrix, annot=True, fmt=".2f", cmap="magma")
plt.title("KL Divergence Between Clusters (Joint 6.2, 7.3, 10.8 µm)")
plt.xlabel("Cluster")
plt.ylabel("Cluster")
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/kl_divergence_joint_C5C6C9.png")

### tSNE and PCA of joint histograms

In [76]:
# Stack flattened histograms into matrix (10, features)
X = np.stack(histograms_flat)  # shape: (10, N)

# === PCA ===
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

plt.figure(figsize=(7, 6))
for i in range(n_clusters):
    plt.scatter(X_pca[i, 0], X_pca[i, 1], label=f'Cluster {i}', s=100)
plt.title("PCA Projection of Joint BT Histograms (C5, C6, C9)")
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/pca_joint_C5C6C9.png")
plt.show()

# === t-SNE ===
tsne = TSNE(n_components=2, perplexity=5, random_state=42, init="pca", learning_rate="auto")
X_tsne = tsne.fit_transform(X)

plt.figure(figsize=(7, 6))
for i in range(n_clusters):
    plt.scatter(X_tsne[i, 0], X_tsne[i, 1], label=f'Cluster {i}', s=100)
plt.title("t-SNE Projection of Joint BT Histograms (C5, C6, C9)")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/tsne_joint_C5C6C9.png")
plt.show()

### Visulaize 3D joint histograms

In [64]:
# One cluster

hist = all_joint_histograms[cluster_idx]  # e.g., cluster 0

x_bins = bin_edges_C5[:-1]
y_bins = bin_edges_C6[:-1]
z_bins = bin_edges_C9[:-1]

# Get all non-zero bins
x_idx, y_idx, z_idx = np.nonzero(hist)
values = hist[x_idx, y_idx, z_idx]

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Create voxel grid
dx = dy = dz = 5  # bin width
ax.bar3d(
    x_bins[x_idx], y_bins[y_idx], z_bins[z_idx],
    dx, dy, dz,
    shade=True,
    color='skyblue',
    alpha=0.6,
    zsort='average'
)

ax.set_xlabel('BT (C5 - 6.2 µm)')
ax.set_ylabel('BT (C6 - 7.3 µm)')
ax.set_zlabel('BT (C9 - 10.8 µm)')
ax.set_title(f'3D Joint Histogram for Cluster {cluster_idx}')
plt.tight_layout()
plt.savefig(f'/p/project1/exaww/chatterjee1/plots/continuous/cluster{cluster_idx}_3D_C5C6C9.png')

In [77]:
# all cluster

# Bins without rightmost edge (needed for coordinates)
# Bins without rightmost edge
x_bins = bin_edges_C5[:-1]
y_bins = bin_edges_C6[:-1]
z_bins = bin_edges_C9[:-1]

dx = dy = dz = 5

# Axis ranges
xlim = (bin_edges_C5[0], bin_edges_C5[-1])
ylim = (bin_edges_C6[0], bin_edges_C6[-1])
zlim = (bin_edges_C9[0], bin_edges_C9[-1])

# Create figure
fig = plt.figure(figsize=(12, 20))

for cluster_idx in range(n_clusters):
    hist = all_joint_histograms[cluster_idx]
    
    x_idx, y_idx, z_idx = np.nonzero(hist)

    ax = fig.add_subplot(5, 2, cluster_idx + 1, projection='3d')

    ax.bar3d(
        x_bins[x_idx], y_bins[y_idx], z_bins[z_idx],
        dx, dy, dz,
        shade=True,
        color='skyblue',
        alpha=0.6,
        zsort='average'
    )

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_zlim(zlim)
    
    ax.set_title(f'Cluster {cluster_idx}', fontsize=10)
    if cluster_idx == 0:
        ax.set_xlabel('C5 (6.2 µm)')
        ax.set_ylabel('C6 (7.3 µm)')
        ax.set_zlabel('C9 (10.8 µm)')
    else:
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_zlabel('')

    # Optional: reduce tick clutter
    #if cluster_idx == 1 or cluster_idx == 3 or cluster_idx == 5 or cluster_idx == 7 or cluster_idx == 9 :
    #    ax.set_xticks([])
    #    ax.set_yticks([])
    #    ax.set_zticks([])

fig.suptitle('3D Joint Histograms of C5-C6-C9 for All Clusters', fontsize=16)
plt.tight_layout()
plt.subplots_adjust(top=0.92)  # leave space for suptitle
plt.savefig("/p/project1/exaww/chatterjee1/plots/continuous/bar3d_allclusters_C5C6C9.png", dpi=300)
plt.show()