# Explainable AI for Earth Science: Practical Concepts, Workflows, and Pitfalls

*ELLIS Summer School: AI for Earth and Climate Sciences, Jena (Germany), September 1–5, 2025*  

https://github.com/ELLIS-Jena-Summer-School/XAI-tutorial

**Prepared by:** Shijie Jiang (Max Planck Institute for Biogeochemistry)

**Date:** 2025-09-02 

**Reference**: Jiang, S., Sweet, L.-b., Blougouras, G., Brenning, A., Li, W., Reichstein, M., et al. (2024). How interpretable machine learning can benefit process understanding in the geosciences. Earth's Future, 12, e2024EF004540. https://doi.org/10.1029/2024EF004540


This notebook is part of the hands-on tutorial for the summer school.  

This tutorial introduces **explainable AI (XAI)** methods for machine learning models, with a focus on **time series data in geoscience**.  

As a case study, we predict daily **Gross Primary Production (GPP)** from meteorological drivers such as **air temperature, shortwave radiation, precipitation, and vapor pressure deficit**. The goal is to understand how to explain its predictions.

<img src="img/slide1.jpg"/>

Although the example centers on time series, the concepts extend to other data types such as **tabular data** and **images**.  

The tutorial is structured in two parts:  

1. **Workflow** – how to apply XAI methods (e.g., Integrated Gradients) to a time series prediction task, and how to interpret the outputs.  
2. **Pitfalls** – common challenges in applying XAI to scientific problems, and strategies to avoid misinterpretation.  

By the end, you will have a practical understanding of how XAI methods can be applied in geoscience, and what is required to use them correctly.  

<img src="img/slide2.jpg"/>

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import random
import numpy as np
from tqdm import trange
from sklearn.metrics import r2_score

from captum.attr import Saliency, IntegratedGradients

# I. Workflow Overview

In this tutorial we will follow a clear workflow for applying XAI to a time series prediction task:  

**Build and train the model → Evaluate predictions → Apply XAI**

**Problem Context**  
> Gross Primary Production (**GPP**) is the total amount of carbon fixed by plants through photosynthesis.  
> It is a central measure of ecosystem productivity and plays a key role in the global carbon cycle.  
>   
> Our goal in this case study is to **predict daily GPP** using a time series machine learning approach.  
> Later, we will apply interpretability methods to examine what the model has learned.  
>
**Data**  
> The dataset contains **real-world observations** of daily values from **2000-01-01 to 2019-12-31**.  
> 
> Inputs (X): meteorological drivers of photosynthesis  
> - **Air Temperature (°C)** – influences plant metabolic rates and photosynthetic efficiency  
> - **Downward Shortwave Radiation (W/m2)** – sunlight energy available for photosynthesis  
> - **Precipitation (mm/day)** – affects soil water availability for plant growth  
> - **Vapour Pressure Deficit (kPa)** – indicator of atmospheric dryness; high values can limit water uptake and reduce carbon assimilation  
> 
> Target (y): ecosystem carbon flux  
> - **Gross Primary Production (gC/m2/day)** – rate at which vegetation absorbs carbon from the atmosphere  
>
**Why Time Series?**  
> Photosynthesis depends not only on current conditions but also on the recent past.  

## 1. Load and Inspect the Data

We now load the dataset into a pandas DataFrame and take a first look at the variables and time span.  

In [None]:
# Load CSV with date as index
path = "xai_data.csv"
df = pd.read_csv(path, parse_dates=["date"], index_col="date")
df = df.sort_index()

# Select features and target
feature_cols = ["air_temp", "shortwave_rad", "precip", "vpd"]
target_col = "gpp"

# Quick check
print(f"Data shape: {df.shape}")

In [None]:
# Full descriptive names for titles
full_names = {
    "air_temp": "Air Temperature",
    "shortwave_rad": "Downward Shortwave Radiation",
    "precip": "Precipitation",
    "vpd": "Vapour Pressure Deficit",
    "gpp": "Gross Primary Production"
}

# Units
units = {
    "air_temp": "°C",
    "shortwave_rad": "W/m2",
    "precip": "mm/day",
    "vpd": "kPa",
    "gpp": "gC/m2/day"
}

vars_to_plot = feature_cols + [target_col]

fig, axes = plt.subplots(len(vars_to_plot), 1, figsize=(8, 7), sharex=True)
for ax, var in zip(axes, vars_to_plot):
    df[var].plot(ax=ax, linewidth=1)
    ax.set_title(f"{full_names[var]} ({units[var]})", fontsize=10, loc="left")
plt.tight_layout()
plt.show()

## 2. Prepare the Data

We will train a sequence-to-one model: the input is the past 90 days of meteorological drivers including the current day, and the output is the GPP for that same day.

First, we split the dataset into training, validation, and test periods based on years. This avoids data leakage from future to past.

Feature scaling and target scaling are **fit on the training period only** and applied to validation/test.

In [None]:
seq_len = 90  # number of days in each input sequence
train_years = (2000, 2013)
val_years   = (2014, 2016)
test_years  = (2017, 2020)

# Split by year
df_train = df.loc[str(train_years[0]):str(train_years[1])]
df_val   = df.loc[str(val_years[0]):str(val_years[1])]
df_test  = df.loc[str(test_years[0]):str(test_years[1])]

