## Downscaling with the DeepESD model

### Train the model

In [49]:
DATA_PATH = './data/input'
FIGURES_PATH = './figures'
MODELS_PATH = './models'
ASYM_PATH = './data/asym'

In [66]:
import xarray as xr
import torch
from torch.utils.data import DataLoader, random_split

import sys; sys.path.append('/home/jovyan/deep4downscaling')
import deep4downscaling.viz
import deep4downscaling.trans
import deep4downscaling.deep.loss
import deep4downscaling.deep.utils
import deep4downscaling.deep.models
import deep4downscaling.deep.train
import deep4downscaling.deep.pred
import deep4downscaling.metrics_ccs

In [16]:
# Load predictors
predictor_filename = f'{DATA_PATH}/ERA5_NorthAtlanticRegion_1-5dg_full.nc'
predictor = xr.open_dataset(predictor_filename)

In [17]:
predictor

In [18]:
# Visualize predictors
deep4downscaling.viz.multiple_map_plot(data=predictor.mean('time'),
                                       output_path=f'./{FIGURES_PATH}/predictor_climatology.pdf')

In [19]:
# Load predictand
predictand_filename = f'{DATA_PATH}/pr_AEMET.nc'
predictand = xr.open_dataset(predictand_filename)

In [20]:
predictand

In [21]:
# Visualize predictand
day_to_viz = '10-04-2015'
deep4downscaling.viz.simple_map_plot(data=predictand.sel(time=day_to_viz),
                                     colorbar='hot_r', var_to_plot='pr',
                                     output_path=f'./{FIGURES_PATH}/predictand_day.pdf')

In [22]:
# Remove days with nans in the predictor
predictor = deep4downscaling.trans.remove_days_with_nans(predictor)

There are no observations containing null values


In [23]:
# Align both datasets in time
predictor, predictand = deep4downscaling.trans.align_datasets(predictor, predictand, 'time')

In [24]:
# Subset into training and test sets
years_train = ('1980', '2010')
years_test = ('2011', '2020')

x_train = predictor.sel(time=slice(*years_train))
y_train = predictand.sel(time=slice(*years_train))

x_test = predictor.sel(time=slice(*years_test))
y_test = predictand.sel(time=slice(*years_test))

In [25]:
# Standardize the predictors
x_train_stand = deep4downscaling.trans.standardize(data_ref=x_train, data=x_train)

In [26]:
# Compute a mask of non-NaN values. This is required to reshape the deep learning model's prediction
# into a valid format
y_mask = deep4downscaling.trans.compute_valid_mask(y_train) 

In [27]:
# Plot the mask
deep4downscaling.viz.simple_map_plot(data=y_mask, var_to_plot='pr',
                                     output_path=f'./{FIGURES_PATH}/predictand_mask.pdf')

In [28]:
# Stack in one dimension (gridpoint)
y_train_stack = y_train.stack(gridpoint=('lat', 'lon'))
y_mask_stack = y_mask.stack(gridpoint=('lat', 'lon'))

In [29]:
# Remove NaNs following y_mask. This is useful for models with a
# final fully-connected layer
y_mask_stack_filt = y_mask_stack.where(y_mask_stack==1, drop=True)

In [30]:
# Remove grid points full of nans (sea-points)
y_train_stack_filt = y_train_stack.where(y_train_stack['gridpoint'] == y_mask_stack_filt['gridpoint'],
                                             drop=True) # Filter y_train w.r.t. y_mask

In [31]:
# It is possible to preprocess the precipitation to ....

In [32]:
# There are plenty of loss function available, in this case we follow \citep{} and focus on ASYM
loss_function = deep4downscaling.deep.loss.Asym(ignore_nans=True,
                                                asym_path=ASYM_PATH)

# Load or compute the Gamma distributions required to compute this loss function
if loss_function.parameters_exist():
    loss_function.load_parameters()
else:
# It is important to always compute the ASYM parameters using the full
# predictand domain (including NaNs) to avoid shape issues
# when computing the loss function during model training
    loss_function.compute_parameters(data=y_train_stack,
                                     var_target='pr')

In [33]:
# Convert data from xarray to numpy
x_train_stand_arr = deep4downscaling.trans.xarray_to_numpy(x_train_stand)
y_train_arr = deep4downscaling.trans.xarray_to_numpy(y_train_stack_filt)

In [34]:
# Create Dataset
train_dataset = deep4downscaling.deep.utils.StandardDataset(x=x_train_stand_arr,
                                                            y=y_train_arr)

In [35]:
# Split into training and validation sets
train_dataset, valid_dataset = random_split(train_dataset,
                                            [0.9, 0.1])

In [36]:
# Create DataLoaders
batch_size = 64

train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size,
                              shuffle=True)

In [37]:
# Load the DL model (explain some parameters, or may reference the help command)
model_name = 'deepesd_pr'
model = deep4downscaling.deep.models.DeepESDpr(x_shape=x_train_stand_arr.shape,
                                               y_shape=y_train_arr.shape,
                                               filters_last_conv=1,
                                               stochastic=False)

In [38]:
# No documentation is avaialble at this moment but functions are properly documented one by one
?deep4downscaling.deep.models.DeepESDpr

