# Explainability for Deep Learning Climate Downscaling

This notebook demonstrates how to apply **Explainable AI (XAI)** techniques to deep learning models trained for spatial downscaling from ERA5 to CERRA temperature data over South-East Europe (SEE). We will use the [Quantus](https://github.com/understandable-machine-intelligence-lab/quantus) library to apply and evaluate saliency-based XAI methods.

## 🔍 Goals
- Load trained models and test data
- Generate saliency maps using Quantus
- Compare explanation patterns between DeepESD and U-Net
- Visualize attribution over the spatial domain

In [1]:
import logging

logging.basicConfig(level=logging.INFO)
logging.info("🔁 Starting imports...")

try:
    import os

    logging.info("✅ Imported os")

    import torch

    logging.info("✅ Imported torch")

    import xarray as xr

    logging.info("✅ Imported xarray")

    import numpy as np

    logging.info("✅ Imported numpy")

    import pandas as pd

    logging.info("✅ Imported pandas")

    import matplotlib.pyplot as plt

    logging.info("✅ Imported matplotlib")

    import quantus

    logging.info("✅ Imported quantus")

    from xbatcher import BatchGenerator

    logging.info("✅ Imported xbatcher")

    from torch.utils.data import DataLoader

    logging.info("✅ Imported torch.utils.data")

    from IPython.display import display, Image, HTML

    logging.info("✅ Imported IPython.display")

    import cartopy.crs as ccrs

    logging.info("✅ Imported cartopy.crs")

    import cartopy.feature as cfeature

    logging.info("✅ Imported cartopy.feature")

    import warnings

    logging.info("✅ Imported warnings")

    from source.model_deepesd import DeepESD
    from source.model_unet import UNet

    logging.info("✅ Imported local models")

except Exception as e:
    logging.error(f"❌ Import failed: {e}")
    raise e

warnings.filterwarnings("ignore")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"📦 Using device: {DEVICE}")

INFO:root:🔁 Starting imports...
INFO:root:✅ Imported os
INFO:root:✅ Imported torch
INFO:root:✅ Imported xarray
INFO:root:✅ Imported numpy
INFO:root:✅ Imported pandas
INFO:root:✅ Imported matplotlib
  from .autonotebook import tqdm as notebook_tqdm
INFO:root:✅ Imported quantus
INFO:root:✅ Imported xbatcher
INFO:root:✅ Imported torch.utils.data
INFO:root:✅ Imported IPython.display
INFO:root:✅ Imported cartopy.crs
INFO:root:✅ Imported cartopy.feature
INFO:root:✅ Imported local models
INFO:root:📦 Using device: cuda


## 📥 Load Test Data and Models

In this section, we load the preprocessed test data from ERA5 (inputs) and CERRA (targets), and load the trained models stored on disk.

In [2]:
from source.generate_dataloader import load_netcdf_pair

logging.info("Generating test dataloader...")

# Paths
test_era5 = "../data/test_era5.nc"
input_lats = xr.open_dataset("../data/test_era5.nc")["lat"].values
input_lons = xr.open_dataset("../data/test_era5.nc")["lon"].values

test_cerra = "../data/test_cerra.nc"
output_lats = xr.open_dataset("../data/test_cerra.nc")["lat"].values
output_lons = xr.open_dataset("../data/test_cerra.nc")["lon"].values

test_dataloader = load_netcdf_pair(test_era5, test_cerra, batch_size=1)
input_sample, target_sample = test_dataloader.dataset[0]
logging.info("Test dataloader created successfully")
logging.info(f"Loaded one test sample with shape: {input_sample.shape}")

INFO:root:Generating test dataloader...
INFO:root:Test dataloader created successfully
INFO:root:Loaded one test sample with shape: torch.Size([1, 1, 63, 65])


## 🔧 Load Trained DeepESD and U-Net Models

Now we initialize both architectures and load their pre-trained weights. These models were trained for temperature downscaling in the WS4 training notebook.


In [3]:
# Paths to trained models
model_path_deepesd = "../models/model_deepesd.pt"
model_path_unet = "../models/model_unet.pt"