print(f"Train: {len(df_train)}, Val: {len(df_val)}, Test: {len(df_test)}")

In [None]:
X_scaler = StandardScaler().fit(df_train[feature_cols])
y_scaler = StandardScaler().fit(df_train[[target_col]])

def scale_df(df, X_scaler, y_scaler):
    out = df.copy()
    out[feature_cols] = X_scaler.transform(df[feature_cols])
    out[target_col] = y_scaler.transform(df[[target_col]])
    return out

df_train_scaled = scale_df(df_train, X_scaler, y_scaler)
df_val_scaled   = scale_df(df_val, X_scaler, y_scaler)
df_test_scaled  = scale_df(df_test, X_scaler, y_scaler)

In [None]:
def create_sequences(data, seq_len, feature_cols, target_col):
    X, y, dates = [], [], []
    arr_X = data[feature_cols].values
    arr_y = data[target_col].values
    idx = data.index
    for i in range(len(data) - seq_len + 1):
        X.append(arr_X[i:i+seq_len])
        y.append(arr_y[i+seq_len-1])    # same-day target
        dates.append(idx[i+seq_len-1])  # date of target
    return np.array(X), np.array(y), dates

# Create sequences
X_train, y_train, dates_train = create_sequences(df_train_scaled, seq_len, feature_cols, target_col)
X_val,   y_val,   dates_val   = create_sequences(df_val_scaled, seq_len, feature_cols, target_col)
X_test,  y_test,  dates_test  = create_sequences(df_test_scaled, seq_len, feature_cols, target_col)

# Convert to tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32).unsqueeze(-1)
X_val_t   = torch.tensor(X_val, dtype=torch.float32)
y_val_t   = torch.tensor(y_val, dtype=torch.float32).unsqueeze(-1)
X_test_t  = torch.tensor(X_test, dtype=torch.float32)
y_test_t  = torch.tensor(y_test, dtype=torch.float32).unsqueeze(-1)

print(f"Input shape:\n  Train - {X_train_t.shape}\n  Val   - {X_val_t.shape}\n  Test  - {X_test_t.shape}")
print(f"Target shape:\n  Train - {y_train_t.shape}\n  Val   - {y_val_t.shape}\n  Test  - {y_test_t.shape}")

## 3. Define the Model

We use a Long Short-Term Memory (**LSTM**) network for this task.

- LSTMs are a type of recurrent neural network (RNN) designed to handle sequential data.
- They can learn dependencies across time steps, making them well-suited for meteorological time series.
- Our model takes 90 days of weather inputs and outputs a single GPP value for the last day.

We use:
- One LSTM layer
- A fully connected output layer
- Mean squared error loss
- Adam optimizer

In [None]:
# --- Reproducibility ---
def set_seed(seed=42):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Model ---
class LSTMRegressor(nn.Module):
    def __init__(self, input_size, hidden_size=64, num_layers=1, dropout=0.0):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
                            batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm(x)        # [B, T, H]
        last = out[:, -1, :]         # last time step
        yhat = self.fc(last)         # [B, 1]
        return yhat

# --- Init function ---
def init_model(input_size, hidden_size=64, num_layers=1, dropout=0.0, seed=42, device=None):
    set_seed(seed)
    model = LSTMRegressor(input_size, hidden_size, num_layers, dropout)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return model.to(device)

# Example: initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = init_model(len(feature_cols), hidden_size=64, seed=42, device=device)

## 4. Train the Model

We train the LSTM using mini-batches and monitor performance on the validation set.
We include:

- **Early stopping**: stop training when validation loss does not improve for several epochs.
- **Learning rate scheduling**: automatically reduce learning rate when validation loss plateaus.

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                device, max_epochs=30, patience=10):
    best_val_loss = float('inf')
    epochs_no_improve = 0
    train_losses, val_losses = [], []

    for epoch in trange(1, max_epochs+1, desc="Training"):
        # --- Training ---
        model.train()
        train_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * xb.size(0)
        train_loss /= len(train_loader.dataset)

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                preds = model(xb)
                loss = criterion(preds, yb)
                val_loss += loss.item() * xb.size(0)
        val_loss /= len(val_loader.dataset)

        # Record
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        # Adjust LR
        scheduler.step(val_loss)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"\nEarly stopping at epoch {epoch}")
                break

    model.load_state_dict(best_model_state)
    return model, train_losses, val_losses

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

# DataLoaders
batch_size = 64
train_loader = DataLoader(TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(TensorDataset(X_test_t, y_test_t), batch_size=batch_size, shuffle=False)


model, train_losses, val_losses = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device)

## 5. Evaluate the Model

We assess performance with the R2 score for training, validation, and test periods.

In [None]:
def evaluate_model(model, X_t, y_t, dates, y_scaler):
    model.eval()
    with torch.no_grad():
        preds_scaled = model(X_t).cpu().numpy().ravel()
    # inverse transform
    preds = y_scaler.inverse_transform(preds_scaled.reshape(-1, 1)).ravel()
    truth = y_scaler.inverse_transform(y_t.cpu().numpy().ravel().reshape(-1, 1)).ravel()
    r2 = r2_score(truth, preds)
    return r2, pd.Series(truth, index=dates), pd.Series(preds, index=dates)

