In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import os
import re
from datetime import datetime, timedelta
from scipy.interpolate import interp1d, RegularGridInterpolator
from skimage.transform import resize
import warnings
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as animation
import matplotlib.dates as mdates

warnings.filterwarnings("ignore", category=UserWarning)

# --- CONFIG ---
IONEX_PATHS = [
    './ionosphere_central/vTEC_data/uqrg1000.24i',
    './ionosphere_central/vTEC_data/uqrg1010.24i'
]
COMBINED_IONEX_PARQUET = "combined_ionex_vtec.parquet"
TIEGCM_DIR = './ionosphere_central/CCMC/model_data/TIE-GCM/129/Akshay_Ramesh_042125_IT_5/'
SPATIAL_FACTOR = 2
TEMPORAL_MINUTES = 10

# --- Sun-center utilities ---
def get_subsolar_lon(dt):
    # dt: pandas.Timestamp or datetime.datetime (UTC)
    seconds = dt.hour * 3600 + dt.minute * 60 + dt.second
    frac_of_day = seconds / 86400.0
    subsolar_lon = (frac_of_day * 360.0) - 180.0
    return subsolar_lon

# --- Data processing (unchanged) ---
def parse_ionex_file(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    header_end = [i for i, l in enumerate(lines) if 'END OF HEADER' in l][0]
    header = lines[:header_end]
    for line in header:
        if 'LAT1 / LAT2 / DLAT' in line:
            lat1, lat2, dlat = map(float, line.split()[:3])
        if 'LON1 / LON2 / DLON' in line:
            lon1, lon2, dlon = map(float, line.split()[:3])
    lats = np.arange(lat1, lat2 - 0.1, -abs(dlat))
    lons = np.arange(lon1, lon2 + 0.1, dlon)
    maps, epochs = [], []
    i = header_end + 1
    while i < len(lines):
        if 'START OF TEC MAP' in lines[i]:
            epoch_line = lines[i+1]
            y, mo, d, h, mi, s = map(int, epoch_line[:36].split())
            epoch = datetime(y, mo, d, h%24, mi, s) + timedelta(days=h//24)
            epochs.append(epoch)
            grid = []
            i += 2
            while 'END OF TEC MAP' not in lines[i]:
                if 'LAT/LON1/LON2/DLON/H' in lines[i]:
                    row = []
                    i += 1
                    while lines[i].strip() and 'LAT/LON1/LON2/DLON/H' not in lines[i] and 'END OF TEC MAP' not in lines[i]:
                        line_values = re.findall(r'[-+]?[0-9]*\.?[0-9]+', lines[i])
                        vals = [float(v) if v != '9999' else np.nan for v in line_values]
                        row.extend(vals)
                        i += 1
                    if len(row) == len(lons):
                        grid.append(row)
                else:
                    i += 1
            if len(grid) == len(lats):
                maps.append(grid)
        else:
            i += 1
    maps = np.array(maps)
    data = {(lat, lon): maps[:, lat_idx, lon_idx] for lat_idx, lat in enumerate(lats) for lon_idx, lon in enumerate(lons)}
    df = pd.DataFrame(data, index=epochs)
    df = 0.1 * df
    df.columns = pd.MultiIndex.from_tuples(df.columns, names=['lat', 'lon'])
    return df

def load_obs_grid(paths, save_path=None):
    dfs = []
    for fp in paths:
        if fp:
            try:
                dfs.append(parse_ionex_file(fp))
            except Exception as e:
                print(f"Failed to process {fp}: {e}")
    df = pd.concat(dfs)
    df = df[~df.index.duplicated(keep='first')].sort_index()
    if save_path:
        df.to_parquet(save_path)
    return df

def load_model_grid(directory):
    model_files = sorted([f for f in os.listdir(directory) if not f.endswith('.json')])
    grids, times = [], []
    lat0, lon0 = None, None
    for fname in model_files:
        fpath = os.path.join(directory, fname)
        try:
            try:
                ds = xr.open_dataset(fpath, engine='netcdf4')
            except Exception:
                ds = xr.open_dataset(fpath, engine='scipy')
            if 'TEC' not in ds.variables:
                print(f"Skipping {fname}: no 'TEC'")
                continue
            tec = ds['TEC'].values
            if lat0 is None:
                lat0 = ds['lat'].values
                lon0 = ds['lon'].values
            if 'time' in ds:
                tvals = pd.to_datetime(ds['time'].values)
                for i in range(tec.shape[0]):
                    grids.append(tec[i] / 1e12)
                    times.append(tvals[i])
            else:
                grids.append(tec / 1e12)
                times.append(None)
            print(f"Loaded {fname} shape {tec.shape}")
        except Exception as e:
            print(f"Failed to load {fname}: {e}")
    if not grids:
        raise RuntimeError("No valid model grids found! Please check files and data.")
    grids = np.stack(grids)
    times = pd.to_datetime(times)
    print(f"Loaded {len(grids)} model grids, grid shape {grids.shape}")
    return grids, times, lat0, lon0

def interpolate_to_grid(obs_grids, obs_lats, obs_lons, tgt_lats, tgt_lons):
    out = []
    for grid in obs_grids:
        interp = RegularGridInterpolator((obs_lats, obs_lons), grid, bounds_error=False, fill_value=np.nan)
        mesh = np.meshgrid(tgt_lats, tgt_lons, indexing='ij')
        pts = np.stack([m.ravel() for m in mesh], axis=-1)
        grid_interp = interp(pts).reshape(len(tgt_lats), len(tgt_lons))
        out.append(grid_interp)
    return np.stack(out)

def interp_time(series_times, series_grids, new_times):
    interp = interp1d([pd.Timestamp(t).timestamp() for t in series_times],
                      series_grids, axis=0, kind='linear', bounds_error=False, fill_value="extrapolate")
    return interp([pd.Timestamp(t).timestamp() for t in new_times])

def make_training_pairs(data, spatial_factor):
    n, h, w = data.shape
    X, Y = [], []
    for i in range(n):
        hr = data[i]
        lr = resize(hr, (h//spatial_factor, w//spatial_factor), order=3, anti_aliasing=True)
        up = resize(lr, (h, w), order=3, anti_aliasing=False)
        X.append(up)
        Y.append(hr)
    X = np.stack(X)[..., None]
    Y = np.stack(Y)[..., None]
    return X, Y

def compute_spatial_gradients(grids, lats, lons):
    dlat = np.gradient(lats)
    dlon = np.gradient(lons)
    grad_lat = np.gradient(grids, axis=1) / dlat[None, :, None]
    grad_lon = np.gradient(grids, axis=2) / dlon[None, None, :]
    grad_mag = np.sqrt(grad_lat**2 + grad_lon**2)
    return grad_lat, grad_lon, grad_mag

def compute_temporal_gradient(grids, times):
    dts = np.diff(times) / np.timedelta64(1, 's')
    dt_mean = np.mean(dts)
    grad_t = np.gradient(grids, axis=0) / dt_mean
    return grad_t

# --- MODELS (TF) ---
import tensorflow as tf
from tensorflow.keras import layers, models

def build_unet(input_shape):
    inputs = layers.Input(shape=input_shape)
    c1 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)
    c2 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)
    m = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p2)
    u2 = layers.Conv2DTranspose(64, (2,2), strides=(2,2), padding='same')(m)
    u2 = layers.concatenate([u2, c2])
    c3 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u2)
    u1 = layers.Conv2DTranspose(32, (2,2), strides=(2,2), padding='same')(c3)
    u1 = layers.concatenate([u1, c1])
    c4 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(u1)
    outputs = layers.Conv2D(1, (1,1), activation='linear', padding='same')(c4)
    return models.Model(inputs, outputs)

