Parsing IONEX observational data

Loading TIE-GCM model data

Interpolating both datasets onto a common grid and timeframe

Creating training tensors explicitly saved as X_train.pt and Y_train.pt

This is a U-Net––an encoder-decoder network with skip-connections that’s great for tasks where you need both global context and fine spatial detail (e.g., image segmentation or super-resolution). On top of the basic U-Net, it adds AttentionBlocks to each skip path:

ConvBlock: two back-to-back Conv → BatchNorm → ReLU layers to extract features at each resolution.

AttentionBlock: learns a soft “mask” (via a tiny network + sigmoid) that gates the features coming from the encoder, so the decoder only “pays attention” to the most relevant spatial regions.

Putting it together, the AttentionUNet

Encodes your input through two downsampling steps (via max-pooling) and ConvBlocks, building up a deep feature map.

Decodes by upsampling (ConvTranspose2d), using AttentionBlocks to select which encoder features to pass along, then ConvBlocks to fuse them.

Outputs a final 1-channel map (via a 1×1 convolution) matching the input’s spatial dimensions.

Why this model?
U-Net structure: preserves high-resolution detail through skip-connections

Attention: lets the network focus on the most informative features and ignore noise, improving accuracy on tasks where some spatial regions matter more than others

In [None]:
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import xarray as xr
import os
from scipy.interpolate import interp1d, RegularGridInterpolator
from skimage.transform import resize
from datetime import datetime, timedelta
from tqdm import tqdm

# --- CONFIGURATION ---
file_paths = [
    './ionosphere_central/vTEC_data/uqrg1000.24i',
    './ionosphere_central/vTEC_data/uqrg1010.24i'
]
model_directory = './ionosphere_central/CCMC/model_data/TIE-GCM/129/Akshay_Ramesh_042125_IT_5/'
BATCH_SIZE = 16
EPOCHS = 50
LR = 1e-4
SPATIAL_FACTOR = 4