# Evaluate
r2_train, y_true_train, y_pred_train = evaluate_model(model, X_train_t, y_train_t, dates_train, y_scaler)
r2_val,   y_true_val,   y_pred_val   = evaluate_model(model, X_val_t, y_val_t, dates_val, y_scaler)
r2_test,  y_true_test,  y_pred_test  = evaluate_model(model, X_test_t, y_test_t, dates_test, y_scaler)

In [None]:
fig, ax = plt.subplots(figsize=(12, 4))

# Observed GPP for the whole period
ax.plot(df.index, df[target_col], color="black", linewidth=0.7, label="Observed")

# Predictions by split
ax.plot(y_true_train.index, y_pred_train, color="tab:blue", alpha=0.7, label=f"Predicted (Train, R2={r2_train:.2f})")
ax.plot(y_true_val.index,   y_pred_val,   color="tab:orange", alpha=0.7, label=f"Predicted (Val, R2={r2_val:.2f})")
ax.plot(y_true_test.index,  y_pred_test,  color="tab:green", alpha=0.7, label=f"Predicted (Test, R2={r2_test:.2f})")

# Highlight test period
ax.axvspan(y_true_test.index[0], y_true_test.index[-1], color="gray", alpha=0.1)

# Labels and title
ax.set_title("Observed vs Predicted GPP", fontsize=12)
ax.set_ylabel("GPP (gC/m2/day)")
ax.set_xlabel("Date")

# Legend
ax.legend(loc="lower left", fontsize=10, frameon=False, ncol=4)

plt.tight_layout()
plt.show()

## 6. Explaining the Model

So far, we treated the LSTM as a **black box**: it takes 90 days of inputs and produces a GPP prediction.  
But we have not asked *why* the model makes a certain prediction. Which drivers (e.g., radiation, precipitation, VPD) were most responsible? And which days in the 90-day window mattered most?

To answer these questions, we use **Integrated Gradients (IG)**.

IG is a method from explainable AI designed for neural networks.  
The idea is:

- We choose a **baseline input** (a reference with “no information”).  
- We compare the model’s prediction on the actual input against its prediction on this baseline.  
- IG attributes this difference to the individual input values  day by day, feature by feature.  

The result is an attribution array with the same shape as the input sequence.  
Each number tells us how much that particular feature on that particular day pushed the prediction up or down, relative to the baseline.

This way, we can move from “the model predicts GPP = 5.2” to “this prediction was mainly driven by high radiation in the last week, partly offset by high VPD earlier in the window.”

In [None]:
import matplotlib.dates as mdates

def _date_cell_edges(dates):
    """Helper: compute left/right edges of each date cell for pcolormesh."""
    x = mdates.date2num(dates.to_pydatetime())  # centers
    return np.concatenate(([x[0] - 0.5], (x[:-1] + x[1:]) / 2, [x[-1] + 0.5]))

def plot_input_with_attribution(
    df, dates_window,
    input_seq, attribution_seq,
    feature_names,
    target_col=None, pred_value=None, obs_value=None,
    baseline_pred_value=None, baseline_input_seq=None,
    cmap="BrBG", show_colorbar=True
):
    """Plot drivers (features) with attributions as heatmaps.
    If target_col / predictions are provided, adds a bottom panel for them.
    """

    T, F = input_seq.shape
    show_target_panel = target_col is not None and pred_value is not None and obs_value is not None

    nrows = F + 1 if show_target_panel else F
    fig, axes = plt.subplots(
        nrows=nrows, ncols=1, figsize=(10, 6), sharex=True, gridspec_kw={'hspace': 0.4}
    )

    if nrows == 1:
        axes = [axes]  # ensure iterable

    vmax = np.quantile(np.abs(attribution_seq), 0.99) + 1e-8
    x_edges = _date_cell_edges(dates_window)

    # --- Feature panels ---
    for i in range(F):
        ax = axes[i]
        z = attribution_seq[:, i][None, :]  # (1, T)
        ymin, ymax = float(input_seq[:, i].min()), float(input_seq[:, i].max())
        X = np.vstack([x_edges, x_edges])   # (2, T+1)
        Y = np.vstack([np.full_like(x_edges, ymin), np.full_like(x_edges, ymax)])

        pcm = ax.pcolormesh(X, Y, z, cmap=cmap, vmin=-vmax, vmax=vmax, shading="flat")

        # Overlay input series
        ax.plot(dates_window, input_seq[:, i], color="black", linewidth=1.2)
        if baseline_input_seq is not None:
            ax.plot(dates_window, baseline_input_seq[:, i],
                    color="blue", linewidth=1.0, alpha=0.9, label="Baseline input")
            if i == 0:
                ax.legend(loc="center left", fontsize=9, frameon=False, bbox_to_anchor=(1.0, 0.5))

        # Title
        varname = feature_names[i]
        ax.set_title(f"{full_names[varname]} ({units[varname]})", fontsize=10, loc="left")
        ax.grid(True, linestyle="--", alpha=0.3)

    # --- Optional target panel ---
    if show_target_panel:
        ax = axes[-1]
        ax.plot(dates_window, df.loc[dates_window, target_col],
                color="black", linewidth=1.4, label="Observed")
        ax.scatter(dates_window[-1], pred_value, color="red", zorder=3, label="Pred (target)")
        if baseline_pred_value is not None:
            ax.scatter(dates_window[-1], baseline_pred_value, color="blue", zorder=3, label="Baseline pred")
        ax.set_title(f"{full_names[target_col]} ({units[target_col]})", fontsize=10, loc="left")
        ax.legend(fontsize=9, loc="lower left", frameon=False, bbox_to_anchor=(1.0, -0.1))

        # Date formatting
        ax.xaxis.set_major_locator(mdates.AutoDateLocator())
        ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(ax.xaxis.get_major_locator()))

    # Shared colorbar
    if show_colorbar:
        fig.colorbar(pcm, ax=axes, orientation="vertical", shrink=0.7, label="Attribution")

    return fig