def build_srcnn(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, 9, padding='same', activation='relu')(inputs)
    x = layers.Conv2D(32, 1, padding='same', activation='relu')(x)
    outputs = layers.Conv2D(1, 5, padding='same')(x)
    return models.Model(inputs, outputs)

def train_tf(model, X_train, Y_train, epochs=20, batch_size=8):
    model.compile(optimizer='adam', loss='mse')
    model.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    return model

# --- All animations are now "rotate map under sun" style. ---
def animate_sun_centered_data(model_lats, model_lons, all_epochs, data, fname='sun_centered_data_animation.gif', cmap='plasma', label='VTEC'):
    """
    Animate the data (shape: [N_frames, n_lat, n_lon]) with the subsolar point always at the center.
    """
    N_FRAMES = data.shape[0]
    vmin = np.nanmin(data)
    vmax = np.nanmax(data)
    fig = plt.figure(figsize=(10, 5))
    ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))

    def update(frame_idx):
        ax.cla()
        arr = data[frame_idx]
        dt = all_epochs[frame_idx]
        subsolar_lon = get_subsolar_lon(dt)
        ax.projection = ccrs.PlateCarree(central_longitude=-subsolar_lon)
        ax.set_global()
        mesh = ax.pcolormesh(model_lons, model_lats, arr, cmap=cmap, shading='auto', vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
        ax.coastlines()
        ax.set_title(f'{label}\nUTC: {dt:%Y-%m-%d %H:%M} | Subsolar Lon: {subsolar_lon:.1f}°', pad=14)
        return [mesh]
    ani = animation.FuncAnimation(fig, update, frames=N_FRAMES, blit=False)
    ani.save(fname, writer='pillow', fps=6)
    plt.close(fig)
    print(f"Saved {fname}")

def animate_panel_sun_centered(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred, out_prefix):
    n_frames = obs_interp.shape[0]
    vmin = np.nanmin([np.nanmin(obs_interp), np.nanmin(model_interp), np.nanmin(Y_pred)])
    vmax = np.nanmax([np.nanmax(obs_interp), np.nanmax(model_interp), np.nanmax(Y_pred)])
    fig, axs = plt.subplots(1, 3, figsize=(21, 6), subplot_kw={'projection': ccrs.PlateCarree()})
    def update(idx):
        dt = all_epochs[idx]
        subsolar_lon = get_subsolar_lon(dt)
        for i, (arr, title) in enumerate(zip([obs_interp[idx], model_interp[idx], Y_pred[idx][..., 0]],
                                            ['Obs (Interp)', 'Physical Model', 'Super-Res'])):
            axs[i].cla()
            axs[i].set_global()
            axs[i].projection = ccrs.PlateCarree(central_longitude=-subsolar_lon)
            mesh = axs[i].pcolormesh(model_lons, model_lats, arr, vmin=vmin, vmax=vmax, shading='auto', cmap='plasma', transform=ccrs.PlateCarree())
            axs[i].coastlines()
            axs[i].set_title(f"{title}\n{dt:%Y-%m-%d %H:%M} Sun-centered")
            if idx == 0:
                plt.colorbar(mesh, ax=axs[i], orientation='vertical', label='VTEC (TECU)')
        return axs
    ani = animation.FuncAnimation(fig, update, frames=n_frames, blit=False)
    ani.save(f'{out_prefix}_compare_suncentered.gif', writer='pillow', fps=5)
    plt.close(fig)
    print(f"Sun-centered panel animation saved to {out_prefix}_compare_suncentered.gif")

def animate_gradient_sun_centered(all_epochs, data, lats, lons, label, fname_prefix, cmap, vmin=None, vmax=None):
    fig = plt.figure(figsize=(9, 4.5))
    ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
    def update(i):
        dt = all_epochs[i]
        subsolar_lon = get_subsolar_lon(dt)
        ax.clear()
        ax.set_global()
        ax.projection = ccrs.PlateCarree(central_longitude=-subsolar_lon)
        im = ax.pcolormesh(lons, lats, data[i], cmap=cmap, vmin=vmin, vmax=vmax, shading='auto', transform=ccrs.PlateCarree())
        ax.coastlines()
        ax.set_title(f"{label}\n{dt:%Y-%m-%d %H:%M} (Sun fixed @ center)")
        return [im]
    ani = animation.FuncAnimation(fig, update, frames=len(all_epochs), blit=False)
    ani.save(fname_prefix + '.gif', writer='pillow', fps=8)
    plt.close(fig)

# --- Static/PNG can still use rolling if you want the local comparison plot to have aligned features. ---

def verify_and_plot_results(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred, out_prefix='comparison'):
    vmin = np.nanmin([np.nanmin(obs_interp), np.nanmin(model_interp), np.nanmin(Y_pred)])
    vmax = np.nanmax([np.nanmax(obs_interp), np.nanmax(model_interp), np.nanmax(Y_pred)])
    idxs = np.linspace(0, len(all_epochs)-1, min(20, len(all_epochs))).astype(int)

    # --- Sun-centered Static Example Comparison Plots (can use rolling for local alignment if wanted) ---
    for i, idx in enumerate(idxs[:5]):
        dt = all_epochs[idx]
        subsolar_lon = get_subsolar_lon(dt)
        projection = ccrs.PlateCarree(central_longitude=-subsolar_lon)
        data_stack = [obs_interp[idx], model_interp[idx], Y_pred[idx][..., 0]]
        titles = [
            f"Obs (Interp) {dt:%Y-%m-%d %H:%M}",
            f"Physical Model {dt:%Y-%m-%d %H:%M}",
            f"Super-Res Output {dt:%Y-%m-%d %H:%M}",
        ]
        fig, axs = plt.subplots(1, 3, figsize=(18, 5), subplot_kw={'projection': projection})
        for ax, dat, title in zip(axs, data_stack, titles):
            # Optional: roll_longitude(dat, model_lons, subsolar_lon) if you want static PNG to have the same region aligned.
            im = ax.pcolormesh(model_lons, model_lats, dat, vmin=vmin, vmax=vmax, shading='auto', cmap='plasma', transform=ccrs.PlateCarree())
            ax.coastlines()
            ax.set_title(title + "\nSun-centered")
            cb = plt.colorbar(im, ax=ax, orientation='vertical')
            cb.set_label('VTEC (TECU)')
        plt.tight_layout()
        plt.savefig(f"{out_prefix}_static_{i:02d}_suncentered.png")
        plt.close(fig)

    # --- RMSE Time Series Plot (unchanged) ---
    rmse_model = np.sqrt(np.nanmean((obs_interp - model_interp) ** 2, axis=(1, 2)))
    rmse_sr = np.sqrt(np.nanmean((obs_interp - Y_pred[:obs_interp.shape[0],..., 0]) ** 2, axis=(1, 2)))
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(all_epochs, rmse_model, label='Physical Model', color='tab:blue')
    ax.plot(all_epochs, rmse_sr, label='Super-Res', color='tab:orange')
    ax.set_ylabel('RMSE (TECU)')
    ax.set_title('RMSE of Model vs. Super-Res vs. Obs')
    ax.legend()
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
    fig.autofmt_xdate()
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_rmse_timeseries.png")
    plt.close(fig)

    # --- Global Error Map (unchanged) ---
    mae_model = np.nanmean(np.abs(obs_interp - model_interp), axis=0)
    mae_sr = np.nanmean(np.abs(obs_interp - Y_pred[:obs_interp.shape[0],..., 0]), axis=0)
    fig, axs = plt.subplots(1, 2, figsize=(14, 5), subplot_kw={'projection': ccrs.PlateCarree()})
    for ax, dat, title in zip(axs, [mae_model, mae_sr], ['Physical Model MAE', 'Super-Res MAE']):
        im = ax.pcolormesh(model_lons, model_lats, dat, shading='auto', cmap='inferno')
        ax.coastlines()
        ax.set_title(title)
        cb = plt.colorbar(im, ax=ax, orientation='vertical')
        cb.set_label('Mean Absolute Error (TECU)')
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_mae_maps.png")
    plt.close(fig)

    # --- GRADIENT ANALYSIS WITH SUN CENTERED ANIMATIONS ---
    def plot_gradients_and_speeds(grid, grid_label, all_epochs, model_lats, model_lons, out_prefix):
        grad_lat, grad_lon, grad_mag = compute_spatial_gradients(grid, model_lats, model_lons)
        grad_t = compute_temporal_gradient(grid, all_epochs)
        idxs = np.linspace(0, grid.shape[0]-1, min(4, grid.shape[0])).astype(int)
        for j, idx in enumerate(idxs):
            dt = all_epochs[idx]
            subsolar_lon = get_subsolar_lon(dt)
            projection = ccrs.PlateCarree(central_longitude=-subsolar_lon)
            fig, axs = plt.subplots(1, 2, figsize=(14, 5), subplot_kw={'projection': projection})
            mag_arr = grad_mag[idx]
            t_arr = grad_t[idx]
            im0 = axs[0].pcolormesh(model_lons, model_lats, mag_arr, cmap='viridis', shading='auto', transform=ccrs.PlateCarree())
            axs[0].coastlines()
            axs[0].set_title(f"{grid_label} | |∇VTEC| {dt:%Y-%m-%d %H:%M}\nSun-centered")
            cb0 = plt.colorbar(im0, ax=axs[0]); cb0.set_label("Gradient (TECU/deg)")
            im1 = axs[1].pcolormesh(model_lons, model_lats, t_arr, cmap='coolwarm', shading='auto', transform=ccrs.PlateCarree())
            axs[1].coastlines()
            axs[1].set_title(f"{grid_label} | dVTEC/dt {dt:%Y-%m-%d %H:%M}\nSun-centered")
            cb1 = plt.colorbar(im1, ax=axs[1]); cb1.set_label("Time Gradient (TECU/s)")
            plt.tight_layout()
            plt.savefig(f"{out_prefix}_{grid_label.replace(' ','_').lower()}_gradient_{j:02d}_suncentered.png")
            plt.close(fig)
        # --- Timeseries (unchanged) ---
        plt.figure(figsize=(10, 4))
        plt.plot(all_epochs, np.nanmean(grad_mag, axis=(1,2)), label="|∇VTEC| (space)")
        plt.plot(all_epochs, np.nanmean(np.abs(grad_t), axis=(1,2)), label="|dVTEC/dt| (time)")
        plt.title(f"{grid_label} | Global mean gradients")
        plt.xlabel("Time")
        plt.ylabel("Mean gradient")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{out_prefix}_{grid_label.replace(' ','_').lower()}_gradient_timeseries.png")
        plt.close()
        with np.errstate(divide='ignore', invalid='ignore'):
            speed_est = np.abs(grad_t) / (grad_mag + 1e-6)
            median_speed = np.nanmedian(speed_est, axis=(1,2))
        plt.figure(figsize=(10, 4))
        plt.plot(all_epochs, median_speed)
        plt.title(f"{grid_label} | Median apparent propagation speed (deg/s)")
        plt.xlabel("Time")
        plt.ylabel("deg/s")
        plt.tight_layout()
        plt.savefig(f"{out_prefix}_{grid_label.replace(' ','_').lower()}_speed_est_timeseries.png")
        plt.close()
        # --- SUN-CENTERED ANIMATIONS ---
        gradmag_vmin, gradmag_vmax = np.nanpercentile(grad_mag, [1, 99])
        gradt_vmin, gradt_vmax = np.nanpercentile(grad_t, [1, 99])
        speed_est_vmin, speed_est_vmax = np.nanpercentile(np.abs(speed_est), [1, 99])
        fname_prefix = f"{out_prefix}_{grid_label.replace(' ','_').lower()}"
        animate_gradient_sun_centered(all_epochs, grad_mag, model_lats, model_lons, f"{grid_label} |∇VTEC| (TECU/deg)", fname_prefix+"_gradmag_suncentered", 'viridis', gradmag_vmin, gradmag_vmax)
        animate_gradient_sun_centered(all_epochs, grad_t, model_lats, model_lons, f"{grid_label} dVTEC/dt (TECU/s)", fname_prefix+"_gradt_suncentered", 'coolwarm', gradt_vmin, gradt_vmax)
        animate_gradient_sun_centered(all_epochs, np.abs(speed_est), model_lats, model_lons, f"{grid_label} |speed| (deg/s)", fname_prefix+"_speed_suncentered", 'plasma', speed_est_vmin, speed_est_vmax)

    pred_grid = Y_pred[:obs_interp.shape[0], ..., 0]
    grids_dict = {
        "Obs (Interp)": obs_interp,
        "Super-Res": pred_grid,
        "Physical Model": model_interp[:obs_interp.shape[0]]
    }
    for grid_label, grid in grids_dict.items():
        plot_gradients_and_speeds(grid, grid_label, all_epochs, model_lats, model_lons, out_prefix)

    print("Comparison, gradient plots, and sun-centered animations saved.")

    # --- Panel animation (all subplots, map moves for all) ---
    animate_panel_sun_centered(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred, out_prefix)

def main(backend='tf', architecture='unet'):
    obs_df = load_obs_grid(IONEX_PATHS, save_path=COMBINED_IONEX_PARQUET)
    obs_lats = sorted(set([lat for lat, _ in obs_df.columns]))
    obs_lons = sorted(set([lon for _, lon in obs_df.columns]))
    obs_epochs = pd.to_datetime(obs_df.index)
    obs_grids = np.stack([obs_df.loc[t].values.reshape((len(obs_lats), len(obs_lons))) for t in obs_df.index])
    model_grids, model_times, model_lats, model_lons = load_model_grid(TIEGCM_DIR)
    obs_on_model = interpolate_to_grid(obs_grids, obs_lats, obs_lons, model_lats, model_lons)
    all_epochs = pd.date_range(start=max(obs_epochs[0], model_times[0]), end=min(obs_epochs[-1], model_times[-1]), freq=f'{TEMPORAL_MINUTES}min')
    obs_interp = interp_time(obs_epochs, obs_on_model, all_epochs)
    model_interp = interp_time(model_times, model_grids, all_epochs)
    all_grids = np.concatenate([obs_interp, model_interp], axis=0)
    mean, std = np.nanmean(all_grids), np.nanstd(all_grids)
    all_grids = np.nan_to_num((all_grids - mean) / std)
    X, Y = make_training_pairs(all_grids, SPATIAL_FACTOR)
    if backend == 'tf':
        if architecture == 'unet':
            model = build_unet(input_shape=X.shape[1:])
        elif architecture == 'srcnn':
            model = build_srcnn(input_shape=X.shape[1:])
        model = train_tf(model, X, Y, epochs=30)
        Y_pred = model.predict(X)
    else:
        raise NotImplementedError("Only TensorFlow U-Net and SRCNN implemented here.")
    Y_pred = Y_pred * std + mean
    N = obs_interp.shape[0]
    Y_pred = Y_pred[:N]
    model_interp = model_interp[:N]
    all_epochs = all_epochs[:N]
    verify_and_plot_results(model_lats, model_lons, all_epochs, obs_interp, model_interp, Y_pred)
    # Also make single-field sun-centered GIFs if you want (optional)
    animate_sun_centered_data(model_lats, model_lons, all_epochs, Y_pred[..., 0], fname='sun_centered_superres.gif', cmap='plasma', label='Super-Res VTEC')
    animate_sun_centered_data(model_lats, model_lons, all_epochs, obs_interp, fname='sun_centered_obs.gif', cmap='plasma', label='Obs (Interp) VTEC')
    animate_sun_centered_data(model_lats, model_lons, all_epochs, model_interp, fname='sun_centered_model.gif', cmap='plasma', label='Physical Model VTEC')
    np.save("superres_vtec.npy", Y_pred)
    print("Saved super-resolved output.")

if __name__ == '__main__':
    main(backend='tf', architecture='unet')
    # To run SRCNN, change to:
    # main(backend='tf', architecture='srcnn')


In [None]:
#test is it running again?


Sun Centered Observation Rotation Ground Observations. 

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as animation

# --- Load Data ---
df = pd.read_parquet("combined_ionex_vtec.parquet")

if not isinstance(df.columns, pd.MultiIndex):
    raise RuntimeError("Expected MultiIndex columns for [lat, lon]!")

lats = sorted(set(lat for lat, _ in df.columns), reverse=True)
lons = sorted(set(lon for _, lon in df.columns))
all_epochs = df.index
all_grids = [df.loc[epoch].values.reshape((len(lats), len(lons))) for epoch in all_epochs]
vmin = np.nanmin([np.nanmin(g) for g in all_grids])
vmax = np.nanmax([np.nanmax(g) for g in all_grids])

def get_subsolar_lon(dt):
    seconds = dt.hour * 3600 + dt.minute * 60 + dt.second
    frac_of_day = seconds / 86400.0
    return (frac_of_day * 360.0) - 180.0

fig = plt.figure(figsize=(13, 6))

def update(frame_idx):
    plt.clf()  # Clear the figure, NOT just the axes, to allow new projection
    epoch = all_epochs[frame_idx]
    subsolar_lon = get_subsolar_lon(pd.to_datetime(epoch))
    ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=-subsolar_lon))
    ax.set_global()
    ax.coastlines()
    mesh = ax.pcolormesh(lons, lats, all_grids[frame_idx], 
                         transform=ccrs.PlateCarree(), shading='auto', vmin=vmin, vmax=vmax, cmap='gist_ncar')
    title = ax.set_title(f'Sun-centered Observations VTEC\n{epoch} | Subsolar Lon: {subsolar_lon:.1f}°')
    cb = plt.colorbar(mesh, ax=ax, orientation='vertical', label='VTEC (TECU)')
    return mesh, title