# --- Parse IONEX Observations ---
def parse_ionex_file(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    header_end = [i for i, l in enumerate(lines) if 'END OF HEADER' in l][0]
    header = lines[:header_end]
    for line in header:
        if 'LAT1 / LAT2 / DLAT' in line:
            lat1, lat2, dlat = map(float, line.split()[:3])
        if 'LON1 / LON2 / DLON' in line:
            lon1, lon2, dlon = map(float, line.split()[:3])
    lats = np.arange(lat1, lat2 - 0.1, -abs(dlat))
    lons = np.arange(lon1, lon2 + 0.1, dlon)
    maps, epochs = [], []
    i = header_end + 1
    while i < len(lines):
        if 'START OF TEC MAP' in lines[i]:
            epoch_line = lines[i+1]
            y, mo, d, h, mi, s = map(int, epoch_line[:36].split())
            epoch = datetime(y, mo, d, h%24, mi, s) + timedelta(days=h//24)
            epochs.append(epoch)
            grid = []
            i += 2
            while 'END OF TEC MAP' not in lines[i]:
                if 'LAT/LON1/LON2/DLON/H' in lines[i]:
                    row = []
                    i += 1
                    while lines[i].strip() and 'LAT/LON1/LON2/DLON/H' not in lines[i] and 'END OF TEC MAP' not in lines[i]:
                        line_values = re.findall(r'[-+]?[0-9]*\.?[0-9]+', lines[i])
                        vals = [float(v) if v != '9999' else np.nan for v in line_values]
                        row.extend(vals)
                        i += 1
                    if len(row) == len(lons):
                        grid.append(row)
                else:
                    i += 1
            if len(grid) == len(lats):
                maps.append(grid)
        else:
            i += 1
    maps = np.array(maps)
    data = {(lat, lon): maps[:, lat_idx, lon_idx] for lat_idx, lat in enumerate(lats) for lon_idx, lon in enumerate(lons)}
    df = pd.DataFrame(data, index=epochs)
    df.columns = pd.MultiIndex.from_tuples(df.columns, names=['lat', 'lon'])
    return df

# Combine Observations Safely:
dfs = [parse_ionex_file(fp) for fp in file_paths if fp]

# Ensure matching columns across all observations
common_cols = dfs[0].columns
for df in dfs[1:]:
    common_cols = common_cols.intersection(df.columns)

# Keep only common columns in all dfs
dfs_aligned = [df[common_cols] for df in dfs]

# Concatenate and remove duplicates
obs_df = pd.concat(dfs_aligned)
obs_df = obs_df[~obs_df.index.duplicated(keep='first')].sort_index()

# Verify dimensions clearly
obs_lats = sorted(set(lat for lat, lon in obs_df.columns))
obs_lons = sorted(set(lon for lat, lon in obs_df.columns))


# --- Load TIE-GCM Model Data ---
model_files = sorted([
    f for f in os.listdir(model_directory) 
    if not f.endswith('.json') and f != '.ipynb_checkpoints'
])
model_grids, model_times = [], []
for fname in model_files:
    ds = xr.open_dataset(os.path.join(model_directory, fname), engine='netcdf4')
    if 'TEC' in ds:
        tec, times = ds['TEC'].values / 1e12, pd.to_datetime(ds['time'].values)
        model_grids.extend(tec)
        model_times.extend(times)
model_grids = np.stack(model_grids)
model_times = pd.to_datetime(model_times)
model_lats, model_lons = ds['lat'].values, ds['lon'].values

# --- Interpolation ---
common_times = pd.date_range(start=max(obs_df.index.min(), model_times.min()), end=min(obs_df.index.max(), model_times.max()), freq='5min')

# Prepare original lat/lon from observational data
obs_lats = sorted(set(lat for lat, lon in obs_df.columns))
obs_lons = sorted(set(lon for lat, lon in obs_df.columns))

obs_grids = []
for t in obs_df.index:
    obs_values = obs_df.loc[t].values.reshape(len(obs_lats), len(obs_lons))
    interp_func = RegularGridInterpolator((obs_lats, obs_lons), obs_values, bounds_error=False, fill_value=np.nan)
    
    mesh_lats, mesh_lons = np.meshgrid(model_lats, model_lons, indexing='ij')
    interp_points = np.column_stack([mesh_lats.ravel(), mesh_lons.ravel()])
    
    interpolated_grid = interp_func(interp_points).reshape(len(model_lats), len(model_lons))
    obs_grids.append(interpolated_grid)

obs_grids = np.stack(obs_grids)

interp_obs = interp1d(obs_df.index.astype(int), obs_grids, axis=0, bounds_error=False, fill_value='extrapolate')(common_times.astype(int))
interp_model = interp1d(model_times.astype(int), model_grids, axis=0, bounds_error=False, fill_value='extrapolate')(common_times.astype(int))

# --- Training Data Prep ---
data = np.concatenate([interp_obs, interp_model], axis=0)
mean, std = np.nanmean(data), np.nanstd(data)
np.save('tec_mean.npy', mean)
np.save('tec_std.npy', std)

data = np.nan_to_num((data - mean) / std)
X_train = torch.tensor(data[:, None], dtype=torch.float32)
Y_train = torch.tensor(data[:, None], dtype=torch.float32)

# Save prepared tensors
torch.save(X_train, 'X_train.pt')
torch.save(Y_train, 'Y_train.pt')


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm
from torchvision.transforms.functional import resize

# --- CONFIGURATION ---
BATCH_SIZE = 16
EPOCHS = 50
LR = 1e-4
SPATIAL_FACTOR = 4

# --- Enhanced Model: Residual Attention UNet ---
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Conv2d(F_g, F_int, kernel_size=1)
        self.W_x = nn.Conv2d(F_l, F_int, kernel_size=1)
        self.psi = nn.Conv2d(F_int, 1, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, g, x):
        psi = self.relu(self.W_g(g) + self.W_x(x))
        psi = self.sigmoid(self.psi(psi))
        return x * psi

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv(x)

class AttentionUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1):
        super().__init__()
        self.c1 = ConvBlock(in_ch, 64)
        self.p1 = nn.MaxPool2d(2)
        self.c2 = ConvBlock(64, 128)
        self.p2 = nn.MaxPool2d(2)
        self.c3 = ConvBlock(128, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.att2 = AttentionBlock(128, 128, 64)
        self.c4 = ConvBlock(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.att1 = AttentionBlock(64, 64, 32)
        self.c5 = ConvBlock(128, 64)

        self.final = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        c1 = self.c1(x)
        c2 = self.c2(self.p1(c1))
        c3 = self.c3(self.p2(c2))

        u2 = self.up2(c3)
        c2 = self.att2(u2, c2)
        u2 = torch.cat([u2, c2], dim=1)
        c4 = self.c4(u2)

        u1 = self.up1(c4)
        c1 = self.att1(u1, c1)
        u1 = torch.cat([u1, c1], dim=1)
        c5 = self.c5(u1)

        return self.final(c5)

# --- Data Loading (assume preprocessed and normalized tensors) ---
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# --- Initialize Model, Optimizer, Scheduler, Loss ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AttentionUNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
criterion = nn.MSELoss()

# --- Training Loop with Early Stopping ---
best_loss = float('inf')
patience, trigger = 10, 0

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    for xb, yb in tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}'):
        xb, yb = xb.to(device), yb.to(device)

        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)
    scheduler.step(epoch_loss)

    print(f'Epoch {epoch+1}, Loss: {epoch_loss:.6f}')

    # Early stopping
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'attention_unet_best.pt')
        trigger = 0
    else:
        trigger += 1
        if trigger >= patience:
            print('Early stopping triggered.')
            break