### 6.1. A first look at IG on one sample

We apply IG to one test sequence (90 days of inputs).  

IG returns an **attribution array** with the same shape as the input: `[time steps, features]`.  
Each entry shows how much a given feature on a given day contributed to the prediction:

- Positive → pushed the prediction *up* (relative to baseline)  
- Negative → pushed the prediction *down* (relative to baseline)  

By design, the attributions add up to the prediction difference:

$$
\text{Prediction} - \text{Baseline prediction} \;\approx\; \sum_{t=1}^{T}\sum_{f=1}^{F} \mathrm{IG}_{t,f}
$$

A key part of IG is the **baseline**.  
This is the input we use as a reference, i.e., the model’s prediction when “nothing informative” is given.  
IG then attributes the change from this baseline prediction to the actual prediction.

Here we use a **zero baseline**, which means all input features are set to zero in the standardized space.  
Intuitively, this represents “neutral” conditions after scaling.  

Later we will also test other baselines (e.g. climatology) to see how the choice of baseline affects the explanation.

In [None]:
model.eval()
ig = IntegratedGradients(model)

# --- 1. Pick a sample ---
idx = 150
x = X_test_t[idx:idx+1]                     # input sequence [1, T, F]
dates_win = df_test.index[idx:idx+seq_len]  # dates for the window

# --- 2. Define the baseline ---
# Here: all zeros in standardized space (we discuss baseline after seeing IG)
baseline = torch.zeros_like(x)

# --- 3. Compute attributions with IG ---
attr, _ = ig.attribute(x, baselines=baseline, return_convergence_delta=True)
attr = attr.squeeze(0).detach().cpu().numpy()   # shape [T, F]

# --- 4. Predictions in scaled space ---
with torch.no_grad():
    pred_scaled = model(x).cpu().numpy().ravel()[0]       # model prediction
    base_scaled = model(baseline).cpu().numpy().ravel()[0]# baseline prediction

diff_scaled = pred_scaled - base_scaled   # difference explained by IG
sum_attr = float(attr.sum())              # sum of all attributions
comp_err = abs(diff_scaled - sum_attr)    # completeness check

# --- 5. Convert to original units ---
pred = y_scaler.inverse_transform([[pred_scaled]])[0,0]
obs  = y_scaler.inverse_transform(y_test_t[idx:idx+1].cpu().numpy())[0,0]
base = y_scaler.inverse_transform([[base_scaled]])[0,0]

diff_unscaled = y_scaler.inverse_transform([[diff_scaled]])[0,0]
sum_attr_unscaled = y_scaler.inverse_transform([[sum_attr]])[0,0]

# --- 6. Print results ---
print(f"Sample index: {idx}  |  Window: {dates_win[0].date()} → {dates_win[-1].date()}")
print(f"Observed GPP:          {obs:.3f} gC/m2/day")
print(f"Model prediction:      {pred:.3f} gC/m2/day")
print(f"Baseline prediction:   {base:.3f} gC/m2/day")
print()
print("IG completeness check:")
print(f"  Prediction − Baseline  = {diff_scaled:.5f} (scaled)  ≈  {diff_unscaled:.3f} (unscaled)")
print(f"  Sum of IG attributions = {sum_attr:.5f} (scaled)  ≈  {sum_attr_unscaled:.3f} (unscaled)")
print(f"  Error in equality      = {comp_err:.2e}")
print()
print(f"Attribution array shape: {attr.shape}  (time steps × features)")

### 6.2. Visualizing IG along inputs

To better understand the attributions, we can overlay them directly on the input time series.  
For each feature, the line shows the actual input values, while the background color shows the IG attribution:  

- Green → this day’s input **increased** the prediction (relative to baseline)  
- Brown → this day’s input **decreased** the prediction (relative to baseline)  

The bottom panel shows the observed GPP, the model’s prediction, and the baseline prediction for the target day.  

In [None]:
fig = plot_input_with_attribution(
    df=df_test,                      # DataFrame slice
    dates_window=dates_win,          # the 90-day window of dates
    input_seq=X_scaler.inverse_transform(x.squeeze(0).cpu().numpy()),   # inputs in original units
    attribution_seq=attr,            # IG attributions
    feature_names=feature_cols,      # list of input names
    target_col=target_col,           # "gpp"
    pred_value=pred,
    obs_value=obs,
    baseline_pred_value=base,
    baseline_input_seq=X_scaler.inverse_transform(baseline.squeeze(0).cpu().numpy()) # optional
)