ani = animation.FuncAnimation(fig, update, frames=len(df), interval=250, blit=False)

ani.save('sun_centered_ionexOBS_vtec.gif', writer='pillow', fps=4)
print("Saved sun_centered_ionex_vtec.gif")


# --- For Jupyter notebook inline display (optional) ---
from IPython.display import HTML
plt.rcParams['animation.embed_limit'] = 50_000_000
HTML(ani.to_jshtml())

# --- Save as MP4 (requires ffmpeg) ---
#ani.save('sun_centered_ionex_vtec.mp4', writer='ffmpeg', fps=4)


Now do it orthographic, both sun and dark sides. 

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as animation

# --- Load Data ---
df = pd.read_parquet("combined_ionex_vtec.parquet")

if not isinstance(df.columns, pd.MultiIndex):
    raise RuntimeError("Expected MultiIndex columns for [lat, lon]!")

lats = sorted(set(lat for lat, _ in df.columns), reverse=True)
lons = sorted(set(lon for _, lon in df.columns))
all_epochs = df.index
all_grids = [df.loc[epoch].values.reshape((len(lats), len(lons))) for epoch in all_epochs]
vmin = np.nanmin([np.nanmin(g) for g in all_grids])
vmax = np.nanmax([np.nanmax(g) for g in all_grids])