# Infer shapes from test data
input_shape = input_sample.shape[-2:]
output_shape = target_sample.shape[-2:]

logging.info("🔧 Loading DeepESD model...")
deepesd_model = DeepESD(input_shape, output_shape, 1, 1)
deepesd_model.load_state_dict(torch.load(model_path_deepesd, map_location=DEVICE))
deepesd_model.to(DEVICE).eval()
logging.info("✅ DeepESD loaded.")

logging.info("🔧 Loading U-Net model...")
unet_model = UNet(input_shape, output_shape, 1, 1)
unet_model.load_state_dict(torch.load(model_path_unet, map_location=DEVICE))
unet_model.to(DEVICE).eval()
logging.info("✅ U-Net loaded.")

INFO:root:🔧 Loading DeepESD model...
INFO:root:✅ DeepESD loaded.
INFO:root:🔧 Loading U-Net model...
INFO:root:✅ U-Net loaded.


## 🎯 Objective: Understanding Model Behavior at High-Error Locations
In this section, we aim to analyze how our deep learning models make predictions for temperature downscaling at spatial points where the models exhibit the highest errors. We focus on two models:

DeepESD: a custom deep learning architecture for spatial downscaling,

U-Net: a widely used convolutional neural network architecture in climate and image processing.

By identifying points where models perform poorly, we can apply Explainable AI (XAI) techniques to better understand why those predictions may be unreliable, inconsistent, or driven by unexpected patterns in the input.


In [4]:
# Load predictions and target data
deepesd_ds = xr.open_dataset("../data/test_deepesd.nc")
unet_ds = xr.open_dataset("../data/test_unet.nc")
target_ds = xr.open_dataset("../data/test_cerra.nc")

# Load data
deepesd_np = xr.open_dataset("../data/test_deepesd.nc")[
    "t2m"
].values  # (time, lat, lon)
unet_np = xr.open_dataset("../data/test_unet.nc")["t2m"].values
target_np = xr.open_dataset("../data/test_cerra.nc")["t2m"].values

# Ensure all arrays have shape (time, lat, lon)
if target_np.shape != deepesd_np.shape:
    target_np = np.transpose(target_np, (2, 0, 1))

## 📊 How Are the High-Error Pixels Selected?
To identify the most relevant locations:

1. Compute the RMSE (Root Mean Squared Error) between model predictions and the ground truth (CERRA) across time, for each pixel.

2. Select the top 5 pixels with the highest RMSE, but enforce a minimum spatial separation between selected points (e.g. 10 grid cells) to avoid clustering around the same area.

3. Map each selected output location (high-resolution CERRA grid) to the nearest input grid cell in ERA5. This is necessary because input and output grids have different spatial resolutions.

In [5]:
# Compute RMSE for each pixel (lat, lon)
rmse_deepesd = np.sqrt(np.mean((deepesd_np - target_np) ** 2, axis=0))
rmse_unet = np.sqrt(np.mean((unet_np - target_np) ** 2, axis=0))


def select_diverse_high_error_pixels(error_map, num_pixels=5, min_distance=10):
    flat_indices = np.argsort(error_map.ravel())[::-1]
    lat_lon_indices = np.array(np.unravel_index(flat_indices, error_map.shape)).T

    selected = []
    for lat_idx, lon_idx in lat_lon_indices:
        if all(
            np.linalg.norm(np.array([lat_idx, lon_idx]) - np.array(p)) >= min_distance
            for p in selected
        ):
            selected.append([lat_idx, lon_idx])
        if len(selected) == num_pixels:
            break

    return selected


# Apply to each RMSE map
coords_deepesd = select_diverse_high_error_pixels(
    rmse_deepesd, num_pixels=6, min_distance=10
)
coords_unet = select_diverse_high_error_pixels(rmse_unet, num_pixels=6, min_distance=10)


def find_nearest_index(array, value):
    return (np.abs(array - value)).argmin()