### 6.3. Aggregating attributions: feature perspective  

Because IG attributions are **additive**, we can aggregate them to see which features mattered most over the whole 90-day window.  

Two useful summaries are:  

- **Net contribution** (sum of IG values): shows whether a feature overall pushed the prediction *up* or *down*.  
- **Mean absolute attribution** (importance): shows how strongly the model relied on a feature, regardless of direction.  

In [None]:
feat_signed = attr.sum(axis=0)       # net effect per feature
feat_import = np.mean(np.abs(attr), axis=0)  # strength regardless of sign

tbl = pd.DataFrame({
    "net_contribution": feat_signed,
    "mean_abs_importance": feat_import
}, index=feature_cols)

display(tbl.round(4))

# --- Bar plots ---
fig, ax = plt.subplots(1, 2, figsize=(10, 3), constrained_layout=True)

ax[0].bar(feature_cols, feat_signed, color="tab:blue")
ax[0].axhline(0, color="black", linewidth=0.8)
ax[0].set_title("Net contribution (sum IG)")
ax[0].set_ylabel("IG (scaled output units)")
ax[0].tick_params(axis="x", rotation=45)

ax[1].bar(feature_cols, feat_import, color="tab:orange")
ax[1].set_title("Importance (mean |IG|)")
ax[1].set_ylabel("IG (scaled output units)")
ax[1].tick_params(axis="x", rotation=45)

plt.show()

### 6.4. Aggregating attributions: time perspective  

We can also summarize IG across features for each day.  
This shows *when* in the 90-day window the model relied most on the inputs.  

The result is a time series of daily importance values, highlighting which periods in the input sequence mattered most for this prediction.

In [None]:
# Aggregate IG attributions over features → importance per time step
time_import = np.mean(np.abs(attr), axis=1)   # length T (one value per day in the window)

# Plot
plt.figure(figsize=(10, 3))
plt.plot(dates_win, time_import, color="darkred")
plt.ylabel("Mean |IG| across features")
plt.title("Per-day importance within the 90-day window")

ax = plt.gca()
ax.xaxis.set_major_locator(mdates.AutoDateLocator())
ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(ax.xaxis.get_major_locator()))
plt.show()

# Print top-5 most influential days
topk = np.argsort(-time_import)[:5]
print("Top-5 days by mean |IG|:")
for i in topk:
    print(f"  {dates_win[i].date()}   mean|IG|={time_import[i]:.5f}   signed_sum={attr[i].sum():.5f}")

### 6.5. Comparing a high-GPP and a low-GPP day  

To better understand what IG reveals, we compare two contrasting samples:  
- one where the observed GPP is high,  
- and one where the observed GPP is low.  

By looking at the attribution patterns side by side, we can see how the model emphasizes different drivers depending on ecosystem activity.  


In [None]:
high_idx = np.argmax(y_test_t.numpy())   # index of max GPP in test set
low_idx  = np.argmin(y_test_t.numpy())   # index of min GPP in test set

for idx, label in [(high_idx, "High GPP sample"), (low_idx, "Low GPP sample")]:
    
    # 1. Slice input and dates
    x = X_test_t[idx:idx+1]
    dates_win = df_test.index[idx:idx+seq_len]
    baseline = torch.zeros_like(x)
    
    # 2. Compute IG
    attr, _ = ig.attribute(x, baselines=baseline, return_convergence_delta=True)
    attr = attr.squeeze(0).detach().cpu().numpy()
    
    # 3. Predictions
    with torch.no_grad():
        pred_scaled = model(x).cpu().numpy().ravel()[0]
        base_scaled = model(baseline).cpu().numpy().ravel()[0]
    pred = y_scaler.inverse_transform([[pred_scaled]])[0,0]
    obs  = y_scaler.inverse_transform(y_test_t[idx:idx+1].cpu().numpy())[0,0]
    base = y_scaler.inverse_transform([[base_scaled]])[0,0]

    # 4. Plot with attribution overlay
    fig = plot_input_with_attribution(
        df=df_test,
        dates_window=dates_win,
        input_seq=X_scaler.inverse_transform(x.squeeze(0).numpy()),
        attribution_seq=attr,
        feature_names=feature_cols,
        target_col=target_col,
        pred_value=pred,
        obs_value=obs,
        baseline_pred_value=base,
        baseline_input_seq=X_scaler.inverse_transform(baseline.squeeze(0).numpy())
    )
    fig.suptitle(f"{label}", fontsize=12)

# II. Pitfalls and Good Practice in XAI

Applying an explainability method is straightforward: a few lines of code give you attributions, heatmaps, or importance scores. 
But interpreting these results is the real challenge. In scientific applications like geoscience, careless use can lead to misleading or even wrong conclusions.

Below we summarize several common pitfalls and points to keep in mind.

## 1. Baseline choice matters

Integrated Gradients requires a **baseline** (or reference input).  
All attributions are computed **relative to this baseline**.

