Self contained UNET with comparisons. 

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import os
import re
from datetime import datetime, timedelta
from scipy.interpolate import interp1d, RegularGridInterpolator
from skimage.transform import resize
import warnings
import matplotlib
matplotlib.use('Agg')  # Force non-interactive backend for animation safety!
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as animation
import matplotlib.dates as mdates

warnings.filterwarnings("ignore", category=UserWarning)

# --- CONFIG ---
IONEX_PATHS = [
    './ionosphere_central/vTEC_data/uqrg1000.24i',
    './ionosphere_central/vTEC_data/uqrg1010.24i'
]
COMBINED_IONEX_PARQUET = "combined_ionex_vtec.parquet"
TIEGCM_DIR = './ionosphere_central/CCMC/model_data/TIE-GCM/129/Akshay_Ramesh_042125_IT_5/'
SPATIAL_FACTOR = 2
TEMPORAL_MINUTES = 10

# --- Parse IONEX Observations ---
def parse_ionex_file(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    header_end = [i for i, l in enumerate(lines) if 'END OF HEADER' in l][0]
    header = lines[:header_end]
    for line in header:
        if 'LAT1 / LAT2 / DLAT' in line:
            lat1, lat2, dlat = map(float, line.split()[:3])
        if 'LON1 / LON2 / DLON' in line:
            lon1, lon2, dlon = map(float, line.split()[:3])
    lats = np.arange(lat1, lat2 - 0.1, -abs(dlat))
    lons = np.arange(lon1, lon2 + 0.1, dlon)
    maps, epochs = [], []
    i = header_end + 1
    while i < len(lines):
        if 'START OF TEC MAP' in lines[i]:
            epoch_line = lines[i+1]
            y, mo, d, h, mi, s = map(int, epoch_line[:36].split())
            epoch = datetime(y, mo, d, h%24, mi, s) + timedelta(days=h//24)
            epochs.append(epoch)
            grid = []
            i += 2
            while 'END OF TEC MAP' not in lines[i]:
                if 'LAT/LON1/LON2/DLON/H' in lines[i]:
                    row = []
                    i += 1
                    while lines[i].strip() and 'LAT/LON1/LON2/DLON/H' not in lines[i] and 'END OF TEC MAP' not in lines[i]:
                        line_values = re.findall(r'[-+]?[0-9]*\.?[0-9]+', lines[i])
                        vals = [float(v) if v != '9999' else np.nan for v in line_values]
                        row.extend(vals)
                        i += 1
                    if len(row) == len(lons):
                        grid.append(row)
                else:
                    i += 1
            if len(grid) == len(lats):
                maps.append(grid)
        else:
            i += 1
    maps = np.array(maps)
    data = {(lat, lon): maps[:, lat_idx, lon_idx] for lat_idx, lat in enumerate(lats) for lon_idx, lon in enumerate(lons)}
    df = pd.DataFrame(data, index=epochs)
    df = 0.1 * df  # Correction factor as requested
    df.columns = pd.MultiIndex.from_tuples(df.columns, names=['lat', 'lon'])
    return df

def load_obs_grid(paths, save_path=None):
    dfs = []
    for fp in paths:
        if fp:
            try:
                dfs.append(parse_ionex_file(fp))
            except Exception as e:
                print(f"Failed to process {fp}: {e}")
    df = pd.concat(dfs)
    df = df[~df.index.duplicated(keep='first')].sort_index()
    if save_path:
        df.to_parquet(save_path)
    return df

def load_model_grid(directory):
    model_files = sorted([
        f for f in os.listdir(directory)
        if not f.endswith('.json')
    ])
    grids, times = [], []
    lat0, lon0 = None, None
    for fname in model_files:
        fpath = os.path.join(directory, fname)
        try:
            try:
                ds = xr.open_dataset(fpath, engine='netcdf4')
            except Exception:
                ds = xr.open_dataset(fpath, engine='scipy')
            if 'TEC' not in ds.variables:
                print(f"Skipping {fname}: no 'TEC'")
                continue
            tec = ds['TEC'].values
            if lat0 is None:
                lat0 = ds['lat'].values
                lon0 = ds['lon'].values
            if 'time' in ds:
                tvals = pd.to_datetime(ds['time'].values)
                for i in range(tec.shape[0]):
                    grids.append(tec[i] / 1e12)
                    times.append(tvals[i])
            else:
                grids.append(tec / 1e12)
                times.append(None)
            print(f"Loaded {fname} shape {tec.shape}")
        except Exception as e:
            print(f"Failed to load {fname}: {e}")
    if not grids:
        raise RuntimeError("No valid model grids found! Please check files and data.")
    grids = np.stack(grids)
    times = pd.to_datetime(times)
    print(f"Loaded {len(grids)} model grids, grid shape {grids.shape}")
    return grids, times, lat0, lon0

def interpolate_to_grid(obs_grids, obs_lats, obs_lons, tgt_lats, tgt_lons):
    out = []
    for grid in obs_grids:
        interp = RegularGridInterpolator((obs_lats, obs_lons), grid, bounds_error=False, fill_value=np.nan)
        mesh = np.meshgrid(tgt_lats, tgt_lons, indexing='ij')
        pts = np.stack([m.ravel() for m in mesh], axis=-1)
        grid_interp = interp(pts).reshape(len(tgt_lats), len(tgt_lons))
        out.append(grid_interp)
    return np.stack(out)

def interp_time(series_times, series_grids, new_times):
    interp = interp1d([pd.Timestamp(t).timestamp() for t in series_times],
                      series_grids, axis=0, kind='linear', bounds_error=False, fill_value="extrapolate")
    return interp([pd.Timestamp(t).timestamp() for t in new_times])

def make_training_pairs(data, spatial_factor):
    n, h, w = data.shape
    X, Y = [], []
    for i in range(n):
        hr = data[i]
        lr = resize(hr, (h//spatial_factor, w//spatial_factor), order=3, anti_aliasing=True)
        up = resize(lr, (h, w), order=3, anti_aliasing=False)
        X.append(up)
        Y.append(hr)
    X = np.stack(X)[..., None]
    Y = np.stack(Y)[..., None]
    return X, Y

# --- MODELS (TF) ---
import tensorflow as tf
from tensorflow.keras import layers, models

def build_unet(input_shape):
    inputs = layers.Input(shape=input_shape)
    c1 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)
    c2 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)
    m = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p2)
    u2 = layers.Conv2DTranspose(64, (2,2), strides=(2,2), padding='same')(m)
    u2 = layers.concatenate([u2, c2])
    c3 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u2)
    u1 = layers.Conv2DTranspose(32, (2,2), strides=(2,2), padding='same')(c3)
    u1 = layers.concatenate([u1, c1])
    c4 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(u1)
    outputs = layers.Conv2D(1, (1,1), activation='linear', padding='same')(c4)
    return models.Model(inputs, outputs)

