## Explainability of the DeepESD model

### Set the data

In [1]:
DATA_PATH = './data/input'
FIGURES_PATH = './figures'
MODELS_PATH = './models'

In [2]:
import xarray as xr
import torch
import captum

import sys; sys.path.append('/home/jovyan/deep4downscaling')
import deep4downscaling.viz
import deep4downscaling.trans
import deep4downscaling.deep.models
import deep4downscaling.deep.xai

  from .autonotebook import tqdm as notebook_tqdm


Explain here why we need to reproduce all the data preprocessing. We will focus on the test set.

In [3]:
# Load predictor
predictor_filename = f'{DATA_PATH}/ERA5_NorthAtlanticRegion_1-5dg_full.nc'
predictor = xr.open_dataset(predictor_filename)

# Load predictand
predictand_filename = f'{DATA_PATH}/pr_AEMET.nc'
predictand = xr.open_dataset(predictand_filename)

# Remove days with nans in the predictor
predictor = deep4downscaling.trans.remove_days_with_nans(predictor)

# Align both datasets in time
predictor, predictand = deep4downscaling.trans.align_datasets(predictor, predictand, 'time')

# Split data into training and test sets
years_train = ('1980', '2010')
years_test = ('2011', '2020')

x_train = predictor.sel(time=slice(*years_train))
y_train = predictand.sel(time=slice(*years_train))

x_test = predictor.sel(time=slice(*years_test))
y_test = predictand.sel(time=slice(*years_test))

# Standardize the test predictors w.r.t. to the training ones
x_test_stand = deep4downscaling.trans.standardize(data_ref=x_train, data=x_test)

# Set predictand masking
y_mask = deep4downscaling.trans.compute_valid_mask(y_train) 

y_train_stack = y_train.stack(gridpoint=('lat', 'lon'))
y_mask_stack = y_mask.stack(gridpoint=('lat', 'lon'))

y_mask_stack_filt = y_mask_stack.where(y_mask_stack==1, drop=True)
y_train_stack_filt = y_train_stack.where(y_train_stack['gridpoint'] == y_mask_stack_filt['gridpoint'],
                                             drop=True)
# Convert data from xarray to numpy
x_test_stand_arr = deep4downscaling.trans.xarray_to_numpy(x_test_stand)
y_train_arr = deep4downscaling.trans.xarray_to_numpy(y_train_stack_filt)

There are no observations containing null values


Set device

In [5]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')

Load the model to explain

In [6]:
model_name = 'deepesd_pr'
model = deep4downscaling.deep.models.DeepESDpr(x_shape=x_test_stand_arr.shape,
                                               y_shape=y_train_arr.shape,
                                               filters_last_conv=1,
                                               stochastic=False)
model.load_state_dict(torch.load(f'{MODELS_PATH}/{model_name}.pt'))

  model.load_state_dict(torch.load(f'{MODELS_PATH}/{model_name}.pt'))


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

### Explainability

XAI technique used

In [8]:
xai_method = captum.attr.Saliency(model)

#### Integrated Saliency Map (ISM)

In [12]:
spatial_coord = (43.125797, -8.087920)

ism = deep4downscaling.deep.xai.compute_ism(data=x_test_stand,
                                            mask=y_mask.copy(deep=True),
                                            model=model, device=device,
                                            xai_method=xai_method,
                                            coord=spatial_coord,
                                            postprocess=True)

Computing ISMs...


In [14]:
time_to_plot = '01-02-2018'
deep4downscaling.viz.multiple_map_plot(data=ism.sel(time=time_to_plot),
                                       colorbar='hot_r',
                                       output_path=f'./{FIGURES_PATH}/ism.pdf')

#### Aggregated Saliency Map (ASM)

In [15]:
time_slice = ('01-01-2011', '03-01-2011')
asm = deep4downscaling.deep.xai.compute_asm(data=x_test_stand.sel(time=slice(*time_slice)),
                                            mask=y_mask.copy(deep=True),
                                            model=model, device=device,
                                            xai_method=xai_method,
                                            batch_size=1024,
                                            postprocess=True)

Computing ASMs...


100%|██████████| 60/60 [02:22<00:00,  2.37s/it]


In [47]:
deep4downscaling.viz.multiple_map_plot(data=asm,
                                       colorbar='hot_r',
                                       output_path=f'./{FIGURES_PATH}/asm.pdf')

Saliency Dispersion Map (SDM)

In [None]:
time_slice = ('01-01-2011', '03-01-2011')
asm = deep4downscaling.deep.xai.compute_sdm(data=x_test_stand.sel(time=slice(*time_slice)),
                                            mask=y_mask.copy(deep=True),
                                            model=model, device=device,
                                            xai_method=xai_method,
                                            batch_size=1024,
                                            postprocess=True)