So far, we used a **zero baseline** (all inputs = 0 in standardized space).  
This is convenient and ensures the math works — but it may not always be scientifically meaningful.

Why does this matter?  
- The baseline defines what the model’s *“no information”* state is.  
- Changing the baseline can shift the interpretation of contributions.  
- In geoscience, a zero input may not correspond to any realistic condition.  

For example, we can compare two baselines for the same test sequence:  
- Zero (all inputs = 0 after scaling)  
- Climatology (the average seasonal cycle for each day of year)

Different baselines yield different attributions, and both are valid in the sense of the IG framework, but their **scientific meaning** differs.  
That’s why you must always document and reflect on your baseline choice.

In [None]:
idx = 150
x = X_test_t[idx:idx+1]
dates_win = df_test.index[idx:idx+seq_len]

# --- Zero baseline ---
baseline_zero = torch.zeros_like(x)
attr_zero, _ = ig.attribute(x, baselines=baseline_zero, return_convergence_delta=True)
attr_zero = attr_zero.squeeze(0).detach().cpu().numpy()

with torch.no_grad():
    pred_scaled = model(x).cpu().numpy().ravel()[0]
    base_zero_scaled = model(baseline_zero).cpu().numpy().ravel()[0]
pred = y_scaler.inverse_transform([[pred_scaled]])[0,0]
obs  = y_scaler.inverse_transform(y_test_t[idx:idx+1].cpu().numpy())[0,0]
base_zero = y_scaler.inverse_transform([[base_zero_scaled]])[0,0]

fig = plot_input_with_attribution(
    df=df_test,
    dates_window=dates_win,
    input_seq=X_scaler.inverse_transform(x.squeeze(0).numpy()),
    attribution_seq=attr_zero,
    feature_names=feature_cols,
    target_col=target_col,
    pred_value=pred,
    obs_value=obs,
    baseline_pred_value=base_zero,
    baseline_input_seq=X_scaler.inverse_transform(baseline_zero.squeeze(0).numpy())
)
fig.suptitle(f"Sample (idx = {idx}): IG with Zero Baseline", fontsize=12)

# --- Climatology baseline ---
clim_values = df_train.groupby(df_train.index.dayofyear).mean()
clim_seq = clim_values.loc[dates_win.dayofyear, feature_cols]
baseline_clim = torch.tensor(
    X_scaler.transform(clim_seq),  # preserves DataFrame with feature names
    dtype=torch.float32
).unsqueeze(0)

attr_clim, _ = ig.attribute(x, baselines=baseline_clim, return_convergence_delta=True)
attr_clim = attr_clim.squeeze(0).detach().cpu().numpy()

with torch.no_grad():
    base_clim_scaled = model(baseline_clim).cpu().numpy().ravel()[0]
base_clim = y_scaler.inverse_transform([[base_clim_scaled]])[0,0]

fig = plot_input_with_attribution(
    df=df_test,
    dates_window=dates_win,
    input_seq=X_scaler.inverse_transform(x.squeeze(0).numpy()),
    attribution_seq=attr_clim,
    feature_names=feature_cols,
    target_col=target_col,
    pred_value=pred,
    obs_value=obs,
    baseline_pred_value=base_clim,
    baseline_input_seq=X_scaler.inverse_transform(baseline_clim.squeeze(0).numpy())
)
fig.suptitle(f"Sample (idx = {idx}): IG with Climatology Baseline", fontsize=12)

Notice how the attribution patterns differ between the two baselines.  
With the zero baseline, certain inputs appear strongly influential, while with the climatology baseline their contribution is weaker or even changes sign.  

The key message:  
**Integrated Gradients explains predictions relative to the chosen baseline. Interpret results only in that context. An inappropriate baseline can shift the story and lead to misleading conclusions.**

## 2. Attribution method choice matters

Different attribution methods can give different results on the same sample.  
This is expected: each method has its own assumptions and sensitivity.

As good practice, it is worth checking at least two methods.  
If they give consistent signals, confidence in the interpretation increases.  
If they differ, it shows the explanation is method-dependent and should be treated with caution.

Here we compare **Integrated Gradients** with a simple **Saliency map**.

- **Saliency** computes the gradient of the output with respect to the input.  
  It tells us which small local changes in the input would most affect the prediction.  
  Saliency does *not* use a baseline.

- **Integrated Gradients** accumulates gradients along a path from a baseline input to the actual input.  
  It distributes the difference between baseline prediction and actual prediction across all input values.  
  IG therefore requires a baseline, and the attributions are additive by design.

In [None]:
idx = 150
x = X_test_t[idx:idx+1]
dates_win = df_test.index[idx:idx+seq_len]

# --- IG attributions (baseline = zero) ---
ig = IntegratedGradients(model)
attr_ig, _ = ig.attribute(x, baselines=torch.zeros_like(x), return_convergence_delta=True)
attr_ig = attr_ig.squeeze(0).detach().cpu().numpy()

# --- Saliency attributions ---
sal = Saliency(model)
attr_sal = sal.attribute(x, abs=False)
attr_sal = attr_sal.squeeze(0).detach().cpu().numpy()

