# Explainability for Deep Learning Climate Downscaling

This notebook demonstrates how to apply **Explainable AI (XAI)** techniques to deep learning models trained for spatial downscaling from ERA5 to CERRA temperature data over South-East Europe (SEE). We will use the [Quantus](https://github.com/understandable-machine-intelligence-lab/quantus) library to apply and evaluate saliency-based XAI methods.

## 🔍 Goals
- Load trained models and test data
- Generate saliency maps using Quantus
- Compare explanation patterns between DeepESD and U-Net
- Visualize attribution over the spatial domain

In [17]:
import logging

logging.basicConfig(level=logging.INFO)
logging.info("🔁 Starting imports...")

try:
    import os

    logging.info("✅ Imported os")

    import torch

    logging.info("✅ Imported torch")

    import xarray as xr

    logging.info("✅ Imported xarray")

    import numpy as np

    logging.info("✅ Imported numpy")

    import matplotlib.pyplot as plt

    logging.info("✅ Imported matplotlib")

    import quantus

    logging.info("✅ Imported quantus")

    from xbatcher import BatchGenerator

    logging.info("✅ Imported xbatcher")

    from torch.utils.data import DataLoader

    logging.info("✅ Imported torch.utils.data")

    from IPython.display import display, Image

    logging.info("✅ Imported IPython.display")

    import cartopy.crs as ccrs

    logging.info("✅ Imported cartopy.crs")

    import cartopy.feature as cfeature

    logging.info("✅ Imported cartopy.feature")

    import warnings

    logging.info("✅ Imported warnings")

    from source.model_deepesd import DeepESD
    from source.model_unet import UNet

    logging.info("✅ Imported local models")

except Exception as e:
    logging.error(f"❌ Import failed: {e}")
    raise e

warnings.filterwarnings("ignore")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"📦 Using device: {DEVICE}")

INFO:root:🔁 Starting imports...
INFO:root:✅ Imported os
INFO:root:✅ Imported torch
INFO:root:✅ Imported xarray
INFO:root:✅ Imported numpy
INFO:root:✅ Imported matplotlib
INFO:root:✅ Imported quantus
INFO:root:✅ Imported xbatcher
INFO:root:✅ Imported torch.utils.data
INFO:root:✅ Imported IPython.display
INFO:root:✅ Imported cartopy.crs
INFO:root:✅ Imported cartopy.feature
INFO:root:✅ Imported local models
INFO:root:📦 Using device: cuda


## 📥 Load Test Data and Models

In this section, we load the preprocessed test data from ERA5 (inputs) and CERRA (targets), and load the trained models stored on disk.

In [18]:
from source.generate_dataloader import load_netcdf_pair
logging.info("Generating test dataloader...")

# Paths
test_era5 = "../data/test_era5.nc"
test_cerra = "../data/test_cerra.nc"

test_dataloader = load_netcdf_pair(test_era5, test_cerra)
input_sample, target_sample = test_dataloader.dataset[0]
logging.info("Test dataloader created successfully")
logging.info(f"Loaded one test sample with shape: {input_sample.shape}")

INFO:root:📥 Loading test data from NetCDF...
INFO:root:✅ Loaded one test sample with shape: torch.Size([16, 1, 63, 65])


## 🔧 Load Trained DeepESD and U-Net Models

Now we initialize both architectures and load their pre-trained weights. These models were trained for temperature downscaling in the WS4 training notebook.


In [19]:
# Paths to trained models
model_path_deepesd = "../models/model_deepesd.pt"
model_path_unet = "../models/model_unet.pt"

# Infer shapes from test data
input_shape = input_sample.shape[-2:]
output_shape = target_sample.shape[-2:]

logging.info("🔧 Loading DeepESD model...")
deepesd = DeepESD(input_shape, output_shape, 1, 1)
deepesd.load_state_dict(torch.load(model_path_deepesd, map_location=DEVICE))
deepesd.to(DEVICE).eval()
logging.info("✅ DeepESD loaded.")

logging.info("🔧 Loading U-Net model...")
unet = UNet(input_shape, output_shape, 1, 1)
unet.load_state_dict(torch.load(model_path_unet, map_location=DEVICE))
unet.to(DEVICE).eval()
logging.info("✅ U-Net loaded.")


INFO:root:🔧 Loading DeepESD model...
INFO:root:✅ DeepESD loaded.
INFO:root:🔧 Loading U-Net model...
INFO:root:✅ U-Net loaded.


## Load or Generate Explanations

In order to evaluate explainability techniques with Quantus, we need a batch of **inputs**, **targets**, and corresponding **attributions (explanations)**.

There are two options:

- 🔁 **Option A: Use pre-computed attributions**, e.g., from Captum
- 🧠 **Option B: Use a callable explanation function**, such as `quantus.explain()` or a custom function

In this example, we use **Captum** to compute both **Saliency** and **Integrated Gradients** attributions for a single input batch from our test dataset.


In [20]:
from captum.attr import Saliency, IntegratedGradients

logging.info("📥 Sampling input batch for explanation...")
x_batch, y_batch = next(iter(test_dataloader))
x_batch, y_batch = x_batch.to(DEVICE), y_batch.to(DEVICE)

logging.info("🧠 Generating Saliency and Integrated Gradients attributions with Captum...")
saliency_attr = Saliency(deepesd).attribute(inputs=x_batch, target=None, abs=True).sum(dim=1).detach().cpu().numpy()
intgrad_attr = IntegratedGradients(deepesd).attribute(
    inputs=x_batch, target=None, baselines=torch.zeros_like(x_batch)
).sum(dim=1).detach().cpu().numpy()

# Convert input and target to numpy
x_batch_np = x_batch.detach().cpu().numpy()
y_batch_np = y_batch.detach().cpu().numpy()

# Quick assertion to confirm shapes and types
assert all(isinstance(arr, np.ndarray) for arr in [x_batch_np, y_batch_np, saliency_attr, intgrad_attr])
logging.info(f"✅ Shapes — X: {x_batch_np.shape}, Y: {y_batch_np.shape}, Saliency: {saliency_attr.shape}, IG: {intgrad_attr.shape}")


INFO:root:📥 Sampling input batch for explanation...
INFO:root:🧠 Generating Saliency and Integrated Gradients attributions with Captum...


AssertionError: Target not provided when necessary, cannot take gradient with respect to multiple outputs.