def build_pixel_df(model_name, coords, rmse_map):
    output_lat_idx = [lat for lat, lon in coords]
    output_lon_idx = [lon for lat, lon in coords]
    output_lat_val = [output_lats[lat] for lat in output_lat_idx]
    output_lon_val = [output_lons[lon] for lon in output_lon_idx]
    input_lat_idx = [find_nearest_index(input_lats, val) for val in output_lat_val]
    input_lon_idx = [find_nearest_index(input_lons, val) for val in output_lon_val]
    rmse_vals = [rmse_map[lat, lon] for lat, lon in coords]

    return pd.DataFrame(
        {
            "Model": [model_name] * len(coords),
            "RMSE": rmse_vals,
            "output_lat_idx": output_lat_idx,
            "output_lon_idx": output_lon_idx,
            "output_lat_value": output_lat_val,
            "output_lon_value": output_lon_val,
            "input_lat_idx": input_lat_idx,
            "input_lon_idx": input_lon_idx,
        }
    )


# Build separate DataFrames and concatenate
df_deepesd = build_pixel_df("DeepESD", coords_deepesd, rmse_deepesd)
df_unet = build_pixel_df("UNet", coords_unet, rmse_unet)
df = pd.concat([df_deepesd, df_unet], ignore_index=True)

display(df)

Unnamed: 0,Model,RMSE,output_lat_idx,output_lon_idx,output_lat_value,output_lon_value,input_lat_idx,input_lon_idx
0,DeepESD,2.140564,23,62,42.19,20.24,21,29
1,DeepESD,2.06592,30,123,42.54,23.29,22,41
2,DeepESD,2.029901,40,10,43.04,17.64,24,19
3,DeepESD,2.003727,15,68,41.79,20.54,19,30
4,DeepESD,1.979262,13,127,41.69,23.49,19,42
5,DeepESD,1.95981,34,152,42.74,24.74,23,47
6,UNet,3.168911,6,115,41.34,22.89,17,40
7,UNet,2.985608,13,127,41.69,23.49,19,42
8,UNet,2.957485,15,68,41.79,20.54,19,30
9,UNet,2.831704,0,100,41.04,22.14,16,37


# 🧠 Explainable AI Techniques Used

