In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as animation
import xarray as xr
import os
from scipy.interpolate import interp1d
from skimage.transform import resize
from tqdm import tqdm

# --- CONFIGURATION ---
spatial_factor = 4
temporal_minutes = 5
model_directory = './ionosphere_central/CCMC/model_data/TIE-GCM/129/Akshay_Ramesh_042125_IT_5/'

# --- LOAD MODEL DATA ---
model_files = sorted([f for f in os.listdir(model_directory) if not f.endswith('.json')])
model_grids, model_times = [], []
for fname in model_files:
    fpath = os.path.join(model_directory, fname)
    try:
        ds = xr.open_dataset(fpath)
        if 'TEC' not in ds.variables:
            print(f"Skipping {fname} (no 'TEC' variable)")
            continue
        tec = ds['TEC'].values
        if 'time' in ds:
            times = pd.to_datetime(ds['time'].values)
            for i in range(tec.shape[0]):
                model_grids.append(tec[i]/1e12)  # Already TECU (if not, adjust scale here)
                model_times.append(times[i])
        else:
            model_grids.append(tec/1e12)
            model_times.append(None)
    except Exception as e:
        print(f"Failed to load {fname}: {e}")

if len(model_grids) == 0:
    raise RuntimeError("No valid model grids found! Please check files and data.")

model_grids = np.stack(model_grids)     # (n_times, lat, lon)
model_times = pd.to_datetime([t if t is not None else pd.NaT for t in model_times])
model_lats = ds['lat'].values
model_lons = ds['lon'].values

# --- LOAD & PREP OBSERVATION DATA ---
df = pd.read_parquet("combined_ionex_vtec.parquet")
df = .035 * df
lats = sorted(set([lat for lat, _ in df.columns]))
lons = sorted(set([lon for _, lon in df.columns]))
obs_epochs = pd.to_datetime(df.index)
df = df.sort_index(axis=1, level=['lat', 'lon'])
obs_grids = np.stack([df.loc[epoch].values.reshape((len(lats), len(lons))) for epoch in df.index], axis=0)

# --- UNIT CONVERSION FOR OBSERVATIONS (if needed) ---
if np.nanmax(obs_grids) > 200:  # Likely in 1e16 el/mÂ²
    print("Converting obs_grids from VTEC units to TECU")
    obs_grids = obs_grids / 1e16  # Now in TECU

# --- INTERPOLATE OBS ONTO MODEL GRID ---
from scipy.interpolate import RegularGridInterpolator
obs_on_model = []
for g in obs_grids:
    interp_func = RegularGridInterpolator((lats, lons), g, bounds_error=False, fill_value=np.nan)
    mesh = np.meshgrid(model_lats, model_lons, indexing='ij')
    pts = np.column_stack([m.ravel() for m in mesh])
    out = interp_func(pts).reshape(len(model_lats), len(model_lons))
    obs_on_model.append(out)
obs_on_model = np.stack(obs_on_model)   # (n_obs_times, lat, lon)

# --- CHOOSE UNION TIME AXIS & INTERPOLATE TO 5-MIN ---
min_time = max(np.nanmin(model_times), np.nanmin(obs_epochs))
max_time = min(np.nanmax(model_times), np.nanmax(obs_epochs))
all_times = pd.date_range(start=min_time, end=max_time, freq=f'{temporal_minutes}min')

def interp_time(series_times, series_grids, new_times):
    series_times = pd.to_datetime(series_times)
    keep = ~pd.isnull(series_times)
    series_times = np.array(series_times)[keep]
    series_grids = np.array(series_grids)[keep]
    # Remove times with NaN-only grids (fully missing)
    keep2 = [not np.all(np.isnan(g)) for g in series_grids]
    series_times = series_times[keep2]
    series_grids = series_grids[keep2]
    interp = interp1d(
        [pd.Timestamp(t).timestamp() for t in series_times],
        series_grids, axis=0, kind='linear', bounds_error=False, fill_value="extrapolate"
    )
    return interp([pd.Timestamp(t).timestamp() for t in new_times])

obs_interp = interp_time(obs_epochs, obs_on_model, all_times)
model_interp = interp_time(model_times, model_grids, all_times)

print("\nDIAGNOSTICS:")
print("Obs Interp min/max/mean:", np.nanmin(obs_interp), np.nanmax(obs_interp), np.nanmean(obs_interp))
print("Model Interp min/max/mean:", np.nanmin(model_interp), np.nanmax(model_interp), np.nanmean(model_interp))
print("Obs shape:", obs_interp.shape, "Model shape:", model_interp.shape)

