This code trys to run a SRCNN Model. For April 9-11 overlap with TIE-GCM. Run the code that combines the ground obs first... 

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import os
from skimage.transform import resize
import warnings
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.dates as mdates
import tensorflow as tf
from tensorflow.keras import layers, models

# --- CONFIG ---
COMBINED_IONEX_PARQUET = "combined_ionex_vtec.parquet"
TIEGCM_DIR = './ionosphere_central/CCMC/model_data/TIE-GCM/129/Akshay_Ramesh_042125_IT_5/'
SPATIAL_FACTOR = 2
TEMPORAL_MINUTES = 10
UNET_SR_FILE = "superres_vtec.npy"       # Pre-saved U-Net predictions
SRCNN_SR_FILE = "srcnn_superres_vtec.npy"

warnings.filterwarnings("ignore", category=UserWarning)

def save_if_not_exists(fname, arr):
    if not os.path.exists(fname):
        np.save(fname, arr)

def load_model_grid(directory):
    model_files = sorted([
        f for f in os.listdir(directory)
        if not f.endswith('.json')
    ])
    grids, times = [], []
    lat0, lon0 = None, None
    for fname in model_files:
        fpath = os.path.join(directory, fname)
        try:
            try:
                ds = xr.open_dataset(fpath, engine='netcdf4')
            except Exception:
                ds = xr.open_dataset(fpath, engine='scipy')
            if 'TEC' not in ds.variables:
                print(f"Skipping {fname}: no 'TEC'")
                continue
            tec = ds['TEC'].values
            if lat0 is None:
                lat0 = ds['lat'].values
                lon0 = ds['lon'].values
            if 'time' in ds:
                tvals = pd.to_datetime(ds['time'].values)
                for i in range(tec.shape[0]):
                    grids.append(tec[i] / 1e12)
                    times.append(tvals[i])
            else:
                grids.append(tec / 1e12)
                times.append(None)
            print(f"Loaded {fname} shape {tec.shape}")
        except Exception as e:
            print(f"Failed to load {fname}: {e}")
    if not grids:
        raise RuntimeError("No valid model grids found! Please check files and data.")
    grids = np.stack(grids)
    times = pd.to_datetime(times)
    print(f"Loaded {len(grids)} model grids, grid shape {grids.shape}")
    return grids, times, lat0, lon0