def build_srcnn(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, 9, padding='same', activation='relu')(inputs)
    x = layers.Conv2D(32, 1, padding='same', activation='relu')(x)
    outputs = layers.Conv2D(1, 5, padding='same')(x)
    return models.Model(inputs, outputs)

def train_tf(model, X_train, Y_train, epochs=20, batch_size=8):
    model.compile(optimizer='adam', loss='mse')
    model.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    return model

# --- PLOTTING & VERIFICATION ---
def verify_and_plot_results(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred, out_prefix='comparison'):
    vmin = np.nanmin([np.nanmin(obs_interp), np.nanmin(model_interp), np.nanmin(Y_pred)])
    vmax = np.nanmax([np.nanmax(obs_interp), np.nanmax(model_interp), np.nanmax(Y_pred)])

    idxs = np.linspace(0, len(all_epochs)-1, min(20, len(all_epochs))).astype(int)

    # --- Static Example Comparison Plots ---
    for i, idx in enumerate(idxs[:5]):
        fig, axs = plt.subplots(1, 3, figsize=(18, 5), subplot_kw={'projection': ccrs.PlateCarree()})
        titles = [
            f"Obs (Interp) {all_epochs[idx]:%Y-%m-%d %H:%M}",
            f"Physical Model {all_epochs[idx]:%Y-%m-%d %H:%M}",
            f"Super-Res Output {all_epochs[idx]:%Y-%m-%d %H:%M}",
        ]
        data = [obs_interp[idx], model_interp[idx], Y_pred[idx][..., 0]]
        for ax, dat, title in zip(axs, data, titles):
            im = ax.pcolormesh(model_lons, model_lats, dat, vmin=vmin, vmax=vmax, shading='auto', cmap='plasma')
            ax.coastlines()
            ax.set_title(title)
            cb = plt.colorbar(im, ax=ax, orientation='vertical')
            cb.set_label('VTEC (TECU)')
        plt.tight_layout()
        plt.savefig(f"{out_prefix}_static_{i:02d}.png")
        plt.close(fig)

    # --- RMSE Time Series Plot ---
    rmse_model = np.sqrt(np.nanmean((obs_interp - model_interp) ** 2, axis=(1, 2)))
    rmse_sr = np.sqrt(np.nanmean((obs_interp - Y_pred[:obs_interp.shape[0],..., 0]) ** 2, axis=(1, 2)))
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(all_epochs, rmse_model, label='Physical Model', color='tab:blue')
    ax.plot(all_epochs, rmse_sr, label='Super-Res', color='tab:orange')
    ax.set_ylabel('RMSE (TECU)')
    ax.set_title('RMSE of Model vs. Super-Res vs. Obs')
    ax.legend()
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
    fig.autofmt_xdate()
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_rmse_timeseries.png")
    plt.close(fig)

    # --- Global Error Map (Mean Absolute Error) ---
    mae_model = np.nanmean(np.abs(obs_interp - model_interp), axis=0)
    mae_sr = np.nanmean(np.abs(obs_interp - Y_pred[:obs_interp.shape[0],..., 0]), axis=0)
    fig, axs = plt.subplots(1, 2, figsize=(14, 5), subplot_kw={'projection': ccrs.PlateCarree()})
    for ax, dat, title in zip(axs, [mae_model, mae_sr], ['Physical Model MAE', 'Super-Res MAE']):
        im = ax.pcolormesh(model_lons, model_lats, dat, shading='auto', cmap='inferno')
        ax.coastlines()
        ax.set_title(title)
        cb = plt.colorbar(im, ax=ax, orientation='vertical')
        cb.set_label('Mean Absolute Error (TECU)')
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_mae_maps.png")
    plt.close(fig)

    print("Comparison plots saved.")

    # --- ANIMATION (Robust, no ax.clear, no memory leak) ---
    def animate_panel(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred, out_prefix):
        n_frames = obs_interp.shape[0]
        vmin = np.nanmin([np.nanmin(obs_interp), np.nanmin(model_interp), np.nanmin(Y_pred)])
        vmax = np.nanmax([np.nanmax(obs_interp), np.nanmax(model_interp), np.nanmax(Y_pred)])
        fig, axs = plt.subplots(1, 3, figsize=(21, 6), subplot_kw={'projection': ccrs.PlateCarree()})
        titles = ['Obs (Interp)', 'Physical Model', 'Super-Res']
        arrs = [obs_interp[0], model_interp[0], Y_pred[0][..., 0]]
        meshes = []
        for ax, arr, title in zip(axs, arrs, titles):
            mesh = ax.pcolormesh(model_lons, model_lats, arr, vmin=vmin, vmax=vmax, shading='auto', cmap='plasma')
            ax.coastlines()
            ax.set_title(title)
            cb = plt.colorbar(mesh, ax=ax, orientation='vertical')
            cb.set_label('VTEC (TECU)')
            meshes.append(mesh)
        plt.tight_layout()

        def update(idx):
            arrs = [obs_interp[idx], model_interp[idx], Y_pred[idx][..., 0]]
            for mesh, arr in zip(meshes, arrs):
                mesh.set_array(arr.ravel())
            fig.suptitle(f'Global VTEC {all_epochs[idx]:%Y-%m-%d %H:%M}')
            return meshes

        ani = animation.FuncAnimation(fig, update, frames=n_frames, blit=False)
        ani.save(f'{out_prefix}_compare.gif', writer='pillow', fps=5)
        plt.close(fig)
        print(f"Animation saved to {out_prefix}_compare.gif")

    # Add this call at the end of verify_and_plot_results
    animate_panel(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred, out_prefix)