[0;31mInit signature:[0m
[0mdeep4downscaling[0m[0;34m.[0m[0mdeep[0m[0;34m.[0m[0mmodels[0m[0;34m.[0m[0mDeepESDpr[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mx_shape[0m[0;34m:[0m [0mtuple[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0my_shape[0m[0;34m:[0m [0mtuple[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfilters_last_conv[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstochastic[0m[0;34m:[0m [0mbool[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlast_relu[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
DeepESD model as proposed in Baño-Medina et al. 2024 for precipitation
downscaling. This implementation allows for a deterministic (MSE-based)
and stochastic (NLL-based) definition.

Baño-Medina, J., Manzanas, R., Cimadevilla, E., Fernández, J., González-Abad,
J., Cofiño, A. S., and Gutiérrez, J. M.: Downscaling multi-model cl

In [39]:
# Set the training hyperparameters.
num_epochs = 10000
patience_early_stopping = 60

learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate)

In [40]:
# Get device (mention both .yml)
device = ('cuda' if torch.cuda.is_available() else 'cpu')

In [41]:
# Move ASYM paramters to device
loss_function.prepare_parameters(device=device)

In [44]:
Train the model / How the model is saved and so on so forth
train_loss, val_loss = deep4downscaling.deep.train.standard_training_loop(
                            model=model, model_name=model_name, model_path=MODELS_PATH,
                            device=device, num_epochs=num_epochs,
                            loss_function=loss_function, optimizer=optimizer,
                            train_data=train_dataloader, valid_data=valid_dataloader,
                            patience_early_stopping=patience_early_stopping)

### Downscale the test set

In [43]:
# Compute the predictions on the test set. To do so, first test data is standardized using the training set as
# a reference. To avoid any out-of-memory (OOM) errors, predictions are computed in batches of 16. The prediction
# is saved in preds_path.
model.load_state_dict(torch.load(f'{MODELS_PATH}/{model_name}.pt'))

x_test_stand = deep4downscaling.trans.standardize(data_ref=x_train, data=x_test)

pred_test = deep4downscaling.deep.pred.compute_preds_standard(
                                x_data=x_test_stand, model=model,
                                device=device, var_target='pr',
                                mask=y_mask, batch_size=16)

  model.load_state_dict(torch.load(f'{MODELS_PATH}/{model_name}.pt'))


In [46]:
# Visualize the predictions
deep4downscaling.viz.simple_map_plot(data=pred_test.mean('time'),
                                     colorbar='hot_r', var_to_plot='pr',
                                     output_path=f'./{FIGURES_PATH}/prediction_test_mean.pdf')

### Downscale a Global Climate Model

In [58]:
# Load GCM data
gcm_hist = xr.open_dataset(f'{DATA_PATH}/EC-Earth3-Veg_r1i1p1f1_ssp370_hist.nc')

In [59]:
gcm_hist

In [60]:
gcm_fut = xr.open_dataset(f'{DATA_PATH}/EC-Earth3-Veg_r1i1p1f1_ssp370_fut.nc')

In [61]:
gcm_fut

In [62]:
#Before feeding the data to the model, and as explained in the manuscript, GCM predictors are first bias-corrected
# and then standardized.

gcm_hist_corrected = deep4downscaling.trans.scaling_delta_correction(data=gcm_hist,
                                                                     gcm_hist=gcm_hist, obs_hist=x_train)
gcm_fut_corrected = deep4downscaling.trans.scaling_delta_correction(data=gcm_fut,
                                                                    gcm_hist=gcm_hist, obs_hist=x_train)

gcm_hist_corrected_stand = deep4downscaling.trans.standardize(data_ref=x_train, data=gcm_hist_corrected)
gcm_fut_corrected_stand = deep4downscaling.trans.standardize(data_ref=x_train, data=gcm_fut_corrected)

In [64]:
# Compute the projections for the historical and future periods in a manner similar to the predictions for the test set

proj_historical = deep4downscaling.deep.pred.compute_preds_standard(
                    x_data=gcm_hist_corrected_stand, model=model,
                    device=device, var_target='pr',
                    mask=y_mask, batch_size=16)

proj_future = deep4downscaling.deep.pred.compute_preds_standard(
                    x_data=gcm_fut_corrected_stand, model=model,
                    device=device, var_target='pr',
                    mask=y_mask, batch_size=16)

In [65]:
# Visualize the projection
deep4downscaling.viz.simple_map_plot(data=proj_future.mean('time'),
                                     colorbar='hot_r', var_to_plot='pr',
                                     output_path=f'./{FIGURES_PATH}/proj_gcm_fut_mean.pdf')

### Compute climate change signals

In [69]:
reduction_function = deep4downscaling.metrics_ccs.mean
ccs_mean = deep4downscaling.metrics_ccs.compute_ccs(hist_data=proj_historical, fut_data=proj_future,
                                                    reduction_function=reduction_function,
                                                    relative=True)

In [73]:
# Visualize the climate change signal
deep4downscaling.viz.simple_map_plot(data=ccs_mean,
                                     colorbar='BrBG', var_to_plot='pr',
                                     vlimits=(-40, 40), num_levels=16,
                                     output_path=f'./{FIGURES_PATH}/ccs_mean.pdf')