def make_training_pairs(data, spatial_factor):
    n, h, w = data.shape
    X, Y = [], []
    for i in range(n):
        hr = data[i]
        lr = resize(hr, (h//spatial_factor, w//spatial_factor), order=3, anti_aliasing=True)
        up = resize(lr, (h, w), order=3, anti_aliasing=False)
        X.append(up)
        Y.append(hr)
    X = np.stack(X)[..., None]
    Y = np.stack(Y)[..., None]
    return X, Y

def interp_time(series_times, series_grids, new_times):
    from scipy.interpolate import interp1d
    interp = interp1d([pd.Timestamp(t).timestamp() for t in series_times],
                      series_grids, axis=0, kind='linear', bounds_error=False, fill_value="extrapolate")
    return interp([pd.Timestamp(t).timestamp() for t in new_times])

def interpolate_to_grid(obs_grids, obs_lats, obs_lons, tgt_lats, tgt_lons):
    from scipy.interpolate import RegularGridInterpolator
    out = []
    for grid in obs_grids:
        interp = RegularGridInterpolator((obs_lats, obs_lons), grid, bounds_error=False, fill_value=np.nan)
        mesh = np.meshgrid(tgt_lats, tgt_lons, indexing='ij')
        pts = np.stack([m.ravel() for m in mesh], axis=-1)
        grid_interp = interp(pts).reshape(len(tgt_lats), len(tgt_lons))
        out.append(grid_interp)
    return np.stack(out)

def get_lat_band_indices(lats, bands=[(-90,-60),(-60,-30),(-30,0),(0,30),(30,60),(60,90)]):
    band_indices = {}
    lats = np.array(lats)
    for (lat1,lat2) in bands:
        mask = (lats >= min(lat1,lat2)) & (lats < max(lat1,lat2))
        band_indices[f"{lat1} to {lat2}"] = np.where(mask)[0]
    return band_indices

def plot_histograms(errors_dict, out_prefix="compare_unet_srcnn"):
    plt.figure(figsize=(10,6))
    for lbl, errs in errors_dict.items():
        plt.hist(errs[~np.isnan(errs)], bins=50, alpha=0.5, label=lbl, density=True)
    plt.xlabel('Absolute Error (TECU)')
    plt.ylabel('Frequency')
    plt.title('Histogram of Absolute Errors')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_error_histograms.png", dpi=120)
    plt.close()

    plt.figure(figsize=(8,5))
    data = [errs[~np.isnan(errs)] for errs in errors_dict.values()]
    # Use tick_labels instead of labels for compatibility with matplotlib >=3.9
    plt.boxplot(data, vert=True, tick_labels=list(errors_dict.keys()), showmeans=True)
    plt.ylabel('Absolute Error (TECU)')
    plt.title('Boxplot of Absolute Errors')
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_error_boxplots.png", dpi=120)
    plt.close()

def plot_region_mae_rmse(band_names, region_mae, region_rmse, model_names, out_prefix="compare_unet_srcnn"):
    width = 0.25
    x = np.arange(len(band_names))
    for error_name, error_array, ylabel, fname in [
        ("MAE", region_mae, 'Mean Absolute Error (TECU)', "region_mae"),
        ("RMSE", region_rmse, 'Root Mean Square Error (TECU)', "region_rmse")
    ]:
        plt.figure(figsize=(12,6))
        for i, model in enumerate(model_names):
            plt.bar(x + i*width, error_array[:,i], width, label=model)
        plt.xticks(x + width, band_names, rotation=20)
        plt.ylabel(ylabel)
        plt.title(f'{error_name} by Latitude Band')
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{out_prefix}_{fname}_latbands.png", dpi=120)
        plt.close()

def plot_comparisons(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_unet, Y_srcnn, out_prefix='compare_unet_srcnn'):
    vmin = np.nanmin([np.nanmin(obs_interp), np.nanmin(model_interp), np.nanmin(Y_unet), np.nanmin(Y_srcnn)])
    vmax = np.nanmax([np.nanmax(obs_interp), np.nanmax(model_interp), np.nanmax(Y_unet), np.nanmax(Y_srcnn)])
    idxs = np.linspace(0, len(all_epochs)-1, min(10, len(all_epochs))).astype(int)
    for i, idx in enumerate(idxs):
        fig, axs = plt.subplots(1, 4, figsize=(24, 5), subplot_kw={'projection': ccrs.PlateCarree()})
        titles = [
            f"Obs {all_epochs[idx]:%m-%d %H:%M}",
            f"Physical Model {all_epochs[idx]:%m-%d %H:%M}",
            f"U-Net SR {all_epochs[idx]:%m-%d %H:%M}",
            f"SRCNN SR {all_epochs[idx]:%m-%d %H:%M}",
        ]
        data = [obs_interp[idx], model_interp[idx], Y_unet[idx][..., 0], Y_srcnn[idx][..., 0]]
        for ax, dat, title in zip(axs, data, titles):
            im = ax.pcolormesh(model_lons, model_lats, dat, vmin=vmin, vmax=vmax, shading='auto', cmap='plasma')
            ax.coastlines()
            ax.set_title(title)
            cb = plt.colorbar(im, ax=ax, orientation='vertical')
            cb.set_label('VTEC (TECU)')
        plt.tight_layout()
        plt.savefig(f"{out_prefix}_static_{i:02d}.png", dpi=120)
        plt.close(fig)

    # --- RMSE Time Series ---
    rmse_model = np.sqrt(np.nanmean((obs_interp - model_interp) ** 2, axis=(1, 2)))
    rmse_unet = np.sqrt(np.nanmean((obs_interp - Y_unet[:obs_interp.shape[0],..., 0]) ** 2, axis=(1, 2)))
    rmse_srcnn = np.sqrt(np.nanmean((obs_interp - Y_srcnn[:obs_interp.shape[0],..., 0]) ** 2, axis=(1, 2)))
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.plot(all_epochs, rmse_model, label='Physical Model', color='tab:blue')
    ax.plot(all_epochs, rmse_unet, label='U-Net SR', color='tab:orange')
    ax.plot(all_epochs, rmse_srcnn, label='SRCNN SR', color='tab:green')
    ax.set_ylabel('RMSE (TECU)')
    ax.set_title('RMSE of Model vs. U-Net vs. SRCNN vs. Obs')
    ax.legend()
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
    fig.autofmt_xdate()
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_rmse_timeseries.png", dpi=120)
    plt.close(fig)

    # --- MAE MAPS ---
    mae_model = np.nanmean(np.abs(obs_interp - model_interp), axis=0)
    mae_unet = np.nanmean(np.abs(obs_interp - Y_unet[:obs_interp.shape[0],..., 0]), axis=0)
    mae_srcnn = np.nanmean(np.abs(obs_interp - Y_srcnn[:obs_interp.shape[0],..., 0]), axis=0)
    fig, axs = plt.subplots(1, 3, figsize=(18, 5), subplot_kw={'projection': ccrs.PlateCarree()})
    for ax, dat, title in zip(axs, [mae_unet, mae_srcnn, mae_model], ['U-Net MAE', 'SRCNN MAE', 'Model MAE']):
        im = ax.pcolormesh(model_lons, model_lats, dat, shading='auto', cmap='inferno')
        ax.coastlines()
        ax.set_title(title)
        cb = plt.colorbar(im, ax=ax, orientation='vertical')
        cb.set_label('Mean Absolute Error (TECU)')
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_mae_maps.png", dpi=120)
    plt.close(fig)

    print(f"\n--- RMSE Averages ---")
    print(f"Physical Model: {np.nanmean(rmse_model):.3f} TECU")
    print(f"U-Net SR:       {np.nanmean(rmse_unet):.3f} TECU")
    print(f"SRCNN SR:       {np.nanmean(rmse_srcnn):.3f} TECU")
    print(f"\n--- MAE Global Mean ---")
    print(f"Physical Model: {np.nanmean(mae_model):.3f} TECU")
    print(f"U-Net MAE:      {np.nanmean(mae_unet):.3f} TECU")
    print(f"SRCNN MAE:      {np.nanmean(mae_srcnn):.3f} TECU")

    # --- Animation (skip frames for speed!) ---
    from matplotlib.animation import FuncAnimation
    anim_fig, anim_axs = plt.subplots(1, 4, figsize=(28, 6), subplot_kw={'projection': ccrs.PlateCarree()})
    FRAME_SKIP = 5
    total_frames = obs_interp.shape[0]
    frame_indices = np.arange(0, total_frames, FRAME_SKIP)
    def make_frame(idx):
        for ax in anim_axs:
            ax.cla()
        data = [
            obs_interp[idx], model_interp[idx], Y_unet[idx][..., 0], Y_srcnn[idx][..., 0]
        ]
        labels = ['Obs (Interp)', 'Model', 'U-Net SR', 'SRCNN SR']
        for ax, dat, lbl in zip(anim_axs, data, labels):
            im = ax.pcolormesh(model_lons, model_lats, dat, vmin=vmin, vmax=vmax, shading='auto', cmap='plasma')
            ax.coastlines()
            ax.set_title(lbl)
            cb = plt.colorbar(im, ax=ax, orientation='vertical')
            cb.set_label('VTEC (TECU)')
        anim_fig.suptitle(f'Global VTEC {all_epochs[idx]:%Y-%m-%d %H:%M}')
        return anim_axs
    ani = FuncAnimation(anim_fig, make_frame, frames=frame_indices, blit=False)
    ani.save(f"{out_prefix}_compare.gif", writer='pillow')
    plt.close(anim_fig)
    print(f"Animation saved to {out_prefix}_compare.gif (every {FRAME_SKIP}th frame)")

def main_compare_srcnn_unet_regions():
    # --- Load data as before ---
    obs_df = pd.read_parquet(COMBINED_IONEX_PARQUET)
    obs_lats = sorted(set([lat for lat, _ in obs_df.columns]))
    obs_lons = sorted(set([lon for _, lon in obs_df.columns]))
    obs_epochs = pd.to_datetime(obs_df.index)
    obs_grids = np.stack([obs_df.loc[t].values.reshape((len(obs_lats), len(obs_lons))) for t in obs_df.index])

    # Save obs/model/interp as .npy for reproducible, fast reruns
    model_grids, model_times, model_lats, model_lons = load_model_grid(TIEGCM_DIR)
    obs_on_model = interpolate_to_grid(obs_grids, obs_lats, obs_lons, model_lats, model_lons)
    all_epochs = pd.date_range(start=max(obs_epochs[0], model_times[0]), end=min(obs_epochs[-1], model_times[-1]), freq=f'{TEMPORAL_MINUTES}min')
    obs_interp = interp_time(obs_epochs, obs_on_model, all_epochs)
    model_interp = interp_time(model_times, model_grids, all_epochs)

    # Save all important intermediates
    save_if_not_exists("model_grids.npy", model_grids)
    save_if_not_exists("model_times.npy", np.array(model_times, dtype='O'))
    save_if_not_exists("model_lats.npy", model_lats)
    save_if_not_exists("model_lons.npy", model_lons)
    save_if_not_exists("obs_on_model.npy", obs_on_model)
    save_if_not_exists("all_epochs.npy", np.array(all_epochs, dtype='O'))
    save_if_not_exists("obs_interp.npy", obs_interp)
    save_if_not_exists("model_interp.npy", model_interp)

    # --- U-Net and SRCNN predictions ---
    Y_pred_unet = np.load(UNET_SR_FILE)
    Y_pred_srcnn = np.load(SRCNN_SR_FILE)
    N = obs_interp.shape[0]
    Y_pred_unet = Y_pred_unet[:N]
    Y_pred_srcnn = Y_pred_srcnn[:N]
    model_interp = model_interp[:N]
    all_epochs = all_epochs[:N]

    # --- Absolute errors for statistics ---
    err_model = np.abs(obs_interp - model_interp)
    err_unet = np.abs(obs_interp - Y_pred_unet[...,0])
    err_srcnn = np.abs(obs_interp - Y_pred_srcnn[...,0])
    plot_histograms({
        "Physical Model": err_model.flatten(),
        "U-Net SR": err_unet.flatten(),
        "SRCNN SR": err_srcnn.flatten()
    })

    # --- Per-region (latitude band) RMSE/MAE ---
    bands = [(-90,-60),(-60,-30),(-30,0),(0,30),(30,60),(60,90)]
    band_indices = get_lat_band_indices(model_lats, bands=bands)
    band_names = list(band_indices.keys())
    region_mae = np.zeros((len(band_names), 3))
    region_rmse = np.zeros((len(band_names), 3))
    for i, band in enumerate(band_names):
        idxs = list(band_indices[band])
        region_mae[i, 0] = np.nanmean(err_model[:, idxs, :])
        region_mae[i, 1] = np.nanmean(err_unet[:, idxs, :])
        region_mae[i, 2] = np.nanmean(err_srcnn[:, idxs, :])
        region_rmse[i, 0] = np.sqrt(np.nanmean((obs_interp[:, idxs, :] - model_interp[:, idxs, :]) ** 2))
        # --- FIX: use np.take to preserve axis order ---
        region_rmse[i, 1] = np.sqrt(np.nanmean((obs_interp[:, idxs, :] - np.take(Y_pred_unet[:N, :, :, 0], idxs, axis=1)) ** 2))
        region_rmse[i, 2] = np.sqrt(np.nanmean((obs_interp[:, idxs, :] - np.take(Y_pred_srcnn[:N, :, :, 0], idxs, axis=1)) ** 2))


    summary_df = pd.DataFrame(region_mae, columns=["Physical Model MAE", "U-Net MAE", "SRCNN MAE"], index=band_names)
    summary_df.to_csv("compare_unet_srcnn_region_mae.csv")
    summary2_df = pd.DataFrame(region_rmse, columns=["Physical Model RMSE", "U-Net RMSE", "SRCNN RMSE"], index=band_names)
    summary2_df.to_csv("compare_unet_srcnn_region_rmse.csv")
    print("\nPer-region MAE and RMSE (TECU):")
    print(summary_df)
    print(summary2_df)

    plot_region_mae_rmse(band_names, region_mae, region_rmse, ["Physical Model","U-Net SR","SRCNN SR"])
    plot_comparisons(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred_unet, Y_pred_srcnn)
    print("Per-region comparison, histograms, maps, and summary files complete.")

if __name__ == "__main__":
    main_compare_srcnn_unet_regions()