def get_subsolar_lon(dt):
    seconds = dt.hour * 3600 + dt.minute * 60 + dt.second
    frac_of_day = seconds / 86400.0
    return (frac_of_day * 360.0) - 180.0

fig = plt.figure(figsize=(15, 7))

def update(frame_idx):
    plt.clf()
    epoch = all_epochs[frame_idx]
    subsolar_lon = get_subsolar_lon(pd.to_datetime(epoch))
    nightside_lon = (subsolar_lon + 180) % 360
    # Make nightside range [-180,180]
    if nightside_lon > 180:
        nightside_lon -= 360

    # --- Dayside Ax (flip sign) ---
    ax1 = plt.subplot(1, 2, 1, projection=ccrs.Orthographic(central_longitude=-subsolar_lon, central_latitude=0))
    ax1.set_global()
    ax1.coastlines()
    mesh1 = ax1.pcolormesh(
        lons, lats, all_grids[frame_idx], 
        transform=ccrs.PlateCarree(), shading='auto', vmin=vmin, vmax=vmax, cmap='gist_ncar'
    )
    ax1.set_title(f"Sun-Facing Side\nSubsolar Lon: {subsolar_lon:.1f}°", fontsize=12)
    
    # --- Nightside Ax (flip sign) ---
    ax2 = plt.subplot(1, 2, 2, projection=ccrs.Orthographic(central_longitude=-nightside_lon, central_latitude=0))
    ax2.set_global()
    ax2.coastlines()
    mesh2 = ax2.pcolormesh(
        lons, lats, all_grids[frame_idx], 
        transform=ccrs.PlateCarree(), shading='auto', vmin=vmin, vmax=vmax, cmap='gist_ncar'
    )
    ax2.set_title(f"Dark Side\nOpposite Lon: {nightside_lon:.1f}°", fontsize=12)

    # --- Colorbar below both plots ---
    cbar_ax = fig.add_axes([0.25, 0.10, 0.5, 0.03])
    plt.colorbar(mesh1, cax=cbar_ax, orientation='horizontal', label='VTEC (TECU)')

    fig.suptitle(f'VTEC | {epoch}', fontsize=14)
    return mesh1, mesh2

ani = animation.FuncAnimation(fig, update, frames=len(df), interval=250, blit=False)
ani.save('orthographic_day_night_vtec.gif', writer='pillow', fps=4)
print("Saved orthographic_day_night_vtec.gif")

# --- Jupyter inline display (optional) ---
from IPython.display import HTML
plt.rcParams['animation.embed_limit'] = 50_000_000
HTML(ani.to_jshtml())


## Looks like a few things. 
### 1. It looks like there is a leading edge and a trailing edge. 
#### the ionosphere is dragged into the shade by the Earth/atmosphere??
### 2. Need to look at more data. The continent positions pull the "wave" up and down. Specifically Africa, and then South America.. 