We use **saliency-based XAI methods** from the [Captum](https://captum.ai) library to explain how the models arrived at their predictions for each selected **high-error output pixel**.

The goal is to compute the **attribution of each input pixel** (from ERA5) to the prediction at a specific output location (in CERRA). This highlights **which parts of the input the model considers most influential** for a particular forecast.

---

## 🧭 Saliency Maps

**Saliency maps** compute the gradient of the output with respect to the input. They answer the question:

> *How much would a small change in each input pixel affect the prediction at the output pixel of interest?*

Key properties:
- ✅ Simple and fast to compute.
- ✅ Highlights the most influential regions in the input.
- ⚠️ Sensitive to noise and input perturbations.
- ⚠️ Can suffer from gradient saturation.

Mathematically, for an input $x$ and a model $f$:

![Saliency Equation](../outputs/figures/equations/saliency.jpg)

---

## 🧭 Integrated Gradients

**Integrated Gradients (IG)** aim to provide more stable and reliable attributions by averaging gradients along a straight-line path from a **baseline input** (e.g., a tensor of zeros) to the actual input.

IG addresses some of the limitations of raw gradients by integrating over the model’s sensitivity.

Formally, for an input $x$, baseline $x'$, and model $f$:

![Integrated Gradients Equation](../outputs/figures/equations/integrated_gradients.jpg)

Where:
- $x$ is the actual input.
- $x'$ is the baseline input.
- $f(x)$ is the model’s output at the selected output pixel.

Key advantages:
- ✅ More robust to noise.
- ✅ Provides **axiomatic guarantees** (e.g. completeness).
- ✅ More faithful to model internals.

In [6]:
from captum.attr import Saliency, IntegratedGradients
import torch
import numpy as np
import logging

# Ensure models are in eval mode
deepesd_model.eval()
unet_model.eval()

# Get one batch
x_batch, y_batch = test_dataloader.dataset[0]
x_batch, y_batch = x_batch.to(DEVICE), y_batch.to(DEVICE)

# Extract input-side pixel indices from df
deepesd_pixels = df[df["Model"] == "DeepESD"][
    ["output_lat_idx", "output_lon_idx"]
].values
unet_pixels = df[df["Model"] == "UNet"][["output_lat_idx", "output_lon_idx"]].values

# Get spatial dimensions from input batch
_, _, H, W = x_batch.shape


def compute_attributions(model, x, lat_idx, lon_idx):
    x = x.requires_grad_()
    lat_idx = min(lat_idx, H - 1)
    lon_idx = min(lon_idx, W - 1)

    def forward_func(input_tensor):
        output = model(input_tensor)
        return output[:, 0, lat_idx, lon_idx]  # shape (B,)

    model.zero_grad()
    saliency = Saliency(forward_func).attribute(x, abs=True)
    intgrad = IntegratedGradients(forward_func).attribute(
        x, baselines=torch.zeros_like(x)
    )

    return saliency.detach().cpu().numpy(), intgrad.detach().cpu().numpy()


# Store results
attributions = {
    "DeepESD": {"saliency": [], "ig": []},
    "UNet": {"saliency": [], "ig": []},
}

# DeepESD
logging.info("🔍 Generating attributions for DeepESD...")
for lat_idx, lon_idx in deepesd_pixels:
    sal, ig = compute_attributions(deepesd_model, x_batch, lat_idx, lon_idx)
    attributions["DeepESD"]["saliency"].append(sal)
    attributions["DeepESD"]["ig"].append(ig)

# UNet
logging.info("🔍 Generating attributions for UNet...")
for lat_idx, lon_idx in unet_pixels:
    sal, ig = compute_attributions(unet_model, x_batch, lat_idx, lon_idx)
    attributions["UNet"]["saliency"].append(sal)
    attributions["UNet"]["ig"].append(ig)

logging.info("✅ All saliency and IG maps computed.")

INFO:root:🔍 Generating attributions for DeepESD...
INFO:root:🔍 Generating attributions for UNet...
INFO:root:✅ All saliency and IG maps computed.


In [7]:
print("DeepESD saliency:", len(attributions["DeepESD"]["saliency"]))
print("UNet saliency:", len(attributions["UNet"]["saliency"]))
print("DeepESD IG:", len(attributions["DeepESD"]["ig"]))
print("UNet IG:", len(attributions["UNet"]["ig"]))

DeepESD saliency: 6
UNet saliency: 6
DeepESD IG: 6
UNet IG: 6


In [8]:
from source.plot_explanation_maps import plot_explanation_map

In [9]:
from IPython.display import Markdown, display

with open("../source/plot_explanation_maps.py", "r") as f:
    code = f.read()

display(Markdown(f"```python\n{code}\n```"))

```python
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import xarray as xr


def plot_explanation_map(
    attr,
    title,
    input_lats=None,
    input_lons=None,
    pixel_coords=None,
    filename=None,
    vmin=None,
    vmax=None,
    cbar_label="Attribution Value",
):

    # Convert numpy to xarray.DataArray
    data = attr[0, 0, :, :]
    if vmin is None:
        vmin = np.nanmin(data)
    if vmax is None:
        vmax = np.nanmax(data)

    if vmin < 0 < vmax:
        cmap = "RdBu_r"
        if abs(vmin) > abs(vmax):
            vmax = abs(vmin)
        else:
            vmin = -vmax
    else:
        cmap = "Oranges"

    da = xr.DataArray(
        data,
        dims=("lat", "lon"),
        coords={"lat": input_lats, "lon": input_lons},
        name="attribution",
    )

    fig = plt.figure(figsize=(8, 6))
    ax = plt.axes(projection=ccrs.PlateCarree())

    # Plot attribution
    mesh = da.plot.pcolormesh(
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        add_colorbar=False,  # Disable automatic colorbar
    )

    # Add coastlines and borders
    ax.coastlines(resolution="10m", linestyle="-", linewidths=0.5)
    ax.add_feature(cfeature.BORDERS, linestyle="-", linewidth=0.5)

    # Optional pixel marker
    if pixel_coords:
        lat_idx, lon_idx = pixel_coords
        lat_val = input_lats[lat_idx]
        lon_val = input_lons[lon_idx]
        ax.plot(
            lon_val,
            lat_val,
            marker="o",
            color="black",
            markersize=6,
            transform=ccrs.PlateCarree(),
        )

    # Add title
    ax.set_title(title, fontsize=15)

    # Add manual colorbar
    cbar_position = [0.15, 0.05, 0.7, 0.05]
    cbar_ax = fig.add_axes(cbar_position)
    cbar = fig.colorbar(mesh, cax=cbar_ax, orientation="horizontal")
    cbar.set_label(cbar_label, fontsize=12)

    # Save or show
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

    return filename

```

In [10]:
# Save paths to reuse in display cells
paths_deepesd_saliency, paths_deepesd_ig = [], []
paths_unet_saliency, paths_unet_ig = [], []

os.makedirs("../outputs/figures/explainability", exist_ok=True)

for i in range(6):
    output_lat_idx, output_lon_idx = deepesd_pixels[i]
    input_coords = df[
        (df["output_lat_idx"] == output_lat_idx)
        & (df["output_lon_idx"] == output_lon_idx)
    ][["input_lat_idx", "input_lon_idx"]].values[0]
    input_lat_idx, input_lon_idx = input_coords

    # Saliency
    path = f"../outputs/figures/explainability/deepesd_saliency_high-error-pixels-{i+1}.png"
    plot_explanation_map(
        attributions["DeepESD"]["saliency"][i],
        title="DeepESD - Saliency",
        input_lats=input_lats,
        input_lons=input_lons,
        filename=path,
        pixel_coords=(input_lat_idx, input_lon_idx),
    )
    paths_deepesd_saliency.append(path)
    logging.info(f"📸 Figure saved to {path}")

    # Integrated Gradients
    path = f"../outputs/figures/explainability/deepesd_ig_high-error-pixels-{i+1}.png"
    plot_explanation_map(
        attributions["DeepESD"]["ig"][i],
        title="DeepESD - IntegratedGradients",
        input_lats=input_lats,
        input_lons=input_lons,
        filename=path,
        pixel_coords=(input_lat_idx, input_lon_idx),
    )
    paths_deepesd_ig.append(path)
    logging.info(f"📸 Figure saved to {path}")

for i in range(6):
    output_lat_idx, output_lon_idx = unet_pixels[i]
    input_coords = df[
        (df["output_lat_idx"] == output_lat_idx)
        & (df["output_lon_idx"] == output_lon_idx)
    ][["input_lat_idx", "input_lon_idx"]].values[0]
    input_lat_idx, input_lon_idx = input_coords

    # Saliency
    path = f"../outputs/figures/explainability/unet_saliency_high-error-pixels-{i+1}.png"
    plot_explanation_map(
        attributions["UNet"]["saliency"][i],
        title="UNet - Saliency",
        input_lats=input_lats,
        input_lons=input_lons,
        filename=path,
        pixel_coords=(input_lat_idx, input_lon_idx),
    )
    paths_unet_saliency.append(path)
    logging.info(f"📸 Figure saved to {path}")

    # Integrated Gradients
    path = f"../outputs/figures/explainability/unet_ig_high-error-pixels-{i+1}.png"
    plot_explanation_map(
        attributions["UNet"]["ig"][i],
        title="UNet - IntegratedGradients",
        input_lats=input_lats,
        input_lons=input_lons,
        filename=path,
        pixel_coords=(input_lat_idx, input_lon_idx),
    )
    paths_unet_ig.append(path)
    logging.info(f"📸 Figure saved to {path}")

INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_saliency_high-error-pixels-1.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_ig_high-error-pixels-1.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_saliency_high-error-pixels-2.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_ig_high-error-pixels-2.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_saliency_high-error-pixels-3.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_ig_high-error-pixels-3.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_saliency_high-error-pixels-4.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_ig_high-error-pixels-4.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_saliency_high-error-pixels-5.png
INFO:root:📸 Figure saved to ../outputs/figures/explainability/deepesd_ig_high-error-pixels-5.png


In [11]:
def display_image_grid(title, image_paths):
    html = f'<h3 style="margin-bottom: 0.5em;">{title}</h3>'
    html += '<table style="width:100%; border-collapse: collapse;">'
    for i in range(0, len(image_paths), 3):
        html += "<tr>"
        for j in range(3):
            if i + j < len(image_paths):
                path = image_paths[i + j]
                html += f"""
                <td style="text-align:center; padding:4px;">
                    <img src="{path}" style="width:100%; max-width:480px; height:auto; display:block; margin:auto;" />
                    <div style="font-size:18px;">Pixel {i + j + 1}</div>
                </td>
                """
        html += "</tr>"
    html += "</table>"
    display(HTML(html))

In [12]:
# Display each group
display_image_grid("DeepESD - Saliency Maps", paths_deepesd_saliency)
display_image_grid("DeepESD - Integrated Gradients", paths_deepesd_ig)
display_image_grid("UNet - Saliency Maps", paths_unet_saliency)
display_image_grid("UNet - Integrated Gradients", paths_unet_ig)

0,1,2
Pixel 1,Pixel 2,Pixel 3
Pixel 4,Pixel 5,Pixel 6


0,1,2
Pixel 1,Pixel 2,Pixel 3
Pixel 4,Pixel 5,Pixel 6


0,1,2
Pixel 1,Pixel 2,Pixel 3
Pixel 4,Pixel 5,Pixel 6


0,1,2
Pixel 1,Pixel 2,Pixel 3
Pixel 4,Pixel 5,Pixel 6


## 🔍 Explainability at Low-Error Locations

To complement our analysis of high-error predictions, we now turn our attention to **the 6 lowest-error locations** for both DeepESD and UNet. These points represent spatial grid cells where model performance was best, as measured by RMSE.

By comparing the attribution patterns at these low-error locations to those observed at high-error locations, we can better understand what kinds of input patterns lead to **accurate, reliable predictions**, and whether these are driven by:

- Clear and localized attribution around the output pixel
- Broader or more diffuse input patterns
- Differences in how each model uses spatial context

This side-by-side comparison helps shed light on **how** and **where** each model succeeds—and whether their behavior is consistent with domain knowledge.

In the sections below, we reproduce saliency maps and integrated gradients visualizations, but this time focused on the top-6 **low-RMSE** pixels for each model.


In [13]:
# Select lowest-error pixels
def select_diverse_low_error_pixels(error_map, num_pixels=6, min_distance=10):
    flat_indices = np.argsort(error_map.ravel())  # ascending order
    lat_lon_indices = np.array(np.unravel_index(flat_indices, error_map.shape)).T

    selected = []
    for lat_idx, lon_idx in lat_lon_indices:
        if all(
            np.linalg.norm(np.array([lat_idx, lon_idx]) - np.array(p)) >= min_distance
            for p in selected
        ):
            selected.append([lat_idx, lon_idx])
        if len(selected) == num_pixels:
            break
    return selected


# Apply for DeepESD and UNet
coords_deepesd_low = select_diverse_low_error_pixels(rmse_deepesd)
coords_unet_low = select_diverse_low_error_pixels(rmse_unet)

# Build dataframe
df_low_deepesd = build_pixel_df("DeepESD", coords_deepesd_low, rmse_deepesd)
df_low_unet = build_pixel_df("UNet", coords_unet_low, rmse_unet)
df_low = pd.concat([df_low_deepesd, df_low_unet], ignore_index=True)

display(df_low)

Unnamed: 0,Model,RMSE,output_lat_idx,output_lon_idx,output_lat_value,output_lon_value,input_lat_idx,input_lon_idx
0,DeepESD,0.710637,9,3,41.49,17.29,18,17
1,DeepESD,0.720826,3,11,41.19,17.69,17,19
2,DeepESD,0.747795,19,0,41.99,17.14,20,17
3,DeepESD,0.752813,0,21,41.04,18.19,16,21
4,DeepESD,0.7856,14,12,41.74,17.74,19,19
5,DeepESD,0.80803,10,22,41.54,18.24,18,21
6,UNet,0.76,96,68,45.84,20.54,35,30
7,UNet,0.77348,15,3,41.79,17.29,19,17
8,UNet,0.793121,10,12,41.54,17.74,18,19
9,UNet,0.804478,86,66,45.34,20.44,33,30


In [14]:

# Attribution extraction
x_batch, y_batch = test_dataloader.dataset[0]
x_batch = x_batch.to(DEVICE)

attributions_low = {
    "DeepESD": {"saliency": [], "ig": []},
    "UNet": {"saliency": [], "ig": []},
}

logging.info("🔎 Generating low-error attributions for DeepESD...")
for lat_idx, lon_idx in df_low_deepesd[["output_lat_idx", "output_lon_idx"]].values:
    sal, ig = compute_attributions(deepesd_model, x_batch, lat_idx, lon_idx)
    attributions_low["DeepESD"]["saliency"].append(sal)
    attributions_low["DeepESD"]["ig"].append(ig)

logging.info("🔎 Generating low-error attributions for UNet...")
for lat_idx, lon_idx in df_low_unet[["output_lat_idx", "output_lon_idx"]].values:
    sal, ig = compute_attributions(unet_model, x_batch, lat_idx, lon_idx)
    attributions_low["UNet"]["saliency"].append(sal)
    attributions_low["UNet"]["ig"].append(ig)

INFO:root:🔎 Generating low-error attributions for DeepESD...
INFO:root:🔎 Generating low-error attributions for UNet...


In [15]:
# Save and prepare for display
paths_low_deepesd_saliency, paths_low_deepesd_ig = [], []
paths_low_unet_saliency, paths_low_unet_ig = [], []

for i, row in df_low_deepesd.iterrows():
    input_coords = (row["input_lat_idx"], row["input_lon_idx"])
    path = f"../outputs/figures/explainability/deepesd_saliency_low-error-pixels-{i+1}.png"
    plot_explanation_map(
        attributions_low["DeepESD"]["saliency"][i],
        "DeepESD - Saliency (Low Error)",
        input_lats,
        input_lons,
        pixel_coords=input_coords,
        filename=path,
    )
    paths_low_deepesd_saliency.append(path)

    path = f"../outputs/figures/explainability/deepesd_ig_low-error-pixels-{i+1}.png"
    plot_explanation_map(
        attributions_low["DeepESD"]["ig"][i],
        "DeepESD - IntegratedGradients (Low Error)",
        input_lats,
        input_lons,
        pixel_coords=input_coords,
        filename=path,
    )
    paths_low_deepesd_ig.append(path)

for i, row in df_low_unet.iterrows():
    input_coords = (row["input_lat_idx"], row["input_lon_idx"])
    path = f"../outputs/figures/explainability/unet_saliency_low-error-pixels-{i+1}.png"
    plot_explanation_map(
        attributions_low["UNet"]["saliency"][i],
        "UNet - Saliency (Low Error)",
        input_lats,
        input_lons,
        pixel_coords=input_coords,
        filename=path,
    )
    paths_low_unet_saliency.append(path)

    path = f"../outputs/figures/explainability/unet_ig_low-error-pixels-{i+1}.png"
    plot_explanation_map(
        attributions_low["UNet"]["ig"][i],
        "UNet - IntegratedGradients (Low Error)",
        input_lats,
        input_lons,
        pixel_coords=input_coords,
        filename=path,
    )
    paths_low_unet_ig.append(path)

In [16]:
# Display them
display_image_grid("DeepESD - Saliency Maps", paths_low_deepesd_saliency)
display_image_grid("DeepESD - Integrated Gradients", paths_low_deepesd_ig)
display_image_grid("UNet - Saliency Maps", paths_low_unet_saliency)
display_image_grid("UNet - Integrated Gradients", paths_low_unet_ig)

0,1,2
Pixel 1,Pixel 2,Pixel 3
Pixel 4,Pixel 5,Pixel 6


0,1,2
Pixel 1,Pixel 2,Pixel 3
Pixel 4,Pixel 5,Pixel 6


0,1,2
Pixel 1,Pixel 2,Pixel 3
Pixel 4,Pixel 5,Pixel 6


0,1,2
Pixel 1,Pixel 2,Pixel 3
Pixel 4,Pixel 5,Pixel 6
