In [2]:
import numpy as np
import torch
import xarray as xr
import datetime as dt
import pandas as pd
from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt

In [96]:
# OBS

#pop_summary = []  # list to accumulate all site results
site_name = "Cabauw"

# Load the data
ds = xr.open_dataset("/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_cabauwcrops.nc")
cluster_data = torch.load("/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_cabauw_cluster_10_labels.pth", map_location="cpu")

# Load model data
model_ds = xr.open_dataset("/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_cabauwcrops_icon.nc")
model_cluster_data = torch.load("/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/" + "icon_cabauw_cluster_10_labels.pth", map_location="cpu")

# Parse the times
raw_times = ds['time'].values  # (sample,)
obs_datetimes = pd.to_datetime([t[:12] for t in raw_times], format="%Y%m%d%H%M")  # skip last 2 digits (seconds)

#### Select Closest Times to Full Hours

# Target full hourly timestamps
hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'),
                             end=obs_datetimes.max().ceil('H'),
                             freq='H')

# Match closest obs time to each hourly time
matched_indices = []
for h in hourly_times:
    time_diffs = np.abs((obs_datetimes - h).total_seconds())
    min_idx = np.argmin(time_diffs)
    if time_diffs[min_idx] <= 900:  # Accept if within 15 minutes
        matched_indices.append(min_idx)

# Filter obs time and cluster labels to only matched hourly ones
obs_hourly_datetimes = obs_datetimes[matched_indices]
cluster_labels_hourly = np.array(cluster_data)[matched_indices]

#### Compute Probability of Persistence (PoP)

cluster_durations = defaultdict(list)

# Run-length encode with time consistency check
prev_label = cluster_labels_hourly[0]
prev_time = obs_hourly_datetimes[0]
count = 1

for i in range(1, len(cluster_labels_hourly)):
    current_label = cluster_labels_hourly[i]
    current_time = obs_hourly_datetimes[i]
    
    time_diff = (current_time - prev_time).total_seconds() / 3600.0  # in hours
    
    if current_label == prev_label and np.isclose(time_diff, 1.0):
        count += 1
    else:
        # Save the duration and reset
        cluster_durations[prev_label].append(count)
        count = 1
    
    # Update reference values
    prev_label = current_label
    prev_time = current_time

# Handle last sequence
cluster_durations[prev_label].append(count)

# Compute average persistence duration
probability_of_persistence = {c: np.mean(durations) for c, durations in cluster_durations.items()}

#### Final Output

#for cluster, avg_duration in probability_of_persistence.items():
#    print(f"Cluster {cluster}: OBS Average persistence = {avg_duration:.2f} hours")

## ICON

# Parse model times
model_times = model_ds["time"].values
model_datetimes = model_datetimes = pd.to_datetime(model_times)

# Model is already hourly — we just verify and filter cluster data accordingly
# Optional: filter for a matching time window if needed (e.g., same as obs_hourly_datetimes)
start_time = obs_hourly_datetimes.min()
end_time = obs_hourly_datetimes.max()

# Mask model datetimes to match observation range
model_mask = (model_datetimes >= start_time) & (model_datetimes <= end_time)
model_datetimes_filtered = model_datetimes[model_mask]
model_cluster_labels_filtered = np.array(model_cluster_data)[model_mask]

model_durations = defaultdict(list)

prev_label = model_cluster_labels_filtered[0]
prev_time = model_datetimes_filtered[0]
count = 1

for i in range(1, len(model_cluster_labels_filtered)):
    current_label = model_cluster_labels_filtered[i]
    current_time = model_datetimes_filtered[i]
    
    time_diff = (current_time - prev_time).total_seconds() / 3600.0  # in hours

    if current_label == prev_label and np.isclose(time_diff, 1.0):
        count += 1
    else:
        model_durations[prev_label].append(count)
        count = 1

    # Update
    prev_label = current_label
    prev_time = current_time

# Final entry
model_durations[prev_label].append(count)

# Compute average durations (PoP)
model_pop = {c: np.mean(durations) for c, durations in model_durations.items()}

print("Comparison of Probability of Persistence (in hours):\n")

all_clusters = sorted(set(cluster_labels_hourly) | set(model_cluster_labels_filtered))
for c in all_clusters:
    obs_val = probability_of_persistence.get(c, 0)
    model_val = model_pop.get(c, 0)
    pop_summary.append({
        'site': site_name,
        'cluster': c,
        'obs_pop': obs_val,
        'model_pop': model_val
    })
    print(f"Cluster {c}: Obs = {obs_val:.2f} hrs | Model = {model_val:.2f} hrs")

