This is the first attempt to model something.. and increase resolution to 4x and time to 5min timessteps. 

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as animation
import xarray as xr
import os
from scipy.interpolate import interp1d
from skimage.transform import resize
from tqdm import tqdm

# --- CONFIGURATION ---
spatial_factor = 4 # increase resolution
temporal_minutes = 5 # 5min timesteps
model_directory = './ionosphere_central/CCMC/model_data/TIE-GCM/129/Akshay_Ramesh_042125_IT_5/'

# --- LOAD OBSERVATIONAL DATA ---
df = pd.read_parquet("combined_ionex_vtec.parquet")
df= df*0.035
# --- LOAD MODEL DATA ---
model_files = sorted([f for f in os.listdir(model_directory) if not f.endswith('.json')])
model_grids, model_times = [], []
for fname in model_files:
    fpath = os.path.join(model_directory, fname)
    try:
        ds = xr.open_dataset(fpath)
        if 'TEC' not in ds.variables:
            print(f"Skipping {fname} (no 'TEC' variable)")
            continue
        tec = ds['TEC'].values
        if 'time' in ds:
            times = pd.to_datetime(ds['time'].values)
            for i in range(tec.shape[0]):
                model_grids.append(tec[i]/1e12)
                model_times.append(times[i])
        else:
            model_grids.append(tec/1e12)
            model_times.append(None)
    except Exception as e:
        print(f"Failed to load {fname}: {e}")

if len(model_grids) == 0:
    raise RuntimeError("No valid model grids found! Please check files and data.")

model_grids = np.stack(model_grids)     # (n_times, lat, lon)
model_times = pd.to_datetime([t if t is not None else pd.NaT for t in model_times])
model_lats = ds['lat'].values
model_lons = ds['lon'].values

# --- PREP OBSERVATIONAL DATA ---
lats = sorted(set([lat for lat, _ in df.columns]))
lons = sorted(set([lon for _, lon in df.columns]))
obs_epochs = pd.to_datetime(df.index)
df = df.sort_index(axis=1, level=['lat', 'lon'])
obs_grids = np.stack([df.loc[epoch].values.reshape((len(lats), len(lons))) for epoch in df.index], axis=0)

# --- INTERPOLATE OBSERVATIONS ONTO MODEL GRID ---
from scipy.interpolate import RegularGridInterpolator
obs_on_model = []
for g in obs_grids:
    interp_func = RegularGridInterpolator((lats, lons), g, bounds_error=False, fill_value=np.nan)
    mesh = np.meshgrid(model_lats, model_lons, indexing='ij')
    pts = np.column_stack([m.ravel() for m in mesh])
    out = interp_func(pts).reshape(len(model_lats), len(model_lons))
    obs_on_model.append(out)
obs_on_model = np.stack(obs_on_model)   # (n_times, lat, lon)

# --- SYNC TIME AXES: Interpolate obs/model to common set ---
all_epochs = sorted(set(model_times.dropna()).union(set(obs_epochs)))
all_epochs = pd.to_datetime(all_epochs)
# Pad obs/model if needed with NaN grids
def interp_time(series_times, series_grids, new_times):
    series_times = pd.to_datetime(series_times)
    keep = ~pd.isnull(series_times)
    series_times = series_times[keep]
    series_grids = series_grids[keep]
    # In case only 1 grid, just tile it
    if len(series_grids) == 1:
        arr = np.repeat(series_grids, len(new_times), axis=0)
        return arr
    interp = interp1d(
        [pd.Timestamp(t).timestamp() for t in series_times],
        series_grids, axis=0, kind='linear', bounds_error=False, fill_value="extrapolate"
    )
    return interp([pd.Timestamp(t).timestamp() for t in new_times])

obs_interp = interp_time(obs_epochs, obs_on_model, all_epochs)
model_interp = interp_time(model_times, model_grids, all_epochs)

# --- CONCATENATE FOR TRAINING DATASET ---
all_grids = np.concatenate([obs_interp, model_interp], axis=0)
np.random.shuffle(all_grids)
mean = np.nanmean(all_grids)
std = np.nanstd(all_grids)
all_grids_norm = (all_grids - mean) / std
all_grids_norm = np.nan_to_num(all_grids_norm)