# --- SUPER-RESOLUTION TRAINING PAIRS ---
all_grids = np.concatenate([obs_interp, model_interp], axis=0)
mean = np.nanmean(all_grids)
std = np.nanstd(all_grids)
all_grids_norm = (all_grids - mean) / std
all_grids_norm = np.nan_to_num(all_grids_norm)

n_lat, n_lon = all_grids_norm.shape[1:3]
X_train, Y_train = [], []
for i in range(all_grids_norm.shape[0]):
    lowres = resize(all_grids_norm[i], (n_lat//spatial_factor, n_lon//spatial_factor), order=3, anti_aliasing=True)
    upsampled = resize(lowres, (n_lat*spatial_factor, n_lon*spatial_factor), order=3, anti_aliasing=False)
    target = resize(all_grids_norm[i], (n_lat*spatial_factor, n_lon*spatial_factor), order=3, anti_aliasing=True)
    X_train.append(upsampled)
    Y_train.append(target)
X_train = np.stack(X_train)
Y_train = np.stack(Y_train)
X_train = torch.tensor(X_train[:,None,:,:], dtype=torch.float32)
Y_train = torch.tensor(Y_train[:,None,:,:], dtype=torch.float32)

print("X_train shape:", X_train.shape, "Y_train shape:", Y_train.shape)

# --- OPTIONAL: ANIMATION FOR VALIDATION ---
import matplotlib
matplotlib.use('Agg')
vmin = min(np.nanmin(obs_interp), np.nanmin(model_interp))
vmax = max(np.nanmax(obs_interp), np.nanmax(model_interp))
fig, axes = plt.subplots(1, 2, figsize=(14, 6), subplot_kw={'projection': ccrs.PlateCarree()})
titles = ["Obs (Interpolated)", "TIE-GCM Model (Interpolated)"]
for ax, title in zip(axes, titles):
    ax.coastlines()
    ax.set_global()
    ax.set_title(title)

im_obs = axes[0].pcolormesh(model_lons, model_lats, obs_interp[0], transform=ccrs.PlateCarree(),
                            shading='auto', vmin=vmin, vmax=vmax, cmap='viridis')
im_mod = axes[1].pcolormesh(model_lons, model_lats, model_interp[0], transform=ccrs.PlateCarree(),
                            shading='auto', vmin=vmin, vmax=vmax, cmap='viridis')

cb = fig.colorbar(im_obs, ax=axes, orientation='horizontal', fraction=0.04, pad=0.09, label='TEC (TECU)')

def update(idx):
    im_obs.set_array(obs_interp[idx].ravel())
    im_mod.set_array(model_interp[idx].ravel())
    axes[0].set_title(f"Obs (Interp)\n{all_times[idx]}")
    axes[1].set_title(f"Model (Interp)\n{all_times[idx]}")
    return [im_obs, im_mod]

ani = animation.FuncAnimation(fig, update, frames=len(all_times), interval=100, blit=False)
ani.save('obs_vs_model_5min.gif', writer='pillow', fps=10)
plt.close(fig)
print("Saved synchronized validation animation: obs_vs_model_5min.gif")

# --- DATA READY FOR ML ---
print("Ready for super-resolution training with PyTorch or TensorFlow.")


In [None]:
import numpy as np

# If you have [N, 1, H, W], squeeze the singleton channel for Keras [N, H, W, 1]
X = np.load('X_train.npy') if os.path.exists('X_train.npy') else X_train.numpy()
Y = np.load('Y_train.npy') if os.path.exists('Y_train.npy') else Y_train.numpy()
X = np.squeeze(X) if X.shape[1] == 1 else X
Y = np.squeeze(Y) if Y.shape[1] == 1 else Y

# Keras wants channel-last: [N, H, W, 1]
X = X[..., None]
Y = Y[..., None]
print("X:", X.shape, "Y:", Y.shape)


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

def unet_upsampler(input_shape, base_filters=32):
    inputs = layers.Input(shape=input_shape)
    # Encoder
    c1 = layers.Conv2D(base_filters, 3, activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(base_filters, 3, activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D()(c1)
    c2 = layers.Conv2D(base_filters*2, 3, activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(base_filters*2, 3, activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D()(c2)
    # Bottleneck
    b = layers.Conv2D(base_filters*4, 3, activation='relu', padding='same')(p2)
    b = layers.Conv2D(base_filters*4, 3, activation='relu', padding='same')(b)
    # Decoder
    u2 = layers.UpSampling2D()(b)
    concat2 = layers.Concatenate()([u2, c2])
    d2 = layers.Conv2D(base_filters*2, 3, activation='relu', padding='same')(concat2)
    d2 = layers.Conv2D(base_filters*2, 3, activation='relu', padding='same')(d2)
    u1 = layers.UpSampling2D()(d2)
    concat1 = layers.Concatenate()([u1, c1])
    d1 = layers.Conv2D(base_filters, 3, activation='relu', padding='same')(concat1)
    d1 = layers.Conv2D(base_filters, 3, activation='relu', padding='same')(d1)
    # Final convolution
    outputs = layers.Conv2D(1, 1, activation='linear', padding='same')(d1)
    return models.Model(inputs, outputs)

# Example: input lowres upsampled (n_lat*sf, n_lon*sf, 1), output target highres
input_shape = X.shape[1:]  # e.g. (n_lat*sf, n_lon*sf, 1)
model = unet_upsampler(input_shape, base_filters=32)
model.compile(optimizer='adam', loss='mse')
model.summary()


In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Optionally, split into train/val
from sklearn.model_selection import train_test_split
Xtr, Xval, Ytr, Yval = train_test_split(X, Y, test_size=0.1, random_state=42)

callbacks = [
    EarlyStopping(patience=8, restore_best_weights=True),
    ModelCheckpoint('best_upsampler_unet.h5', save_best_only=True)
]

history = model.fit(
    Xtr, Ytr,
    validation_data=(Xval, Yval),
    batch_size=8,
    epochs=50,
    callbacks=callbacks
)


In [None]:
# Predict on validation set (or any batch)
Y_pred = model.predict(Xval)
print("Y_pred shape:", Y_pred.shape)

# Example plot for the first validation sample
import matplotlib.pyplot as plt

i = 0  # first sample
fig, axs = plt.subplots(1, 3, figsize=(18, 6))
axs[0].imshow(Xval[i, ..., 0], aspect='auto')
axs[0].set_title('Input (Low-res upsampled)')
axs[1].imshow(Yval[i, ..., 0], aspect='auto')
axs[1].set_title('Target (High-res)')
axs[2].imshow(Y_pred[i, ..., 0], aspect='auto')
axs[2].set_title('Prediction (Super-res)')
plt.tight_layout()
plt.savefig("validation_example.png")  # <-- Save the figure as PNG
plt.show()


In [None]:
model.save('final_upsampler_unet.h5')


In [None]:
from skimage.transform import resize

# 1. Check your original model input shape:
expected_shape = upsampler.input_shape  # (None, h, w, 1)
_, n_lat, n_lon, n_chan = expected_shape
print("Model expects:", expected_shape)

# 2. If your model_interp is (time, orig_lat, orig_lon), upsample it:
spatial_factor = n_lat // model_interp.shape[1]
print("Spatial upsample factor:", spatial_factor)

model_up = []
for t in range(model_interp.shape[0]):
    upsampled = resize(
        model_interp[t], (n_lat, n_lon),
        order=3, anti_aliasing=True, preserve_range=True
    )
    model_up.append(upsampled)
model_up = np.stack(model_up)  # (time, n_lat, n_lon)

# 3. Add channel dim
model_in = model_up[..., np.newaxis]  # (time, n_lat, n_lon, 1)

# 4. Predict
superres = upsampler.predict(model_in, verbose=1)
superres = np.squeeze(superres)
print("Superres shape:", superres.shape)



In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.animation as animation
import numpy as np

# --- Make sure you have these variables: ---
# superres: (T, new_lat, new_lon)
# model_lats, model_lons: original grids (1D)
# all_times or times_for_sr: (T,) datetime array matching superres

# If needed:
# times_for_sr = all_times

# --- Construct new (upsampled) lat/lon grid ---
upsampled_lats = np.linspace(model_lats.min(), model_lats.max(), superres.shape[1])
upsampled_lons = np.linspace(model_lons.min(), model_lons.max(), superres.shape[2])

assert superres.shape[0] == len(times_for_sr), f"Time axis and frames out of sync: {superres.shape[0]} vs {len(times_for_sr)}"

# --- Animation setup ---
fig, ax = plt.subplots(figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_global()
ax.coastlines(resolution='110m', linewidth=1)
ax.add_feature(cfeature.BORDERS, linewidth=0.7, edgecolor='gray')
ax.add_feature(cfeature.LAND, zorder=0, edgecolor='black', alpha=0.1)
ax.add_feature(cfeature.LAKES, alpha=0.2)
ax.add_feature(cfeature.RIVERS, alpha=0.2)
ax.set_title("Super-resolved TEC (Model)")

vmin, vmax = np.nanmin(superres), np.nanmax(superres)
im = ax.pcolormesh(
    upsampled_lons, upsampled_lats, superres[0],
    cmap='viridis', vmin=vmin, vmax=vmax, shading='auto', transform=ccrs.PlateCarree()
)
cb = fig.colorbar(im, ax=ax, orientation='horizontal', pad=0.08, label='TEC (TECU)')

def update_sr(i):
    im.set_array(superres[i].ravel())
    ax.set_title(f"Super-resolved TEC (Model)\n{times_for_sr[i]}")
    return [im]

ani_sr = animation.FuncAnimation(
    fig, update_sr, frames=superres.shape[0], interval=150, blit=False
)
ani_sr.save('model_superres.gif', writer='pillow', fps=8)
plt.close(fig)
print("Saved super-resolved model animation: model_superres.gif")



In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import os
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.animation as animation
from scipy.interpolate import RegularGridInterpolator, interp1d
import tensorflow as tf
from skimage.transform import resize

# --- LOAD OBSERVATIONAL DATA ---
df = pd.read_parquet("combined_ionex_vtec.parquet")
df = df * 0.04
lats = sorted(set([lat for lat, _ in df.columns]))
lons = sorted(set([lon for _, lon in df.columns]))
obs_epochs = pd.to_datetime(df.index)
df = df.sort_index(axis=1, level=['lat', 'lon'])
obs_grids = np.stack([df.loc[epoch].values.reshape((len(lats), len(lons))) for epoch in df.index], axis=0)
if np.nanmax(obs_grids) > 200: obs_grids = obs_grids / 1e16

# --- LOAD MODEL DATA ---
model_directory = './ionosphere_central/CCMC/model_data/TIE-GCM/129/Akshay_Ramesh_042125_IT_5/'
model_files = sorted([f for f in os.listdir(model_directory) if not f.endswith('.json')])

model_grids, model_times = [], []
for fname in model_files:
    fpath = os.path.join(model_directory, fname)
    try:
        ds = xr.open_dataset(fpath)
        if 'TEC' not in ds.variables:
            print(f"Skipping {fname} (no 'TEC' variable)")
            continue
        tec = ds['TEC'].values
        if 'time' in ds:
            times = pd.to_datetime(ds['time'].values)
            for i in range(tec.shape[0]):
                model_grids.append(tec[i]/1e12)  # Convert to TECU if data is in 1e12 el/m^2
                model_times.append(times[i])
        else:
            model_grids.append(tec/1e12)
            model_times.append(None)
    except Exception as e:
        print(f"Failed to load {fname}: {e}")

if len(model_grids) == 0:
    raise RuntimeError("No valid model grids found! Please check files and data.")

model_grids = np.stack(model_grids)     # (n_times, lat, lon)
model_times = pd.to_datetime([t if t is not None else pd.NaT for t in model_times])
model_lats = ds['lat'].values
model_lons = ds['lon'].values

# --- INTERPOLATE OBSERVATIONS ONTO MODEL GRID ---
obs_on_model = []
for g in obs_grids:
    interp_func = RegularGridInterpolator((lats, lons), g, bounds_error=False, fill_value=np.nan)
    mesh = np.meshgrid(model_lats, model_lons, indexing='ij')
    pts = np.column_stack([m.ravel() for m in mesh])
    out = interp_func(pts).reshape(len(model_lats), len(model_lons))
    obs_on_model.append(out)
obs_on_model = np.stack(obs_on_model)   # (n_obs_times, lat, lon)

# --- SYNC TIME AXES: Interpolate both datasets to the same time axis ---
common_times = sorted(set(model_times.dropna()).intersection(set(obs_epochs)))
if len(common_times) == 0:
    raise RuntimeError("No overlapping times found between obs and model!")
common_times = pd.to_datetime(common_times)
def interp_time(series_times, series_grids, new_times):
    series_times = pd.to_datetime(series_times)
    keep = ~pd.isnull(series_times)
    series_times = series_times[keep]
    series_grids = series_grids[keep]
    if len(series_grids) == 1:
        arr = np.repeat(series_grids, len(new_times), axis=0)
        return arr
    interp = interp1d(
        [pd.Timestamp(t).timestamp() for t in series_times],
        series_grids, axis=0, kind='linear', bounds_error=False, fill_value="extrapolate"
    )
    return interp([pd.Timestamp(t).timestamp() for t in new_times])
obs_interp = interp_time(obs_epochs, obs_on_model, common_times)
model_interp = interp_time(model_times, model_grids, common_times)

# --- USE YOUR UPSAMPLER TO PRODUCE SUPERRES ---
upsampler = tf.keras.models.load_model('final_upsampler_unet.h5', compile=False)
upsampler.compile(optimizer='adam', loss='mse')
_, h, w, c = upsampler.input_shape
print(f"Superres model expects: {(h, w)} spatial, {c} channels")

# --- Upsample model_interp to (h, w) as model input ---
model_up = []
for t in range(model_interp.shape[0]):
    upsampled = resize(model_interp[t], (h, w), order=3, anti_aliasing=True, preserve_range=True)
    model_up.append(upsampled)
model_up = np.stack(model_up)
model_up = model_up[..., np.newaxis]  # (frames, h, w, 1)

# --- Run the upsampler ---
superres = upsampler.predict(model_up, verbose=1)
superres = np.squeeze(superres)  # (frames, h, w)

# --- Upsample obs_interp to superres grid for direct comparison ---
upsampled_lats = np.linspace(model_lats.min(), model_lats.max(), h)
upsampled_lons = np.linspace(model_lons.min(), model_lons.max(), w)
obs_up = []
for arr in obs_interp:
    interp_func = RegularGridInterpolator((model_lats, model_lons), arr, bounds_error=False, fill_value=np.nan)
    mesh = np.meshgrid(upsampled_lats, upsampled_lons, indexing='ij')
    pts = np.column_stack([m.ravel() for m in mesh])
    up = interp_func(pts).reshape(h, w)
    obs_up.append(up)
obs_up = np.stack(obs_up)

# --- For visualization, also upsample model_interp (to compare true model at superres grid) ---
model_up_vis = model_up[...,0]  # (frames, h, w)

# --- ANIMATION SETUP (4-PANEL: Obs, Model, Superres, Diff) ---
import matplotlib
matplotlib.use('Agg')
titles = [
    "Observations (upsampled)", 
    "TIE-GCM Model (upsampled)", 
    "Super-resolved Model", 
    "Difference (Superres - Obs)"
]
diff_sr_obs = superres - obs_up
vmin = np.nanmin([obs_up, model_up_vis, superres])
vmax = np.nanmax([obs_up, model_up_vis, superres])
abs_max_diff = np.nanmax(np.abs(diff_sr_obs))

fig, axes = plt.subplots(1, 4, figsize=(28, 7), subplot_kw={'projection': ccrs.PlateCarree()})
ims = []
for i, (ax, title) in enumerate(zip(axes, titles)):
    ax.set_global()
    ax.coastlines(resolution='110m', linewidth=1)
    ax.add_feature(cfeature.BORDERS, linewidth=0.7, edgecolor='gray')
    ax.add_feature(cfeature.LAND, zorder=0, edgecolor='black', alpha=0.1)
    ax.set_title(title)
    if i < 3:
        ims.append(ax.pcolormesh(upsampled_lons, upsampled_lats, [obs_up, model_up_vis, superres][i][0],
                                 cmap='viridis', vmin=vmin, vmax=vmax, shading='auto', transform=ccrs.PlateCarree()))
    else:
        ims.append(ax.pcolormesh(upsampled_lons, upsampled_lats, diff_sr_obs[0],
                                 cmap='bwr', vmin=-abs_max_diff, vmax=abs_max_diff, shading='auto', transform=ccrs.PlateCarree()))
cb = fig.colorbar(ims[0], ax=axes[:3], orientation='horizontal', pad=0.09, label='TEC (TECU)', fraction=0.04)
cb_diff = fig.colorbar(ims[3], ax=axes[3], orientation='horizontal', pad=0.09, label='Superres - Obs (TECU)', fraction=0.04)

def update_all(i):
    ims[0].set_array(obs_up[i].ravel())
    ims[1].set_array(model_up_vis[i].ravel())
    ims[2].set_array(superres[i].ravel())
    ims[3].set_array(diff_sr_obs[i].ravel())
    for idx, ax in enumerate(axes):
        ax.set_title(f"{titles[idx]}\n{common_times[i]}")
    return ims

ani = animation.FuncAnimation(fig, update_all, frames=superres.shape[0], interval=150, blit=False)
ani.save('obs_vs_model_vs_superres_vs_diff.gif', writer='pillow', fps=8)
plt.close(fig)
print("\nSaved 4-panel animation: obs_vs_model_vs_superres_vs_diff.gif")