Comparison of Probability of Persistence (in hours):

Cluster 0: Obs = 5.94 hrs | Model = 3.12 hrs
Cluster 1: Obs = 6.09 hrs | Model = 1.00 hrs
Cluster 2: Obs = 4.60 hrs | Model = 3.69 hrs
Cluster 3: Obs = 4.92 hrs | Model = 3.90 hrs
Cluster 4: Obs = 3.54 hrs | Model = 6.67 hrs
Cluster 5: Obs = 4.93 hrs | Model = 2.60 hrs
Cluster 6: Obs = 4.14 hrs | Model = 3.71 hrs
Cluster 7: Obs = 5.63 hrs | Model = 2.59 hrs
Cluster 8: Obs = 6.55 hrs | Model = 9.60 hrs
Cluster 9: Obs = 3.43 hrs | Model = 3.55 hrs


In [97]:
pop_summary

[{'site': 'Juelich',
  'cluster': 0,
  'obs_pop': 5.949367088607595,
  'model_pop': 2.857142857142857},
 {'site': 'Juelich',
  'cluster': 1,
  'obs_pop': 5.911111111111111,
  'model_pop': 3.0},
 {'site': 'Juelich',
  'cluster': 2,
  'obs_pop': 3.393162393162393,
  'model_pop': 2.9464285714285716},
 {'site': 'Juelich',
  'cluster': 3,
  'obs_pop': 3.845528455284553,
  'model_pop': 3.4615384615384617},
 {'site': 'Juelich',
  'cluster': 4,
  'obs_pop': 4.3283582089552235,
  'model_pop': 7.348258706467662},
 {'site': 'Juelich',
  'cluster': 5,
  'obs_pop': 4.164285714285715,
  'model_pop': 2.611111111111111},
 {'site': 'Juelich', 'cluster': 6, 'obs_pop': 4.768, 'model_pop': 3.5},
 {'site': 'Juelich',
  'cluster': 7,
  'obs_pop': 4.838095238095238,
  'model_pop': 3.4444444444444446},
 {'site': 'Juelich',
  'cluster': 8,
  'obs_pop': 6.96078431372549,
  'model_pop': 9.584415584415584},
 {'site': 'Juelich',
  'cluster': 9,
  'obs_pop': 3.3333333333333335,
  'model_pop': 3.4172185430463577},
 

### Plotting 

In [98]:
# Sort cluster labels using your method
all_clusters = sorted(set(cluster_labels_hourly) | set(model_cluster_labels_filtered))
all_clusters = [str(c) for c in all_clusters]  # Convert to str for consistent plotting

# Prepare dataframes
obs_df = pd.DataFrame({
    'cluster': [str(k) for k in probability_of_persistence.keys()],
    'pop': list(probability_of_persistence.values()),
    'source': 'obs'
})

model_df = pd.DataFrame({
    'cluster': [str(k) for k in model_pop.keys()],
    'pop': list(model_pop.values()),
    'source': 'model'
})

combined_df = pd.concat([obs_df, model_df], ignore_index=True)

# Ensure correct cluster ordering
combined_df['cluster'] = pd.Categorical(combined_df['cluster'], categories=all_clusters, ordered=True)

# Plot
plt.figure(figsize=(10, 6))
sns.barplot(
    data=combined_df,
    x='cluster',
    y='pop',
    hue='source',
    palette='tab10'
)

plt.xlabel('Mesoscale regimes')
plt.ylabel('Average Persistence Duration (hours)')
plt.title('Cabauw')
plt.legend(title='cabauw PoP')
plt.grid(True)
plt.tight_layout()
plt.show()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/cabauw_pop_.png", dpi=100, bbox_inches="tight")

In [99]:
df_summary = pd.DataFrame(pop_summary)
df_summary

Unnamed: 0,site,cluster,obs_pop,model_pop
0,Juelich,0,5.949367,2.857143
1,Juelich,1,5.911111,3.000000
2,Juelich,2,3.393162,2.946429
3,Juelich,3,3.845528,3.461538
4,Juelich,4,4.328358,7.348259
...,...,...,...,...
75,Cabauw,5,4.931507,2.595238
76,Cabauw,6,4.141732,3.714286
77,Cabauw,7,5.630435,2.588235
78,Cabauw,8,6.550000,9.604167


In [100]:
df_summary.to_csv('/p/project/exaww/chatterjee1/mcspss_continuous/analysis/persistence_all_sites.csv')

### Cluster wise boc plot across different sites

In [101]:
df_pop = pd.DataFrame(pop_summary)

# Convert cluster to string if you prefer categorical x-axis
df_pop['cluster'] = df_pop['cluster'].astype(str)

# Melt the dataframe for long-form plotting
df_melted = df_pop.melt(
    id_vars=['site', 'cluster'],
    value_vars=['obs_pop', 'model_pop'],
    var_name='source',
    value_name='persistence_time'
)

# Clean source labels for clarity
df_melted['source'] = df_melted['source'].str.replace('_pop', '', regex=False)

# Plot
plt.figure(figsize=(12, 6))
sns.boxplot(
    data=df_melted,
    x='cluster',
    y='persistence_time',
    hue='source',
    palette='tab10'
)
plt.xlabel('Cluster')
plt.ylabel('Persistence Duration (hours)')
plt.title('Cluster-wise Persistence Duration Across Sites')
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()

In [102]:
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_pop.png", dpi=100, bbox_inches="tight")

## all site persistence

In [16]:
# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zargoza", "zargoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# Accumulate persistence summaries
pop_summary = []

for i in range(11):
    try:
        site_nc, site_var = obs_sites_ncvar_name[i]
        cluster_id = obs_cluster_sites[i]
        site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
        cluster_id_icon = icon_cluster_sites[i]

        # Load obs data
        ds = xr.open_dataset(f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_{site_nc}crops.nc")
        cluster_data = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth", map_location="cpu")
        cluster_data[cluster_data == 0] = 7
        
        # Parse obs times
        raw_times = ds['time'].values
        obs_datetimes = pd.to_datetime([t[:12] for t in raw_times], format="%Y%m%d%H%M")

        # Find closest timestamps to full hours
        hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'), end=obs_datetimes.max().ceil('H'), freq='H')
        matched_indices = [np.argmin(np.abs((obs_datetimes - h).total_seconds())) for h in hourly_times if np.min(np.abs((obs_datetimes - h).total_seconds())) <= 900]

        obs_hourly_datetimes = obs_datetimes[matched_indices]
        cluster_labels_hourly = np.array(cluster_data)[matched_indices]

        # Compute obs persistence
        cluster_durations = defaultdict(list)
        prev_label = cluster_labels_hourly[0]
        prev_time = obs_hourly_datetimes[0]
        count = 1
        for i in range(1, len(cluster_labels_hourly)):
            curr_label = cluster_labels_hourly[i]
            curr_time = obs_hourly_datetimes[i]
            if curr_label == prev_label and np.isclose((curr_time - prev_time).total_seconds() / 3600.0, 1.0):
                count += 1
            else:
                cluster_durations[prev_label].append(count)
                count = 1
            prev_label = curr_label
            prev_time = curr_time
        cluster_durations[prev_label].append(count)
        probability_of_persistence = {c: np.mean(durations) for c, durations in cluster_durations.items()}

        # Load ICON model data
        model_ds = xr.open_dataset(f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_nc_icon}crops_icon.nc")
        model_cluster_data = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth", map_location="cpu")
        model_cluster_data[model_cluster_data == 0] = 7
        
        model_datetimes = pd.to_datetime(model_ds['time'].values)
        mask = (model_datetimes >= obs_hourly_datetimes.min()) & (model_datetimes <= obs_hourly_datetimes.max())
        model_datetimes_filtered = model_datetimes[mask]
        model_cluster_labels_filtered = np.array(model_cluster_data)[mask]

        model_durations = defaultdict(list)
        prev_label = model_cluster_labels_filtered[0]
        prev_time = model_datetimes_filtered[0]
        count = 1
        for i in range(1, len(model_cluster_labels_filtered)):
            curr_label = model_cluster_labels_filtered[i]
            curr_time = model_datetimes_filtered[i]
            if curr_label == prev_label and np.isclose((curr_time - prev_time).total_seconds() / 3600.0, 1.0):
                count += 1
            else:
                model_durations[prev_label].append(count)
                count = 1
            prev_label = curr_label
            prev_time = curr_time
        model_durations[prev_label].append(count)
        model_pop = {c: np.mean(durations) for c, durations in model_durations.items()}

        all_clusters = sorted(set(cluster_labels_hourly) | set(model_cluster_labels_filtered))
        for c in all_clusters:
            pop_summary.append({
                'site': site_nc,
                'cluster': c,
                'obs_pop': probability_of_persistence.get(c, 0),
                'model_pop': model_pop.get(c, 0)
            })

    except Exception as e:
        print(f"Skipping {site_nc} due to error: {e}")

# Convert to DataFrame
df_pop = pd.DataFrame(pop_summary)
df_pop['cluster'] = df_pop['cluster'].astype(str)

# Melt for seaborn
df_melted = df_pop.melt(
    id_vars=['site', 'cluster'],
    value_vars=['obs_pop', 'model_pop'],
    var_name='source',
    value_name='persistence_time'
)
df_melted['source'] = df_melted['source'].str.replace('_pop', '', regex=False)

# Plot boxplot
plt.figure(figsize=(12, 6))
sns.boxplot(
    data=df_melted,
    x='cluster',
    y='persistence_time',
    hue='source',
    palette='tab10'
)
plt.xlabel('Cluster')
plt.ylabel('Persistence Duration (hours)')
plt.title('Cluster-wise Persistence Duration Across Sites')
plt.grid(True, axis='y')
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_pop_new_.png", dpi=100, bbox_inches="tight")
plt.show()

In [20]:
obs_hourly_datetimes

DatetimeIndex(['2023-04-01 00:12:00', '2023-04-01 00:57:00',
               '2023-04-01 01:57:00', '2023-04-01 02:57:00',
               '2023-04-01 03:57:00', '2023-04-01 04:57:00',
               '2023-04-01 05:57:00', '2023-04-01 06:57:00',
               '2023-04-01 07:57:00', '2023-04-01 08:57:00',
               ...
               '2023-09-30 14:57:00', '2023-09-30 15:57:00',
               '2023-09-30 16:57:00', '2023-09-30 17:57:00',
               '2023-09-30 18:57:00', '2023-09-30 19:57:00',
               '2023-09-30 20:57:00', '2023-09-30 21:57:00',
               '2023-09-30 22:57:00', '2023-09-30 23:57:00'],
              dtype='datetime64[ns]', length=4368, freq=None)

## all site transition probability

In [10]:
def compute_transition_matrix(cluster_sequence, n_clusters):
    # Initialize square matrix (n_clusters x n_clusters)
    matrix = np.zeros((n_clusters, n_clusters), dtype=np.float64)

    # Count transitions
    for i in range(len(cluster_sequence) - 1):
        from_c = cluster_sequence[i]
        to_c = cluster_sequence[i + 1]
        matrix[from_c, to_c] += 1

    # Normalize rows (transition probabilities)
    row_sums = matrix.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        transition_probs = np.divide(matrix, row_sums, out=np.zeros_like(matrix), where=row_sums != 0)

    return pd.DataFrame(transition_probs, columns=[f"to_{i}" for i in range(n_clusters)],
                        index=[f"from_{i}" for i in range(n_clusters)])

In [17]:
# Step 1: Convert to 0-indexed labels for matrix
cluster_labels_hourly_mapped = cluster_labels_hourly - 1
model_cluster_labels_filtered_mapped = model_cluster_labels_filtered - 1

# Step 2: Compute transition matrices
n_clusters = 9
obs_trans_matrix = compute_transition_matrix(cluster_labels_hourly_mapped, n_clusters)
model_trans_matrix = compute_transition_matrix(model_cluster_labels_filtered_mapped, n_clusters)

# Step 3: Relabel axes for 1-indexed display
obs_trans_matrix.index = [f"from_{i}" for i in range(1, 10)]
obs_trans_matrix.columns = [f"to_{i}" for i in range(1, 10)]
model_trans_matrix.index = [f"from_{i}" for i in range(1, 10)]
model_trans_matrix.columns = [f"to_{i}" for i in range(1, 10)]

# Step 4: Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

sns.heatmap(obs_trans_matrix, annot=True, cmap="YlGnBu", fmt=".2f", ax=axes[0])
axes[0].set_title("Observation Transition Matrix (Clusters 1–9)")
axes[0].set_xlabel("Next Cluster")
axes[0].set_ylabel("Current Cluster")

sns.heatmap(model_trans_matrix, annot=True, cmap="YlGnBu", fmt=".2f", ax=axes[1])
axes[1].set_title("Model Transition Matrix (Clusters 1–9)")
axes[1].set_xlabel("Next Cluster")
axes[1].set_ylabel("")

plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_model_obs_tp.png", dpi=100, bbox_inches="tight")
plt.show()

In [12]:
model_cluster_labels_filtered

array([4, 4, 4, ..., 9, 9, 4])

In [18]:
# Extract self-transition probabilities
obs_self_trans = np.diag(obs_trans_matrix.values)
model_self_trans = np.diag(model_trans_matrix.values)

# Create DataFrame with cluster labels from 1 to 9
df_self = pd.DataFrame({
    'cluster': [str(i) for i in range(1, n_clusters + 1)],
    'obs': obs_self_trans,
    'model': model_self_trans
}).melt(id_vars='cluster', var_name='source', value_name='self_transition_prob')

# Bar plot of self-transition probabilities
plt.figure(figsize=(8, 5))
sns.barplot(data=df_self, x='cluster', y='self_transition_prob', hue='source', palette='tab10')
plt.title("Self-Transition Probabilities (Obs vs Model)")
plt.xlabel("Cluster")
plt.ylabel("Self-Transition Probability")
plt.grid(True, axis='y')
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_model_obs_stp.png", dpi=100, bbox_inches="tight")
plt.show()

In [19]:
def compute_strict_nonself_transition_matrix(cluster_sequence, n_clusters):
    count_matrix = np.zeros((n_clusters, n_clusters), dtype=np.float64)

    for i in range(len(cluster_sequence) - 1):
        from_c = cluster_sequence[i]
        to_c = cluster_sequence[i + 1]
        count_matrix[from_c, to_c] += 1

    # Remove self-transitions
    np.fill_diagonal(count_matrix, 0)

    # Normalize by non-self transitions only
    row_sums = count_matrix.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        nonself_probs = np.divide(count_matrix, row_sums, out=np.zeros_like(count_matrix), where=row_sums != 0)

    # Return labeled DataFrame with 1-indexed labels
    return pd.DataFrame(
        nonself_probs,
        columns=[f"to_{i}" for i in range(1, n_clusters + 1)],
        index=[f"from_{i}" for i in range(1, n_clusters + 1)]
    )

# First shift your labels from 1–9 → 0–8
cluster_labels_hourly_mapped = cluster_labels_hourly - 1
model_cluster_labels_filtered_mapped = model_cluster_labels_filtered - 1

# Compute non-self transition matrices
obs_strict_nonself = compute_strict_nonself_transition_matrix(cluster_labels_hourly_mapped, n_clusters)
model_strict_nonself = compute_strict_nonself_transition_matrix(model_cluster_labels_filtered_mapped, n_clusters)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

sns.heatmap(obs_strict_nonself, annot=True, cmap="Blues", fmt=".2f", ax=axes[0])
axes[0].set_title("Obs Strict Non-Self Transition Probabilities")
axes[0].set_xlabel("Next Cluster")
axes[0].set_ylabel("Current Cluster")

sns.heatmap(model_strict_nonself, annot=True, cmap="Blues", fmt=".2f", ax=axes[1])
axes[1].set_title("Model Strict Non-Self Transition Probabilities")
axes[1].set_xlabel("Next Cluster")
axes[1].set_ylabel("")

plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_model_obs_nstp.png", dpi=100, bbox_inches="tight")
plt.show()

## After reducing the three clusters

In [6]:
# Final label after merge and reindex
merged_to_final_label = {
    3: 0,  # 3 includes original 2
    4: 1,
    5: 2,  # 5 includes original 1
    6: 3,
    7: 4,  # 7 includes original 0
    8: 5,
    9: 6,
}

# Define final x-axis cluster IDs
plot_clusters = list(range(3, 10))

# Define custom labels
custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}

# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zargoza", "zargoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

# Accumulate persistence summaries
pop_summary = []

for i in range(11):
    try:
        site_nc, site_var = obs_sites_ncvar_name[i]
        cluster_id = obs_cluster_sites[i]
        site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
        cluster_id_icon = icon_cluster_sites[i]

        # Load obs data
        ds = xr.open_dataset(f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_{site_nc}crops.nc")
        cluster_data = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth", map_location="cpu")
        cluster_data = np.array(cluster_data)
        cluster_data[cluster_data == 0] = 7
        cluster_data[cluster_data == 1] = 5
        cluster_data[cluster_data == 2] = 3
        #cluster_data = np.vectorize(merged_to_final_label.get)(cluster_data)
        
        # Parse obs times
        raw_times = ds['time'].values
        obs_datetimes = pd.to_datetime([t[:12] for t in raw_times], format="%Y%m%d%H%M")

        # Find closest timestamps to full hours
        hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'), end=obs_datetimes.max().ceil('H'), freq='H')
        matched_indices = [np.argmin(np.abs((obs_datetimes - h).total_seconds())) for h in hourly_times if np.min(np.abs((obs_datetimes - h).total_seconds())) <= 900]

        obs_hourly_datetimes = obs_datetimes[matched_indices]
        cluster_labels_hourly = np.array(cluster_data)[matched_indices]

        # Compute obs persistence
        cluster_durations = defaultdict(list)
        prev_label = cluster_labels_hourly[0]
        prev_time = obs_hourly_datetimes[0]
        count = 1
        for i in range(1, len(cluster_labels_hourly)):
            curr_label = cluster_labels_hourly[i]
            curr_time = obs_hourly_datetimes[i]
            if curr_label == prev_label and np.isclose((curr_time - prev_time).total_seconds() / 3600.0, 1.0):
                count += 1
            else:
                cluster_durations[prev_label].append(count)
                count = 1
            prev_label = curr_label
            prev_time = curr_time
        cluster_durations[prev_label].append(count)
        probability_of_persistence = {c: np.mean(durations) for c, durations in cluster_durations.items()}

        # Load ICON model data
        model_ds = xr.open_dataset(f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_nc_icon}crops_icon.nc")
        model_cluster_data = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth", map_location="cpu")
        model_cluster_data = np.array(model_cluster_data)
        model_cluster_data[model_cluster_data == 0] = 7
        model_cluster_data[model_cluster_data == 1] = 5
        model_cluster_data[model_cluster_data == 2] = 3
        #model_cluster_data = np.vectorize(merged_to_final_label.get)(model_cluster_data)
        
        model_datetimes = pd.to_datetime(model_ds['time'].values)
        mask = (model_datetimes >= obs_hourly_datetimes.min()) & (model_datetimes <= obs_hourly_datetimes.max())
        model_datetimes_filtered = model_datetimes[mask]
        model_cluster_labels_filtered = np.array(model_cluster_data)[mask]

        model_durations = defaultdict(list)
        prev_label = model_cluster_labels_filtered[0]
        prev_time = model_datetimes_filtered[0]
        count = 1
        for i in range(1, len(model_cluster_labels_filtered)):
            curr_label = model_cluster_labels_filtered[i]
            curr_time = model_datetimes_filtered[i]
            if curr_label == prev_label and np.isclose((curr_time - prev_time).total_seconds() / 3600.0, 1.0):
                count += 1
            else:
                model_durations[prev_label].append(count)
                count = 1
            prev_label = curr_label
            prev_time = curr_time
        model_durations[prev_label].append(count)
        model_pop = {c: np.mean(durations) for c, durations in model_durations.items()}

        all_clusters = sorted(set(cluster_labels_hourly) | set(model_cluster_labels_filtered))
        for c in all_clusters:
            pop_summary.append({
                'site': site_nc,
                'cluster': c,
                'obs_pop': probability_of_persistence.get(c, 0),
                'model_pop': model_pop.get(c, 0)
            })

    except Exception as e:
        print(f"Skipping {site_nc} due to error: {e}")

# Convert to DataFrame
df_pop = pd.DataFrame(pop_summary)
df_pop['cluster'] = df_pop['cluster'].astype(str)

# Melt for seaborn
df_melted = df_pop.melt(
    id_vars=['site', 'cluster'],
    value_vars=['obs_pop', 'model_pop'],
    var_name='source',
    value_name='persistence_time'
)
df_melted['source'] = df_melted['source'].str.replace('_pop', '', regex=False)

# Plot boxplot
plt.figure(figsize=(12, 6))
sns.boxplot(
    data=df_melted,
    x='cluster',
    y='persistence_time',
    hue='source',
    palette='tab10'
)

plt.ylabel('Persistence Duration (hours)')
plt.title('Cluster-wise Persistence Duration Across Sites')
plt.grid(True, axis='y')


# Generate xtick labels
xtick_labels = [custom_labels.get(cl, str(cl)) for cl in plot_clusters]

# Apply to plot
plt.xlabel('Cluster')
plt.xticks(ticks=range(len(plot_clusters)), labels=xtick_labels)

plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_pop_new_merged_123.png", dpi=100, bbox_inches="tight")
plt.show()

In [7]:
model_cluster_labels_filtered.shape

(4573,)

### merged transition probability

In [9]:


def compute_transition_matrix(cluster_sequence, n_clusters, valid_clusters):
    matrix = np.zeros((n_clusters, n_clusters), dtype=np.float64)

    for i in range(len(cluster_sequence) - 1):
        from_c = cluster_sequence[i]
        to_c = cluster_sequence[i + 1]
        if from_c in valid_clusters and to_c in valid_clusters:
            from_idx = valid_clusters.index(from_c)
            to_idx = valid_clusters.index(to_c)
            matrix[from_idx, to_idx] += 1

    row_sums = matrix.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        transition_probs = np.divide(matrix, row_sums, out=np.zeros_like(matrix), where=row_sums != 0)

    custom_labels = {
        3: "3(+2)",
        5: "5(+1)",
        7: "7(+0)"
    }
    columns = [custom_labels.get(c, str(c)) for c in valid_clusters]
    index = [f"from_{custom_labels.get(c, str(c))}" for c in valid_clusters]

    return pd.DataFrame(transition_probs, columns=columns, index=index)

final_clusters = [3, 4, 5, 6, 7, 8, 9]
n_clusters = len(final_clusters)

custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}

obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zargoza", "zargoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
obs_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}
icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}
icon_cluster_sites = {
    0: "juelich",
    1: "lin",
    2: "warsaw",
    3: "vienna",
    4: "bourges",
    5: "zargoza",
    6: "sirta",
    7: "cabauw",
    8: "nuremberg",
    9: "aurillac",
    10: "dresden",
}