# --- Plot IG (inputs only) ---
fig = plot_input_with_attribution(
    df=df_test,
    dates_window=dates_win,
    input_seq=X_scaler.inverse_transform(x.squeeze(0).numpy()),
    attribution_seq=attr_ig,
    feature_names=feature_cols,
    show_colorbar=True
)
fig.suptitle(f"Sample (idx = {idx}): Integrated Gradients", fontsize=12)
plt.show()

# --- Plot Saliency (inputs only) ---
fig = plot_input_with_attribution(
    df=df_test,
    dates_window=dates_win,
    input_seq=X_scaler.inverse_transform(x.squeeze(0).numpy()),
    attribution_seq=attr_sal,
    feature_names=feature_cols,
    show_colorbar=True
)
fig.suptitle(f"Sample (idx = {idx}): Saliency", fontsize=12)
plt.show()

Notice the difference. Which view is more useful depends on the question:  

- If your question is *“What parts of this input sequence most contributed to the prediction compared to a reference state?”* → IG is more suitable.  
- If your question is *“Where is the model most sensitive to small perturbations right now?”* → Saliency is appropriate.  

**Key message:** explanations are **method-dependent**.  
Always state which method you used and how it matches the scientific question.  

## 3. Explanations reflect the model, not the true process  

XAI methods explain how *this model* makes predictions, not how the real-world system works.  

To see this, we can remove an important driver (temperature) and retrain the model.  
If predictive performance remains acceptable, the attributions will still “explain” the new model, but those explanations may differ, because the model has changed.  

This shows that interpretability is always tied to the specific model you trained,  
not a guarantee of causal understanding of the system.


In [None]:
# --- Drop one feature: Air Temperature ---
drop_feature = "air_temp"
keep_features = [f for f in feature_cols if f != drop_feature]
keep_idx = [i for i, f in enumerate(feature_cols) if f != drop_feature]

# Directly subset tensors (already windowed)
X_train_red = X_train_t[:, :, keep_idx]
X_val_red   = X_val_t[:, :, keep_idx]
X_test_red  = X_test_t[:, :, keep_idx]

y_train_red, y_val_red, y_test_red = y_train_t, y_val_t, y_test_t

# --- Fit scaler on reduced features (for inverse_transform later) ---
X_scaler_red = StandardScaler().fit(df_train[keep_features].values)

# --- Train a new model with reduced features ---
model_red = init_model(len(keep_features), hidden_size=64, seed=42, device=device)

train_loader_red = DataLoader(TensorDataset(X_train_red, y_train_red), batch_size=batch_size, shuffle=True)
val_loader_red   = DataLoader(TensorDataset(X_val_red, y_val_red), batch_size=batch_size, shuffle=False)
test_loader_red  = DataLoader(TensorDataset(X_test_red, y_test_red), batch_size=batch_size, shuffle=False)