# --- MAIN ---
def main(backend='tf', architecture='unet'):
    obs_df = load_obs_grid(IONEX_PATHS, save_path=COMBINED_IONEX_PARQUET)
    obs_lats = sorted(set([lat for lat, _ in obs_df.columns]))
    obs_lons = sorted(set([lon for _, lon in obs_df.columns]))
    obs_epochs = pd.to_datetime(obs_df.index)
    obs_grids = np.stack([obs_df.loc[t].values.reshape((len(obs_lats), len(obs_lons))) for t in obs_df.index])

    model_grids, model_times, model_lats, model_lons = load_model_grid(TIEGCM_DIR)
    obs_on_model = interpolate_to_grid(obs_grids, obs_lats, obs_lons, model_lats, model_lons)
    all_epochs = pd.date_range(start=max(obs_epochs[0], model_times[0]), end=min(obs_epochs[-1], model_times[-1]), freq=f'{TEMPORAL_MINUTES}min')
    obs_interp = interp_time(obs_epochs, obs_on_model, all_epochs)
    model_interp = interp_time(model_times, model_grids, all_epochs)
    all_grids = np.concatenate([obs_interp, model_interp], axis=0)
    mean, std = np.nanmean(all_grids), np.nanstd(all_grids)
    all_grids = np.nan_to_num((all_grids - mean) / std)

    X, Y = make_training_pairs(all_grids, SPATIAL_FACTOR)
    if backend == 'tf':
        if architecture == 'unet':
            model = build_unet(input_shape=X.shape[1:])
        elif architecture == 'srcnn':
            model = build_srcnn(input_shape=X.shape[1:])
        model = train_tf(model, X, Y, epochs=30)
        Y_pred = model.predict(X)
    else:
        raise NotImplementedError("Only TensorFlow U-Net implemented here.")

    Y_pred = Y_pred * std + mean
    N = obs_interp.shape[0]
    Y_pred = Y_pred[:N]
    model_interp = model_interp[:N]
    all_epochs = all_epochs[:N]

    verify_and_plot_results(
        model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred
    )

    np.save("superres_vtec.npy", Y_pred)
    print("Saved super-resolved output.")

# --- RUN ---
if __name__ == '__main__':
    main(backend='tf', architecture='unet')
   