all_obs_labels = []
all_model_labels = []
obs_matrices = []
model_matrices = []

for i in range(11):
    try:
        site_nc, site_var = obs_sites_ncvar_name[i]
        cluster_id = obs_cluster_sites[i]
        site_nc_icon, site_var_icon = icon_sites_ncvar_name[i]
        cluster_id_icon = icon_cluster_sites[i]

        ds = xr.open_dataset(f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_{site_nc}crops.nc")
        cluster_data = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{cluster_id}_cluster_10_labels.pth", map_location="cpu")
        cluster_data = np.array(cluster_data)
        cluster_data[cluster_data == 0] = 7
        cluster_data[cluster_data == 1] = 5
        cluster_data[cluster_data == 2] = 3

        raw_times = ds['time'].values
        obs_datetimes = pd.to_datetime([t[:12] for t in raw_times], format="%Y%m%d%H%M")
        hourly_times = pd.date_range(start=obs_datetimes.min().floor('H'), end=obs_datetimes.max().ceil('H'), freq='H')
        matched_indices = [np.argmin(np.abs((obs_datetimes - h).total_seconds())) for h in hourly_times if np.min(np.abs((obs_datetimes - h).total_seconds())) <= 900]

        obs_hourly_datetimes = obs_datetimes[matched_indices]
        cluster_labels_hourly = np.array(cluster_data)[matched_indices]

        model_ds = xr.open_dataset(f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_nc_icon}crops_icon.nc")
        model_cluster_data = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{cluster_id_icon}_cluster_10_labels.pth", map_location="cpu")
        model_cluster_data = np.array(model_cluster_data)
        model_cluster_data[model_cluster_data == 0] = 7
        model_cluster_data[model_cluster_data == 1] = 5
        model_cluster_data[model_cluster_data == 2] = 3

        model_datetimes = pd.to_datetime(model_ds['time'].values)
        mask = (model_datetimes >= obs_hourly_datetimes.min()) & (model_datetimes <= obs_hourly_datetimes.max())
        model_datetimes_filtered = model_datetimes[mask]
        model_cluster_labels_filtered = np.array(model_cluster_data)[mask]
        
        obs_matrix = compute_transition_matrix(cluster_labels_hourly, n_clusters, final_clusters)
        model_matrix = compute_transition_matrix(model_cluster_labels_filtered, n_clusters, final_clusters)
        obs_matrices.append(obs_matrix)
        model_matrices.append(model_matrix)


    except Exception as e:
        print(f"Skipping {site_nc} due to error: {e}")

# Average across sites
obs_avg_matrix = sum(obs_matrices) / len(obs_matrices)
model_avg_matrix = sum(model_matrices) / len(model_matrices)


In [10]:
# Define custom labels for x and y ticks
custom_labels = {
    3: "3(+2)",
    5: "5(+1)",
    7: "7(+0)"
}
final_clusters = [3, 4, 5, 6, 7, 8, 9]

xtick_labels = [custom_labels.get(i, str(i)) for i in final_clusters]
ytick_labels = [f"from_{custom_labels.get(i, str(i))}" for i in final_clusters]

obs_trans_matrix = pd.DataFrame(obs_avg_matrix, index=ytick_labels, columns=xtick_labels)
model_trans_matrix = pd.DataFrame(model_avg_matrix, index=ytick_labels, columns=xtick_labels)

# Plotting
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

sns.heatmap(obs_trans_matrix, annot=True, cmap="YlGnBu", fmt=".2f", ax=axes[0])
axes[0].set_title("Observation Transition Matrix (Clusters 3–9)")
axes[0].set_xlabel("Next Cluster")
axes[0].set_ylabel("Current Cluster")

sns.heatmap(model_trans_matrix, annot=True, cmap="YlGnBu", fmt=".2f", ax=axes[1])
axes[1].set_title("Model Transition Matrix (Clusters 3–9)")
axes[1].set_xlabel("Next Cluster")
axes[1].set_ylabel("")
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_model_obs_tp_merged.png", dpi=100, bbox_inches="tight")
plt.show()

In [11]:
# Define proper tick labels based on final_clusters
tick_labels = [custom_labels.get(i, str(i)) for i in final_clusters]

# Create DataFrame with self-transition values using tick labels
df_self = pd.DataFrame({
    'cluster': tick_labels,
    'obs': np.diag(obs_trans_matrix.values),
    'model': np.diag(model_trans_matrix.values)
}).melt(id_vars='cluster', var_name='source', value_name='self_transition_prob')

# Bar plot
plt.figure(figsize=(8, 5))
sns.barplot(data=df_self, x='cluster', y='self_transition_prob', hue='source', palette='tab10')
plt.title("Self-Transition Probabilities (Obs vs Model)")
plt.xlabel("Cluster")
plt.ylabel("Self-Transition Probability")
plt.grid(True, axis='y')
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_model_obs_stp_merged.png", dpi=100, bbox_inches="tight")


In [13]:
def compute_strict_nonself_transition_matrix(cluster_sequence, n_clusters, valid_clusters, label_map):
    count_matrix = np.zeros((n_clusters, n_clusters), dtype=np.float64)

    for i in range(len(cluster_sequence) - 1):
        from_c = cluster_sequence[i]
        to_c = cluster_sequence[i + 1]
        if from_c in valid_clusters and to_c in valid_clusters:
            from_idx = valid_clusters.index(from_c)
            to_idx = valid_clusters.index(to_c)
            if from_idx != to_idx:
                count_matrix[from_idx, to_idx] += 1

    np.fill_diagonal(count_matrix, 0)

    row_sums = count_matrix.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        nonself_probs = np.divide(count_matrix, row_sums, out=np.zeros_like(count_matrix), where=row_sums != 0)

    col_labels = [label_map.get(c, str(c)) for c in valid_clusters]
    row_labels = [f"from_{label_map.get(c, str(c))}" for c in valid_clusters]

    return pd.DataFrame(nonself_probs, columns=col_labels, index=row_labels)

# Compute strict non-self matrices from previously used dummy matrices
obs_strict_nonself = compute_strict_nonself_transition_matrix(cluster_labels_hourly, n_clusters, final_clusters, custom_labels)
model_strict_nonself = compute_strict_nonself_transition_matrix(model_cluster_labels_filtered, n_clusters, final_clusters, custom_labels)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

sns.heatmap(obs_strict_nonself, annot=True, cmap="Blues", fmt=".2f", ax=axes[0])
axes[0].set_title("Obs Strict Non-Self Transition Probabilities")
axes[0].set_xlabel("Next Cluster")
axes[0].set_ylabel("Current Cluster")

sns.heatmap(model_strict_nonself, annot=True, cmap="Blues", fmt=".2f", ax=axes[1])
axes[1].set_title("Model Strict Non-Self Transition Probabilities")
axes[1].set_xlabel("Next Cluster")
axes[1].set_ylabel("")

plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/allsite_model_obs_nstp_merged.png", dpi=100, bbox_inches="tight")