# --- Load Best Model ---
model.load_state_dict(torch.load('attention_unet_best.pt'))
model.eval()


In [None]:
from sklearn.model_selection import train_test_split

# Load previously saved tensors
X = torch.load('X_train.pt')
Y = torch.load('Y_train.pt')

# Split data into training and test sets
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

# Ensure test tensors are on the correct device
X_test = X_test.to(device)
Y_test = Y_test.to(device)

In [None]:
with torch.no_grad():
    predictions = model(X_test).cpu().numpy()

Y_test_np = Y_test.cpu().numpy()

Plot a frame from the obs vs the model prediction and also plot the difference. 

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

# --- Load Trained Model ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AttentionUNet().to(device)
model.load_state_dict(torch.load('attention_unet_best.pt'))
model.eval()

# Assume X_test and Y_test are already loaded and preprocessed
# X_test: Low-resolution input grids, Y_test: Ground-truth high-resolution grids

# --- Predict ---
with torch.no_grad():
    predictions = model(X_test.to(device)).cpu().numpy()

# Convert tensors to numpy arrays for plotting
Y_test_np = Y_test.cpu().numpy()

# --- De-normalize the predictions and targets ---
mean = np.load('tec_mean.npy')
std = np.load('tec_std.npy')

predictions = predictions * std + mean
Y_test_np = Y_test_np * std + mean


# --- Select a Sample for Plotting ---
idx = 0  # You can change the index to visualize different samples
prediction = predictions[idx, 0]
observation = Y_test_np[idx, 0]
difference = prediction - observation

# Define lat/lon grid (assuming known from preprocessing)
lats = np.linspace(-90, 90, prediction.shape[0])
lons = np.linspace(-180, 180, prediction.shape[1])
lon_grid, lat_grid = np.meshgrid(lons, lats)

# --- Plot Static Example ---
idx = 0
prediction = predictions[idx, 0]
observation = Y_test_np[idx, 0]
difference = prediction - observation

vmin = min(np.nanmin(observation), np.nanmin(prediction))
vmax = max(np.nanmax(observation), np.nanmax(prediction))

fig, axs = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={'projection': ccrs.PlateCarree()})

# Observation
im0 = axs[0].pcolormesh(lon_grid, lat_grid, observation, shading='auto', cmap='viridis', vmin=vmin, vmax=vmax)
axs[0].add_feature(cfeature.COASTLINE)
axs[0].set_title('Observation (TECU)')
plt.colorbar(im0, ax=axs[0], orientation='horizontal', fraction=0.046, pad=0.04, label='TECU')

# Prediction
im1 = axs[1].pcolormesh(lon_grid, lat_grid, prediction, shading='auto', cmap='viridis', vmin=vmin, vmax=vmax)
axs[1].add_feature(cfeature.COASTLINE)
axs[1].set_title('Model Prediction (TECU)')
plt.colorbar(im1, ax=axs[1], orientation='horizontal', fraction=0.046, pad=0.04, label='TECU')

# Difference
im2 = axs[2].pcolormesh(lon_grid, lat_grid, difference, shading='auto', cmap='RdBu_r', vmin=-np.max(np.abs(difference)), vmax=np.max(np.abs(difference)))
axs[2].add_feature(cfeature.COASTLINE)
axs[2].set_title('Difference (Prediction - Observation)')
plt.colorbar(im2, ax=axs[2], orientation='horizontal', fraction=0.046, pad=0.04, label='Δ TECU')

plt.tight_layout()
plt.savefig('TECcomparison_example.png')
plt.show()
plt.tight_layout()
plt.show()


No time codes taken.. so animation is extra broken!!!! dead end.. start again. 