# --- GENERATE TRAINING PAIRS (LOWRES -> UPSAMPLED, TARGET=SUPERRES) ---
X_train, Y_train = [], []
n_lat, n_lon = all_grids_norm.shape[1:3]
for i in range(all_grids_norm.shape[0]):
    lowres = resize(all_grids_norm[i], (n_lat//spatial_factor, n_lon//spatial_factor), order=3, anti_aliasing=True)
    upsampled = resize(lowres, (n_lat*spatial_factor, n_lon*spatial_factor), order=3, anti_aliasing=False)
    target = resize(all_grids_norm[i], (n_lat*spatial_factor, n_lon*spatial_factor), order=3, anti_aliasing=True)
    X_train.append(upsampled)
    Y_train.append(target)
X_train = np.stack(X_train)
Y_train = np.stack(Y_train)
X_train = torch.tensor(X_train[:,None,:,:], dtype=torch.float32)
Y_train = torch.tensor(Y_train[:,None,:,:], dtype=torch.float32)

# --- DEFINE SRCNN MODEL ---
class SRCNN(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, num_filters=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, num_filters, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(num_filters, num_filters//2, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(num_filters//2, out_channels, kernel_size=5, padding=2)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

# --- TRAIN SRCNN ---
model = SRCNN().cuda() if torch.cuda.is_available() else SRCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
epochs_train = 11
for epoch in range(epochs_train):
    perm = torch.randperm(X_train.shape[0])
    for idx in tqdm(perm, desc=f"Epoch {epoch+1}/{epochs_train}"):
        inp = X_train[idx:idx+1]
        tgt = Y_train[idx:idx+1]
        if torch.cuda.is_available():
            inp, tgt = inp.cuda(), tgt.cuda()
        out = model(inp)
        loss = loss_fn(out, tgt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} Loss: {loss.item():.5f}")

# --- INFERENCE: SUPER-RESOLVE ALL TIMES ---
up_lats = np.linspace(min(model_lats), max(model_lats), n_lat*spatial_factor)
up_lons = np.linspace(min(model_lons), max(model_lons), n_lon*spatial_factor)
hr_grids = []
hr_grid_times = []
model.eval()
for i, grid in enumerate(all_grids_norm):
    lowres = resize(grid, (n_lat//spatial_factor, n_lon//spatial_factor), order=3, anti_aliasing=True)
    upsampled = resize(lowres, (n_lat*spatial_factor, n_lon*spatial_factor), order=3, anti_aliasing=False)
    inp = torch.tensor(upsampled[None,None,:,:], dtype=torch.float32)
    if torch.cuda.is_available():
        inp = inp.cuda()
    with torch.no_grad():
        out = model(inp).cpu().numpy()[0,0]
    grid_hr = out * std + mean
    hr_grids.append(grid_hr)
    # The time associated with this grid:
    if i < len(all_epochs):
        hr_grid_times.append(all_epochs[i])
    else:
        hr_grid_times.append(pd.NaT)
hr_grids = np.stack(hr_grids)
hr_grid_times = pd.to_datetime(hr_grid_times)

# --- TEMPORAL INTERPOLATION TO EVERY 5 MIN ---
# Use only non-NaT times and associated grids for interpolation
valid_idx = ~pd.isnull(hr_grid_times)
hr_grids_valid = hr_grids[valid_idx]
hr_times_valid = hr_grid_times[valid_idx]

all_times = pd.date_range(start=hr_times_valid[0], end=hr_times_valid[-1], freq=f'{temporal_minutes}min')
if len(hr_grids_valid) == 1:
    all_hr_grids = np.repeat(hr_grids_valid, len(all_times), axis=0)
else:
    interp_func = interp1d(
        [pd.Timestamp(t).timestamp() for t in hr_times_valid],
        hr_grids_valid, axis=0, kind='linear', bounds_error=False, fill_value="extrapolate"
    )
    all_hr_grids = interp_func([pd.Timestamp(t).timestamp() for t in all_times])

# --- OUTPUT MODEL & RESULTS ---
torch.save(model.state_dict(), "superres_tec_srcnn.pt")
np.save("superres_tec_5min.npy", all_hr_grids)
print(f"Model and super-resolved TEC saved. Shape: {all_hr_grids.shape}")

# --- OPTIONAL: ANIMATION (FLAT AND ORTHOGRAPHIC) ---
import matplotlib
matplotlib.use('Agg')
fig1, ax1 = plt.subplots(figsize=(12,6), subplot_kw={'projection': ccrs.PlateCarree()})
im1 = ax1.pcolormesh(up_lons, up_lats, all_hr_grids[0], transform=ccrs.PlateCarree(), shading='auto')
cb1 = plt.colorbar(im1, ax=ax1, orientation='vertical', label='TEC (TECU)')
def update1(idx):
    im1.set_array(all_hr_grids[idx].ravel())
    ax1.set_title(f'Super-Resolved TEC {all_times[idx]}')
    return [im1]
ani1 = animation.FuncAnimation(fig1, update1, frames=len(all_times), interval=100, blit=False)
ani1.save('superres_tec_flat.gif', writer='pillow', fps=10)
plt.close(fig1)

fig2 = plt.figure(figsize=(8,8))
ax2 = plt.axes(projection=ccrs.Orthographic(central_longitude=0))
im2 = ax2.pcolormesh(up_lons, up_lats, all_hr_grids[0], transform=ccrs.PlateCarree(), shading='auto')
cb2 = plt.colorbar(im2, ax=ax2, orientation='vertical', label='TEC (TECU)')
def update2(idx):
    im2.set_array(all_hr_grids[idx].ravel())
    ax2.set_title(f'Super-Resolved TEC {all_times[idx]}')
    return [im2]
ani2 = animation.FuncAnimation(fig2, update2, frames=len(all_times), interval=100, blit=False)
ani2.save('superres_tec_ortho.gif', writer='pillow', fps=10)
plt.close(fig2)

print("Animations saved: Correctedsuperres_tec_flat.gif, Correctedsuperres_tec_ortho.gif")