optimizer = torch.optim.Adam(model_red.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

model_red, train_losses_red, val_losses_red = train_model(
    model_red, train_loader_red, val_loader_red,
    criterion, optimizer, scheduler, device
)

# --- Evaluate ---
r2_train_red, y_true_train_red, y_pred_train_red = evaluate_model(model_red, X_train_red, y_train_red, dates_train, y_scaler)
r2_val_red,   y_true_val_red,   y_pred_val_red   = evaluate_model(model_red, X_val_red, y_val_red, dates_val, y_scaler)
r2_test_red,  y_true_test_red,  y_pred_test_red  = evaluate_model(model_red, X_test_red, y_test_red, dates_test, y_scaler)

print(f"Full model R2 (test): {r2_test:.3f}")
print(f"Reduced model R2 (test, without {drop_feature}): {r2_test_red:.3f}")

In [None]:
idx = 120
x_full = X_test_t[idx:idx+1]
x_red  = X_test_red[idx:idx+1]
dates_win = df_test.index[idx:idx+seq_len]

# ---------------- FULL MODEL ----------------
ig_full = IntegratedGradients(model)
attr_full, _ = ig_full.attribute(
    x_full.to(device),
    baselines=torch.zeros_like(x_full).to(device),
    return_convergence_delta=True
)
attr_full = attr_full.squeeze(0).detach().cpu().numpy()

with torch.no_grad():
    pred_scaled = model(x_full.to(device)).cpu().numpy().ravel()[0]
    base_scaled = model(torch.zeros_like(x_full).to(device)).cpu().numpy().ravel()[0]

pred_full = y_scaler.inverse_transform([[pred_scaled]])[0,0]
obs_full  = y_scaler.inverse_transform(y_test_t[idx:idx+1].cpu().numpy())[0,0]
base_full = y_scaler.inverse_transform([[base_scaled]])[0,0]

fig = plot_input_with_attribution(
    df=df_test,
    dates_window=dates_win,
    input_seq=X_scaler.inverse_transform(x_full.squeeze(0).cpu().numpy()),
    attribution_seq=attr_full,
    feature_names=feature_cols,
    target_col=target_col,
    pred_value=pred_full,
    obs_value=obs_full,
    baseline_pred_value=base_full,
    baseline_input_seq=X_scaler.inverse_transform(torch.zeros_like(x_full).squeeze(0).cpu().numpy())
)
fig.suptitle(f"Full model (idx={idx}): IG attributions", fontsize=12)
plt.show()


# ---------------- REDUCED MODEL ----------------
ig_red = IntegratedGradients(model_red)
attr_red, _ = ig_red.attribute(
    x_red.to(device),
    baselines=torch.zeros_like(x_red).to(device),
    return_convergence_delta=True
)
attr_red = attr_red.squeeze(0).detach().cpu().numpy()

with torch.no_grad():
    pred_scaled_red = model_red(x_red.to(device)).cpu().numpy().ravel()[0]
    base_scaled_red = model_red(torch.zeros_like(x_red).to(device)).cpu().numpy().ravel()[0]

pred_red = y_scaler.inverse_transform([[pred_scaled_red]])[0,0]
obs_red  = y_scaler.inverse_transform(y_test_red[idx:idx+1].cpu().numpy())[0,0]
base_red = y_scaler.inverse_transform([[base_scaled_red]])[0,0]

X_scaler_red = StandardScaler().fit(
    df_train[keep_features].values.reshape(-1, len(keep_features))
)

fig = plot_input_with_attribution(
    df=df_test,
    dates_window=dates_win,
    input_seq=X_scaler_red.inverse_transform(x_red.squeeze(0).cpu().numpy()),
    attribution_seq=attr_red,
    feature_names=keep_features,
    target_col=target_col,
    pred_value=pred_red,
    obs_value=obs_red,
    baseline_pred_value=base_red,
    baseline_input_seq=X_scaler_red.inverse_transform(torch.zeros_like(x_red).squeeze(0).cpu().numpy())
)
fig.suptitle(f"Reduced model (idx={idx}, without {drop_feature}): IG attributions", fontsize=12)
plt.show()

Notice how the attributions change when we retrain the model without *Air Temperature*.  

Even when predictive performance remains similar, the attribution patterns can shift considerably.  
This illustrates a key principle: XAI methods explain **how the model makes predictions, not how the real system works**.  
When an important driver is missing, the model redistributes importance to the remaining inputs, especially those that are correlated.  

Interpretability results should therefore always be understood in the context of the model’s design and available features.  

## 4. Explanations depend on model randomness

Neural networks are trained with random initialization and stochastic optimization.  
Even with the same data and hyperparameters, changing the random seed can produce models with similar predictive accuracy but different attribution patterns.  

This means that explanations also carry uncertainty: they reflect *this particular trained model*, not a guaranteed truth about the data.  
For more robust insights, check whether attribution patterns are consistent across multiple training runs.


In [None]:
seeds = [42, 5555]
attributions = []

idx = 150
x = X_test_t[idx:idx+1]
dates_win = df_test.index[idx:idx+seq_len]

for seed in seeds:
    # --- Train a fresh model with a different seed ---
    model_seed = init_model(len(feature_cols), hidden_size=64, seed=seed, device=device)
    optimizer = torch.optim.Adam(model_seed.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    train_loader = DataLoader(TensorDataset(X_train_t, y_train_t), batch_size=64, shuffle=True)
    val_loader   = DataLoader(TensorDataset(X_val_t, y_val_t), batch_size=64, shuffle=False)
    
    model_seed, _, _ = train_model(
        model_seed, train_loader, val_loader,
        criterion, optimizer, scheduler, device, 
    )
    
    # --- Attribution (IG) ---
    ig = IntegratedGradients(model_seed)
    attr, _ = ig.attribute(
        x.to(device),
        baselines=torch.zeros_like(x).to(device),
        return_convergence_delta=True
    )
    attributions.append(attr.squeeze(0).detach().cpu().numpy())
    
    # --- Prediction ---
    with torch.no_grad():
        pred_scaled = model_seed(x.to(device)).cpu().numpy().ravel()[0]
    pred = y_scaler.inverse_transform([[pred_scaled]])[0,0]
    obs  = y_scaler.inverse_transform(y_test_t[idx:idx+1].cpu().numpy())[0,0]
    
    # --- Plot ---
    fig = plot_input_with_attribution(
        df=df_test,
        dates_window=dates_win,
        input_seq=X_scaler.inverse_transform(x.squeeze(0).cpu().numpy()),
        attribution_seq=attributions[-1],
        feature_names=feature_cols,
        target_col=target_col,
        pred_value=pred,
        obs_value=obs,
        baseline_pred_value=None,
        baseline_input_seq=None
    )
    fig.suptitle(f"Seed {seed}: IG attributions (idx={idx})", fontsize=12)
    plt.show()

As a result, attribution maps may vary across runs, even when predictive accuracy is stable.  

**Good practice:**  
- Train models with multiple random seeds and compare explanations for consistency.  
- Repeat analyses with different data splits to see whether patterns hold.  

<img src="img/slide3.jpg"/>

If you want to learn more about when, where, and how XAI can be used to understand processes in Earth and climate research, refer to Jiang et al. (2024). 

Jiang, S., Sweet, L.-b., Blougouras, G., Brenning, A., Li, W., Reichstein, M., et al. (2024). How interpretable machine learning can benefit process understanding in the geosciences. Earth's Future, 12, e2024EF004540. https://doi.org/10.1029/2024EF004540