In [None]:
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.dates as mdates
from matplotlib.animation import FuncAnimation

# -- Paths: adjust if needed
OBS_INTERP_FILE   = "obs_interp.npy"
MODEL_INTERP_FILE = "model_interp.npy"
EPOCHS_FILE       = "all_epochs.npy"
LATS_FILE         = "model_lats.npy"
LONS_FILE         = "model_lons.npy"
UNET_SR_FILE      = "superres_vtec.npy"
SRCNN_SR_FILE     = "srcnn_superres_vtec.npy"

# 1. --- LOAD DATA ---
obs_interp   = np.load(OBS_INTERP_FILE)
model_interp = np.load(MODEL_INTERP_FILE)
all_epochs   = np.load(EPOCHS_FILE, allow_pickle=True)    # <--- This is required!
model_lats   = np.load(LATS_FILE)
model_lons   = np.load(LONS_FILE)
Y_pred_unet  = np.load(UNET_SR_FILE)
Y_pred_srcnn = np.load(SRCNN_SR_FILE)
     

# Align
T = min(obs_interp.shape[0], model_interp.shape[0], Y_pred_unet.shape[0], Y_pred_srcnn.shape[0])
obs_interp   = obs_interp[:T]
model_interp = model_interp[:T]
all_epochs   = all_epochs[:T]
Y_pred_unet  = Y_pred_unet[:T, ..., 0]
Y_pred_srcnn = Y_pred_srcnn[:T, ..., 0]

# 2. --- COMPUTE DIFFERENCES ---
diffs = {
    'TIE-GCM':  model_interp - obs_interp,
    'U-Net SR': Y_pred_unet  - obs_interp,
    'SRCNN SR': Y_pred_srcnn - obs_interp
}

# 3. --- GLOBAL HISTOGRAM & SUMMARY ---
plt.figure(figsize=(10,6))
for lbl, arr in diffs.items():
    plt.hist(arr.flatten()[~np.isnan(arr.flatten())], bins=60, alpha=0.5, label=lbl, density=True)
plt.xlabel('Prediction - Observation (TECU)')
plt.ylabel('Frequency')
plt.legend()
plt.title('Distribution of Difference Fields')
plt.tight_layout()
plt.savefig('diff_histogram_compare.png')
plt.close()

print("\nGlobal difference statistics:")
summary_rows = []
for lbl, arr in diffs.items():
    summary = {
        'Method': lbl,
        'Mean': np.nanmean(arr),
        'Median': np.nanmedian(arr),
        'Std': np.nanstd(arr),
        'MAE': np.nanmean(np.abs(arr))
    }
    summary_rows.append(summary)
    print(f"{lbl}: mean={summary['Mean']:.3f}, std={summary['Std']:.3f}, median={summary['Median']:.3f}, MAE={summary['MAE']:.3f}")
summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv("global_difference_summary.csv", index=False)

# 4. --- REGIONAL VALIDATION ---
bands = [(-90,-60),(-60,-30),(-30,0),(0,30),(30,60),(60,90)]
lat_indices = []
lat_labels  = []
for lat1, lat2 in bands:
    idx = np.where((model_lats >= min(lat1,lat2)) & (model_lats < max(lat1,lat2)))[0]
    lat_indices.append(idx)
    lat_labels.append(f"{lat1} to {lat2}")

region_mae = np.zeros((len(lat_labels), len(diffs)))
region_mean = np.zeros((len(lat_labels), len(diffs)))

for b, idxs in enumerate(lat_indices):
    for m, (lbl, arr) in enumerate(diffs.items()):
        region_mae[b, m] = np.nanmean(np.abs(arr[:, idxs, :]))
        region_mean[b, m] = np.nanmean(arr[:, idxs, :])

region_df = pd.DataFrame(region_mae, columns=list(diffs.keys()), index=lat_labels)
region_df.to_csv("regional_mae.csv")
print("\nRegional MAE (rows: lat bands, cols: model):")
print(region_df)

regionmean_df = pd.DataFrame(region_mean, columns=list(diffs.keys()), index=lat_labels)
regionmean_df.to_csv("regional_mean_diff.csv")
print("\nRegional Mean (rows: lat bands, cols: model):")
print(regionmean_df)

