In [2]:
import torch
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_distances

## Day wise cosine velocity plot OBS and MODEL

In [55]:
# ==== Site selection ====
site_id = 0  # Change for different sites

obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zargoza", "zargoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}

icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zaragoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}

site_nc, _ = obs_sites_ncvar_name[site_id]
site_nc_icon, _ = icon_sites_ncvar_name[site_id]

# === Load OBS data ===
ds_obs = xr.open_dataset(f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_{site_nc}crops.nc")
obs_times_raw = ds_obs['time'].values
obs_datetimes = pd.to_datetime([t[:12] for t in obs_times_raw], format="%Y%m%d%H%M")

# === Load ICON data ===
ds_model = xr.open_dataset(f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_nc_icon}crops_icon.nc")
model_datetimes = pd.to_datetime(ds_model['time'].values)

# === Define common hourly grid ===
hourly_times = pd.date_range(start=max(obs_datetimes.min(), model_datetimes.min()).floor('H'),
                             end=min(obs_datetimes.max(), model_datetimes.max()).ceil('H'),
                             freq='H')

# === Helper function ===
def match_closest(datetimes, reference_times, tolerance=900):
    matched_indices = []
    matched_times = []
    for ref in reference_times:
        diffs = np.abs((datetimes - ref).total_seconds())
        min_idx = np.argmin(diffs)
        if diffs[min_idx] <= tolerance:
            matched_indices.append(min_idx)
            matched_times.append(ref)
    return matched_indices, matched_times

# === Match indices ===
obs_matched_idx, aligned_obs_times = match_closest(obs_datetimes, hourly_times)
model_matched_idx = [i for i, t in enumerate(model_datetimes) if t in hourly_times]

# === Load features and clusters ===
obs_features_all = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/trainfeat_obs_{site_nc}.pth", map_location="cpu")
obs_clusters_all = np.array(torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{site_nc}_cluster_10_labels.pth", map_location="cpu"))
obs_clusters_all[obs_clusters_all == 0] = 7
obs_clusters_all[obs_clusters_all == 1] = 5
obs_clusters_all[obs_clusters_all == 2] = 3

model_features_all = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/trainfeat_icon_{site_nc_icon}.pth", map_location="cpu")
model_clusters_all = np.array(torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{site_nc_icon}_cluster_10_labels.pth", map_location="cpu"))
model_clusters_all[model_clusters_all == 0] = 7
model_clusters_all[model_clusters_all == 1] = 5
model_clusters_all[model_clusters_all == 2] = 3

# === Filter to matched indices ===
obs_features_hourly = obs_features_all[obs_matched_idx]
obs_clusters_hourly = obs_clusters_all[obs_matched_idx]
obs_hourly_datetimes = pd.to_datetime(aligned_obs_times)

model_features_hourly = model_features_all[model_matched_idx]
model_clusters_hourly = model_clusters_all[model_matched_idx]
model_hourly_datetimes = model_datetimes[model_matched_idx]

# === Compute cosine velocities ===
def compute_cosine_velocity(features):
    return [cosine_distances(features[i].unsqueeze(0), features[i - 1].unsqueeze(0))[0, 0]
            for i in range(1, len(features))]

obs_velocity = compute_cosine_velocity(obs_features_hourly)
model_velocity = compute_cosine_velocity(model_features_hourly)

obs_aligned_clusters = obs_clusters_hourly[1:]
model_aligned_clusters = model_clusters_hourly[1:]

obs_aligned_times = obs_hourly_datetimes[1:]
model_aligned_times = model_hourly_datetimes[1:]

# === Select a top day ===
obs_df = pd.DataFrame({
    "timestamp": obs_aligned_times,
    "date": obs_aligned_times.date,
    "cluster": obs_aligned_clusters
})

model_df = pd.DataFrame({
    "timestamp": model_aligned_times,
    "date": model_aligned_times.date,
    "cluster": model_aligned_clusters
})

######## HERE #########
label_of_interest = 3  # Change as needed
#top_dates = obs_df[obs_df["cluster"] == label_of_interest].groupby("date").size().sort_values(ascending=False).head(5).index
top_dates = model_df[model_df["cluster"] == label_of_interest].groupby("date").size().sort_values(ascending=False).head(5).index
selected_date = top_dates[1]  # Choose top day
#print(f"Date: {selected_date} | OBS cluster {custom_labels.get(label_of_interest, label_of_interest)} count: {obs_count} | MODEL cluster {custom_labels.get(label_of_interest, label_of_interest)} count: {model_count}")
#######################

# === Extract for selected day ===
day_mask_obs = obs_df["date"] == selected_date
day_mask_model = pd.to_datetime(model_aligned_times).date == selected_date

obs_day_times = obs_aligned_times[day_mask_obs]
obs_day_velocity = np.array(obs_velocity)[day_mask_obs]
obs_day_clusters = obs_aligned_clusters[day_mask_obs]

model_day_times = model_aligned_times[day_mask_model]
model_day_velocity = np.array(model_velocity)[day_mask_model]
model_day_clusters = model_aligned_clusters[day_mask_model]

# === Print counts ===
obs_count = (obs_day_clusters == label_of_interest).sum()
model_count = (model_day_clusters == label_of_interest).sum()
print(f"Date: {selected_date} | OBS cluster {label_of_interest} count: {obs_count} | MODEL cluster {label_of_interest} count: {model_count}")

# === Plotting ===
plt.figure(figsize=(12, 5))
plt.plot(obs_day_times, obs_day_velocity, color='blue', linestyle='-', linewidth=1.5, label='OBS Velocity')
plt.plot(model_day_times, model_day_velocity, color='orange', linestyle='--', linewidth=1.5, label='Model Velocity')

# OBS markers
obs_cluster_labeled = False
for t, v, c in zip(obs_day_times, obs_day_velocity, obs_day_clusters):
    if c == label_of_interest and not obs_cluster_labeled:
        plt.scatter(t, v, color='blue', marker='s', s=100, edgecolor='k', label=f'OBS Cluster {label_of_interest}')
        obs_cluster_labeled = True
    elif c == label_of_interest:
        plt.scatter(t, v, color='blue', marker='s', s=100, edgecolor='k')

# MODEL markers
model_cluster_labeled = False
for t, v, c in zip(model_day_times, model_day_velocity, model_day_clusters):
    if c == label_of_interest and not model_cluster_labeled:
        plt.scatter(t, v, color='orange', marker='o', s=100, edgecolor='k', label=f'Model Cluster {label_of_interest}')
        model_cluster_labeled = True
    elif c == label_of_interest:
        plt.scatter(t, v, color='orange', marker='o', s=100, edgecolor='k')

plt.xlabel("Time")
plt.ylabel("Cosine Distance to Previous Hour")
plt.title(f"Latent Cosine Velocity (OBS vs MODEL)\n{selected_date}")
plt.xticks(rotation=45)
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/lcv_cl_model_{label_of_interest}_{selected_date}_merged.png", dpi=100, bbox_inches="tight")
plt.show()

Date: 2023-04-08 | OBS cluster 3 count: 4 | MODEL cluster 3 count: 16


## All sites

In [26]:
# === Cluster merging map ===
merge_map = {
    0: 7,
    1: 5,
    2: 3,
    # others remain unchanged
}

# === Site info ===
obs_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zargoza", "zargoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}

icon_sites_ncvar_name = {
    0: ("juelich", "juelich"),
    1: ("lin", "lin"),
    2: ("warsaw", "warsaw"),
    3: ("vienna", "vienna"),
    4: ("bourges", "bourges"),
    5: ("zaragoza", "zargoza"),
    6: ("sirta", "sirta"),
    7: ("cabauw", "cabauw"),
    8: ("nuremberg", "nuremberg"),
    9: ("aurillac", "aurillac"),
    10: ("dresden", "dresden"),
}

label_of_interest = 9
top_k = 10

all_hourly_diffs = defaultdict(list)

def apply_merge_map(clusters):
    for old, new in merge_map.items():
        clusters[clusters == old] = new
    return clusters

for site_id in range(11):
    site_nc, _ = obs_sites_ncvar_name[site_id]
    site_nc_icon, site_feat_icon = icon_sites_ncvar_name[site_id]

    ds_obs = xr.open_dataset(f"/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_{site_nc}crops.nc")
    obs_times_raw = ds_obs['time'].values
    obs_datetimes = pd.to_datetime([t[:12] for t in obs_times_raw], format="%Y%m%d%H%M")

    ds_model = xr.open_dataset(f"/p/scratch/exaww/chatterjee1/nn_obs/continuous/msgobs_108_{site_nc_icon}crops_icon.nc")
    model_datetimes = pd.to_datetime(ds_model['time'].values)

    hourly_times = pd.date_range(start=max(obs_datetimes.min(), model_datetimes.min()).floor('H'),
                                 end=min(obs_datetimes.max(), model_datetimes.max()).ceil('H'),
                                 freq='H')

    def match_closest(datetimes, reference_times, tolerance=900):
        matched_indices, matched_times = [], []
        for ref in reference_times:
            diffs = np.abs((datetimes - ref).total_seconds())
            min_idx = np.argmin(diffs)
            if diffs[min_idx] <= tolerance:
                matched_indices.append(min_idx)
                matched_times.append(ref)
        return matched_indices, matched_times

    obs_matched_idx, aligned_obs_times = match_closest(obs_datetimes, hourly_times)
    model_matched_idx = [i for i, t in enumerate(model_datetimes) if t in hourly_times]

    obs_features_all = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/trainfeat_obs_{site_nc}.pth", map_location="cpu")
    obs_clusters_all = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_obs_features/obs_{site_nc}_cluster_10_labels.pth", map_location="cpu")
    obs_clusters_all = apply_merge_map(np.array(obs_clusters_all))

    model_features_all = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/trainfeat_icon_{site_feat_icon}.pth", map_location="cpu")
    model_clusters_all = torch.load(f"/p/project/exaww/chatterjee1/mcspss_continuous/analysis/location_icon_features/icon_{site_feat_icon}_cluster_10_labels.pth", map_location="cpu")
    model_clusters_all = apply_merge_map(np.array(model_clusters_all))

    obs_features_hourly = obs_features_all[obs_matched_idx]
    obs_clusters_hourly = obs_clusters_all[obs_matched_idx]
    obs_hourly_datetimes = pd.to_datetime(aligned_obs_times)

    model_features_hourly = model_features_all[model_matched_idx]
    model_clusters_hourly = model_clusters_all[model_matched_idx]
    model_hourly_datetimes = model_datetimes[model_matched_idx]

    #obs_df = pd.DataFrame({"timestamp": obs_hourly_datetimes, "date": obs_hourly_datetimes.date, "cluster": obs_clusters_hourly})
    #top_dates = obs_df[obs_df["cluster"] == label_of_interest].groupby("date").size().sort_values(ascending=False).head(top_k).index
    model_df = pd.DataFrame({"timestamp": model_hourly_datetimes, "date": model_hourly_datetimes.date, "cluster": model_clusters_hourly})
    top_dates = model_df[model_df["cluster"] == label_of_interest].groupby("date").size().sort_values(ascending=False).head(top_k).index

    for selected_date in top_dates:
        day_indices_obs = obs_df[obs_df["date"] == selected_date].index
        obs_day_features = obs_features_hourly[day_indices_obs]
        obs_day_times = obs_hourly_datetimes[day_indices_obs]
        obs_day_clusters = obs_clusters_hourly[day_indices_obs]

        model_df = pd.DataFrame({"timestamp": model_hourly_datetimes, "date": model_hourly_datetimes.date, "cluster": model_clusters_hourly})
        day_indices_model = model_df[model_df["date"] == selected_date].index
        model_day_features = model_features_hourly[day_indices_model]
        model_day_times = model_hourly_datetimes[day_indices_model]
        model_day_clusters = model_clusters_hourly[day_indices_model]

        def compute_cosine_velocity(features):
            return [cosine_distances(features[i].unsqueeze(0), features[i - 1].unsqueeze(0))[0, 0] for i in range(1, len(features))]

        obs_velocity = compute_cosine_velocity(obs_day_features)
        model_velocity = compute_cosine_velocity(model_day_features)

        obs_times = obs_day_times[1:]
        model_times = model_day_times[1:]
        common_times = sorted(set(obs_times).intersection(set(model_times)))

        for t in common_times:
            obs_idx = np.where(obs_times == t)[0][0]
            model_idx = np.where(model_times == t)[0][0]
            diff = obs_velocity[obs_idx] - model_velocity[model_idx]
            hour_label = t.strftime("%H:%M")
            all_hourly_diffs[hour_label].append(diff)

        print(f"Site: {site_nc}, Date: {selected_date}, OBS cluster {label_of_interest} count: {(obs_day_clusters == label_of_interest).sum()}, MODEL cluster {label_of_interest} count: {(model_day_clusters == label_of_interest).sum()}")

# === Plotting ===
hour_labels = sorted(all_hourly_diffs.keys(), key=lambda x: int(x.split(":")[0]))
plt.figure(figsize=(12, 6))

for hour in hour_labels:
    y_vals = all_hourly_diffs[hour]
    y_vals_filtered = [y for y in y_vals if -0.35 <= y <= 0.35]
    plt.scatter([hour] * len(y_vals_filtered), y_vals_filtered, color='dimgray', alpha=0.9, marker='o')

mean_vals = [np.mean(all_hourly_diffs[h]) for h in hour_labels]
std_vals = [np.std(all_hourly_diffs[h]) for h in hour_labels]

plt.plot(hour_labels, mean_vals, color='black', linewidth=2, label='Mean Difference')
plt.fill_between(hour_labels, np.array(mean_vals) - np.array(std_vals), np.array(mean_vals) + np.array(std_vals),
                 color='gray', alpha=0.3, label='Std Dev')

plt.xticks(rotation=45)
plt.grid(True)
plt.xlabel("Hour of Day")
plt.ylabel("Cosine Velocity Difference (OBS - MODEL)")
plt.title(f"Obs - Model, Hourly Cosine Velocity Difference Ensemble\n(Cluster {label_of_interest}) All Sites, Top {top_k} Days")
plt.legend()
plt.tight_layout()
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/continuous/lcv_cl_model_{label_of_interest}_ensemble_all_sites_top{top_k}_merged.png", dpi=100)
plt.show()

Site: juelich, Date: 2023-09-23, OBS cluster 9 count: 0, MODEL cluster 9 count: 19
Site: juelich, Date: 2023-04-29, OBS cluster 9 count: 0, MODEL cluster 9 count: 17
Site: juelich, Date: 2023-08-07, OBS cluster 9 count: 0, MODEL cluster 9 count: 15
Site: juelich, Date: 2023-09-22, OBS cluster 9 count: 0, MODEL cluster 9 count: 15
Site: juelich, Date: 2023-05-23, OBS cluster 9 count: 0, MODEL cluster 9 count: 12
Site: juelich, Date: 2023-05-26, OBS cluster 9 count: 4, MODEL cluster 9 count: 11
Site: juelich, Date: 2023-05-22, OBS cluster 9 count: 0, MODEL cluster 9 count: 11
Site: juelich, Date: 2023-09-30, OBS cluster 9 count: 0, MODEL cluster 9 count: 10
Site: juelich, Date: 2023-06-02, OBS cluster 9 count: 6, MODEL cluster 9 count: 10
Site: juelich, Date: 2023-08-26, OBS cluster 9 count: 1, MODEL cluster 9 count: 10
Site: lin, Date: 2023-09-24, OBS cluster 9 count: 1, MODEL cluster 9 count: 15
Site: lin, Date: 2023-09-14, OBS cluster 9 count: 3, MODEL cluster 9 count: 13
Site: lin, D

  plt.figure(figsize=(12, 6))


In [29]:
d = xr.open_dataset("/p/scratch/exaww/chatterjee1/louis/data/Input_CGAN_63978_STH20.nc",engine="netcdf4")
d

OSError: [Errno -51] NetCDF: Unknown file format: b'/p/scratch/exaww/chatterjee1/louis/data/Input_CGAN_63978_STH20.nc'

In [30]:
import h5py
f = h5py.File("/p/scratch/exaww/chatterjee1/louis/data/Input_CGAN_63978_STH20.nc", "r")
print(list(f.keys()))

OSError: Unable to open file (file signature not found)