# 🧠 Train DeepESD for ERA5 → CERRA Spatial Downscaling

This notebook demonstrates how to train the DeepESD convolutional model on ERA5 (0.25°) to CERRA (0.05°) temperature data.

We will:
1. Load preprocessed NetCDF data (train/val sets)
2. Instantiate the DeepESD model
3. Train it using MSE loss
4. Save the trained model for inference and XAI


In [13]:
import os
import logging
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
import xarray as xr
from xbatcher import BatchGenerator
from torch.utils.data import DataLoader, TensorDataset

# Setup logging
logging.basicConfig(level=logging.INFO)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {DEVICE}")

INFO:root:Using device: cuda


In [14]:
class DeepESD(nn.Module):
    def __init__(
        self,
        input_shape: tuple,
        output_shape: tuple,
        input_channels: int,
        output_channels: int,
    ):
        super(DeepESD, self).__init__()
        self.output_shape = output_shape  # Store as attribute
        self.output_channels = output_channels

        self.conv1 = nn.Conv2d(input_channels, 50, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(50, 25, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(25, output_channels, kernel_size=3, padding=1)

        in_features = input_shape[0] * input_shape[1] * output_channels
        out_features = output_shape[0] * output_shape[1] * output_channels
        self.out = nn.Linear(in_features, out_features)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.flatten(x, start_dim=1)
        x = self.out(x)
        return x.view(
            -1, self.output_channels, self.output_shape[0], self.output_shape[1]
        )

In [15]:
class DoubleConv(nn.Module):
    """(Conv2D => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Downsampling path
        for feature in features:
            self.encoder.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Upsampling path
        self.up_transpose = nn.ModuleList()
        self.decoder = nn.ModuleList()

        for feature in reversed(features):
            self.up_transpose.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(DoubleConv(feature * 2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.encoder:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(len(self.up_transpose)):
            x = self.up_transpose[idx](x)
            skip_connection = skip_connections[idx]

            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])

            x = torch.cat((skip_connection, x), dim=1)
            x = self.decoder[idx](x)

        return self.final_conv(x)

In [16]:
def load_netcdf_pair(x_path, y_path, batch_size=16, variable_name="t2m", shuffle=True):
    ds_x = xr.open_dataset(x_path)
    ds_y = xr.open_dataset(y_path)
    ds_x = ds_x.transpose("time", "lat", "lon")
    ds_y = ds_y.transpose("time", "lat", "lon")

    # Ensure they share the same time steps (intersection only)
    common_times = np.intersect1d(ds_x["time"].values, ds_y["time"].values)
    ds_x = ds_x.sel(time=common_times)
    ds_y = ds_y.sel(time=common_times)

    if len(common_times) == 0:
        raise ValueError(f"No overlapping timestamps between {x_path} and {y_path}")

    x_gen = BatchGenerator(
        ds_x[[variable_name]],
        input_dims={
            "time": batch_size,
            "lat": len(ds_x.lat.values),
            "lon": len(ds_x.lon.values),
        },
        preload_batch=False,
    )

    y_gen = BatchGenerator(
        ds_y[[variable_name]],
        input_dims={
            "time": batch_size,
            "lat": len(ds_y.lat.values),
            "lon": len(ds_y.lon.values),
        },
        preload_batch=False,
    )

    def batch_to_tensor(x_batch, y_batch, variable_name="t2m"):
        x_arr = x_batch[variable_name].values
        y_arr = y_batch[variable_name].values

        # Ensure 3D shape before adding channel dim
        if x_arr.ndim == 2:
            x_arr = x_arr[None, :, :]  # Add time dim
        if y_arr.ndim == 2:
            y_arr = y_arr[None, :, :]

        x = torch.tensor(x_arr[:, None, :, :], dtype=torch.float32)
        y = torch.tensor(y_arr[:, None, :, :], dtype=torch.float32)
        return x, y

    data = [batch_to_tensor(xb, yb) for xb, yb in zip(x_gen, y_gen)]

    return DataLoader(data, batch_size=None, shuffle=shuffle)

In [21]:
# File paths
train_era5 = "../data/train_era5.nc"
train_cerra = "../data/train_cerra.nc"
val_era5 = "../data/val_era5.nc"
val_cerra = "../data/val_cerra.nc"
model_path = "../models/deepesd_trained.pt"
metrics_path = "../models/metrics.csv"

# Hyperparameters
epochs = 500
lr = 1e-4
batch_size = 16

# Load data
logging.info("Creating dataloaders...")
logging.info("Creating dataloader for training...")
train_dataloader = load_netcdf_pair(
    train_era5, train_cerra, variable_name="t2m", batch_size=batch_size, shuffle=True
)
logging.info("Creating dataloader for validation...")
val_dataloader = load_netcdf_pair(
    val_era5, val_cerra, variable_name="t2m", batch_size=batch_size, shuffle=False
)
logging.info("Dataloaders created.")

# Model setup
input_shape = train_dataloader.dataset[0][0].shape[-2:]
output_shape = train_dataloader.dataset[0][1].shape[-2:]

model = DeepESD(
    input_shape=input_shape,
    output_shape=output_shape,
    input_channels=1,
    output_channels=1,
)
model.to(DEVICE)
logging.info(
    f"Model initialized with input shape {input_shape}, output shape {output_shape}"
)

# Loss
criterion = nn.MSELoss()

# Optimizer
patience = 10  # Early stopping patience
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

# Early stopping vars
best_val_loss = float("inf")
epochs_no_improve = 0
early_stop = False

# Logging losses
train_losses = []
val_losses = []

logging.info("Starting training loop...")
for epoch in range(epochs):
    if early_stop:
        break

    logging.info(f"Epoch {epoch+1}/{epochs}")
    model.train()
    epoch_train_loss = 0.0
    train_samples = tqdm(
        train_dataloader, desc=f"[Epoch {epoch + 1}] Training", leave=False
    )

    for train_predictor, train_target in train_samples:
        train_predictor, train_target = train_predictor.to(DEVICE), train_target.to(
            DEVICE
        )
        optimizer.zero_grad()
        prediction = model(train_predictor)
        loss = criterion(prediction, train_target)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()
        train_samples.set_postfix(loss=loss.item())

    avg_train_loss = epoch_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)

    model.eval()
    epoch_val_loss = 0.0
    val_samples = tqdm(
        val_dataloader, desc=f"[Epoch {epoch + 1}] Validation", leave=False
    )

    with torch.no_grad():
        for val_predictor, val_target in val_samples:
            val_predictor, val_target = val_predictor.to(DEVICE), val_target.to(DEVICE)
            pred = model(val_predictor)
            loss = criterion(pred, val_target)
            epoch_val_loss += loss.item()
            val_samples.set_postfix(loss=loss.item())

    avg_val_loss = epoch_val_loss / len(val_dataloader)
    val_losses.append(avg_val_loss)

    # Scheduler step
    scheduler.step()

    # Log losses
    logging.info(
        f"Epoch {epoch+1} — Train Loss: {avg_train_loss:.6f} — Val Loss: {avg_val_loss:.6f} — LR: {scheduler.get_last_lr()[0]:.6e}"
    )

    # Early stopping logic
    if avg_val_loss <= best_val_loss * 0.95:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), model_path)
        logging.info(f"✅ Model improved and saved to {model_path}")
    else:
        epochs_no_improve += 1
        logging.info(f"No improvement for {epochs_no_improve} epoch(s)")

    if epochs_no_improve >= patience:
        logging.info(
            f"🛑 Early stopping triggered after {patience} epochs without improvement."
        )
        early_stop = True

# Save metrics
metrics_df = pd.DataFrame(
    {
        "epoch": list(range(1, len(train_losses) + 1)),
        "train_loss": train_losses,
        "val_loss": val_losses,
    }
)
os.makedirs(os.path.dirname(metrics_path), exist_ok=True)
metrics_df.to_csv(metrics_path, index=False)
logging.info(f"📈 Metrics saved to {metrics_path}")

INFO:root:Creating dataloaders...
INFO:root:Creating dataloader for training...
INFO:root:Creating dataloader for validation...
INFO:root:Dataloaders created.
INFO:root:Model initialized with input spatial dim torch.Size([63, 65]), target dim 43452
INFO:root:Starting training loop...
INFO:root:Epoch 1/500
INFO:root:Epoch 1 — Train Loss: 181.122335 — Val Loss: 140.233745 — LR: 9.999901e-06
INFO:root:✅ Model improved and saved to ../models/deepesd_trained.pt
INFO:root:Epoch 2/500
INFO:root:Epoch 2 — Train Loss: 77.742674 — Val Loss: 26.018454 — LR: 9.999605e-06
INFO:root:✅ Model improved and saved to ../models/deepesd_trained.pt
INFO:root:Epoch 3/500
INFO:root:Epoch 3 — Train Loss: 18.045032 — Val Loss: 15.617580 — LR: 9.999112e-06
INFO:root:✅ Model improved and saved to ../models/deepesd_trained.pt
INFO:root:Epoch 4/500
INFO:root:Epoch 4 — Train Loss: 14.909537 — Val Loss: 14.382968 — LR: 9.998421e-06
INFO:root:✅ Model improved and saved to ../models/deepesd_trained.pt
INFO:root:Epoch 5