# 5. --- PLOTS: PER-REGION MAE AND MEAN ---
plt.figure(figsize=(12,5))
width = 0.25
x = np.arange(len(lat_labels))
for i, lbl in enumerate(diffs.keys()):
    plt.bar(x + i*width, region_mae[:,i], width, label=lbl)
plt.xticks(x + width, lat_labels)
plt.ylabel("Mean Absolute Difference (TECU)")
plt.title("Regional MAE by Latitude Band")
plt.legend()
plt.tight_layout()
plt.savefig("regional_mae_bar.png")
plt.close()

plt.figure(figsize=(12,5))
for i, lbl in enumerate(diffs.keys()):
    plt.bar(x + i*width, region_mean[:,i], width, label=lbl)
plt.xticks(x + width, lat_labels)
plt.ylabel("Mean Difference (TECU)")
plt.title("Regional Mean (Bias) by Latitude Band")
plt.legend()
plt.tight_layout()
plt.savefig("regional_mean_bar.png")
plt.close()

# 6. --- REGIONAL HISTOGRAMS ---
for b, idxs in enumerate(lat_indices):
    plt.figure(figsize=(10,6))
    for lbl, arr in diffs.items():
        region_vals = arr[:, idxs, :].flatten()
        plt.hist(region_vals[~np.isnan(region_vals)], bins=50, alpha=0.5, label=lbl, density=True)
    plt.xlabel('Prediction - Observation (TECU)')
    plt.ylabel('Frequency')
    plt.title(f'Difference Distribution ({lat_labels[b]})')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"regional_hist_{lat_labels[b].replace(' ','').replace('to','_')}.png")
    plt.close()

# 7. --- STATIC & ANIMATION PLOTS (example for U-Net, can extend to others as before) ---
idxs = np.linspace(0, T-1, min(6, T)).astype(int)
for i, idx in enumerate(idxs):
    fig, axs = plt.subplots(1, 3, figsize=(18, 5), subplot_kw={'projection': ccrs.PlateCarree()})
    vmax = max([np.nanmax(np.abs(arr[idx])) for arr in diffs.values()] + [1])
    vmin = -vmax
    titles = [f"{lbl} - Obs\n{str(pd.to_datetime(all_epochs[idx]))[:16]}" for lbl in diffs.keys()]
    for ax, arr, title in zip(axs, diffs.values(), titles):
        im = ax.pcolormesh(model_lons, model_lats, arr[idx], vmin=vmin, vmax=vmax, shading='auto', cmap='bwr')
        ax.coastlines()
        ax.set_title(title)
        cb = plt.colorbar(im, ax=ax, orientation='vertical')
        cb.set_label('Difference (TECU)')
    plt.tight_layout()
    plt.savefig(f"static_diff_compare_{i:02d}.png")
    plt.close(fig)

# --- ANIMATION ---
def make_diff_animation(diff_array, label, all_epochs, model_lats, model_lons, vmax=None):
    from matplotlib.animation import FuncAnimation
    vmax = vmax or np.nanmax(np.abs(diff_array))
    vmin = -vmax
    fig, ax = plt.subplots(figsize=(10,4), subplot_kw={'projection': ccrs.PlateCarree()})
    mesh = ax.pcolormesh(model_lons, model_lats, diff_array[0], vmin=vmin, vmax=vmax, shading='auto', cmap='bwr')
    ax.coastlines()
    cb = plt.colorbar(mesh, ax=ax, orientation='vertical')
    cb.set_label('Difference (TECU)')
    title = ax.set_title(f"{label} - Obs {str(pd.to_datetime(all_epochs[0]))[:16]}")
    def update(idx):
        mesh.set_array(diff_array[idx].ravel())
        title.set_text(f"{label} - Obs {str(pd.to_datetime(all_epochs[idx]))[:16]}")
        return mesh, title
    anim = FuncAnimation(fig, update, frames=diff_array.shape[0], blit=False)
    anim.save(f"anim_{label.lower().replace(' ', '_')}_minus_obs.gif", writer='pillow', fps=4)
    plt.close(fig)
    print(f"Animation saved: anim_{label.lower().replace(' ', '_')}_minus_obs.gif")

# Animation (for all methods)
for lbl, arr in diffs.items():
    make_diff_animation(arr, lbl, all_epochs, model_lats, model_lons)

print("All regional statistics, histograms, static difference maps, and animations complete.")


THe code above splits results into latitude bands. 