# Quantile-quantile (QQ) scaled climate projections

In [None]:
import glob
import calendar
import sys
import gc

import xarray as xr
from xclim import sdba
import matplotlib.pyplot as plt
import numpy as np
import cartopy.crs as ccrs
import xesmf as xe
import dask.diagnostics
import cmdline_provenance as cmdprov
import dask
import cmocean

from calc_adjustment import read_data

In [None]:
dask.diagnostics.ProgressBar().register()

In [None]:
# Parameters
example_lat = -42.9
example_lon = 147.3
example_month = 6

In [None]:
# Required parameters
assert 'hist_var' in locals(), "Must provide an historical variable name (option -p hist_var {name})"
assert 'ref_var' in locals(), "Must provide a reference variable name (option -p ref_var {name})"
assert 'target_var' in locals(), "Must provide a target variable name (option -p target_var {name})"
assert 'hist_units' in locals(), "Must provide historical units (option -p hist_units {units})"
assert 'ref_units' in locals(), "Must provide reference units (option -p ref_units {units})"
assert 'target_units' in locals(), "Must provide target units (option -p target_units {units})"
assert 'adjustment_file' in locals(), "Must provide an adjustment factors file (option -p adjustment_file {file path})"
assert 'hist_files' in locals(), """Must provide historical data files (option -p hist_files {"file paths"})"""
assert 'ref_files' in locals(), """Must provide reference data files (option -p ref_files {"file paths"})"""
assert 'target_files' in locals(), """Must provide target data files (option -p target_files {"file paths"})"""
assert 'qq_file' in locals(), "Must provide an qq-scaled data file (option -p qq_file {file path})"
assert 'hist_time_bounds' in locals(), """Must provide time bounds for historical data (option -p hist_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""
assert 'ref_time_bounds' in locals(), """Must provide time bounds for reference data (option -p ref_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""
assert 'obs_time_bounds' in locals(), """Must provide time bounds for target data (option -p target_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""

In [None]:
hist_files = hist_files.split()
ref_files = ref_files.split()
target_files = target_files.split()

hist_time_bounds = hist_time_bounds.split()
ref_time_bounds = ref_time_bounds.split()
target_time_bounds = target_time_bounds.split()

if hist_var in ['tasmax']:
    regular_cmap = 'hot_r'
    diverging_cmap = 'RdBu_r'
elif hist_var in ['tasmin']:
    regular_cmap = cmocean.cm.ice
    diverging_cmap = 'RdBu_r'
elif hist_var in ['pr']:
    regular_cmap = cmocean.cm.rain
    diverging_cmap = 'BrBG'
else:
    raise ValueError(f'No cmap defined for {cmip_var}')

## Read data

In [None]:
ds_target = read_data(
    target_files,
    target_var,
    time_bounds=target_time_bounds,
    input_units=target_units,
    output_units=qq_units
)

In [None]:
ds_target[target_var]

In [None]:
ds_hist = read_data(
    hist_files,
    cmip_var,
    time_bounds=hist_time_bounds,
    input_units=hist_units,
    output_units=qq_units)
ds_hist[hist_var]

In [None]:
ds_ref = read_data(
    ref_files,
    ref_var,
    time_bounds=ref_time_bounds,
    input_units=ref_units,
    output_units=qq_units)
ds_ref[ref_var]

In [None]:
ds_adjust = xr.open_dataset(adjustment_file)
qm = sdba.QuantileDeltaMapping.from_dataset(ds_adjust)

In [None]:
ds_qq = xr.open_dataset(qq_file)

In [None]:
lat_min_target = ds_qq['lat'].values.min()
lat_max_target = ds_qq['lat'].values.max()
lon_min_target = ds_qq['lon'].values.min()
lon_max_target = ds_qq['lon'].values.max()

In [None]:
print(lat_min_obs, lat_max_obs, lon_min_obs, lon_max_obs)

## Quantile mapping

In [None]:
def quantile_month_plot(da, cmap=None, diverging=False, point=None):
    """Create two dimensional month/quantile plot"""

    fig, ax = plt.subplots(figsize=[16, 6])
    if 'lat' in da.dims:
        point_selection = {'lat': example_lat, 'lon': example_lon}
    else:
        point_selection = {}
    quantiles = da.sel(point_selection, method='nearest')
    if diverging:
        abs_max = np.max(np.abs(quantiles.values))
        vmax = abs_max
        vmin = -1 * abs_max
    else:
        vmin = vmax = None
    quantiles.transpose('month', 'quantiles').plot(cmap=cmap, vmax=vmax, vmin=vmin)
    yticks = np.arange(1,13)
    ytick_labels = [calendar.month_abbr[i] for i in yticks]
    plt.yticks(yticks, ytick_labels)
    ax.invert_yaxis()
    plt.show()

In [None]:
def quantile_spatial_plot(
    da, month, lat_bounds=None, lon_bounds=None, cmap=None, diverging=False,
):
    """Spatial plot of the 10th, 50th and 90th percentile"""
    
    da_selection = da.sel({'quantiles': [.1, .5, .9], 'month': month}, method='nearest')
    if lat_bounds:
        lat_min_obs, lat_max_obs = lat_bounds
        da_selection = da_selection.sel(lat=slice(lat_min_obs, lat_max_obs))
    if lon_bounds:
        lon_min_obs, lon_max_obs = lon_bounds
        da_selection = da_selection.sel(lon=slice(lon_min_obs, lon_max_obs))
    if diverging:
        abs_max = np.max(np.abs(da_selection.values))
        vmax = abs_max
        vmin = -1 * abs_max
    else:
        vmin = vmax = None
    p = da_selection.plot(
        col='quantiles',
        transform=ccrs.PlateCarree(),
        cmap=cmap,
        figsize=[20, 5.5],
        subplot_kws={'projection': ccrs.PlateCarree(),},
        vmax=vmax,
        vmin=vmin,
    )
    for ax in p.axes.flat:
        ax.coastlines()
        ax.plot(example_lon, example_lat, 'go', zorder=5, transform=ccrs.PlateCarree())
    plt.suptitle(calendar.month_name[month])
    plt.show()

In [None]:
if 'lat' in qm.ds.dims:
    quantile_spatial_plot(
        qm.ds['hist_q'],
        example_month,
        lat_bounds=[lat_min_obs, lat_max_obs],
        lon_bounds=[lon_min_obs, lon_max_obs],
        cmap=regular_cmap,
    )

In [None]:
quantile_month_plot(qm.ds['hist_q'], cmap=regular_cmap)

In [None]:
qm.ds['af']

In [None]:
if 'lat' in qm.ds.dims:
    quantile_spatial_plot(
        qm.ds['af'],
        example_month,
        lat_bounds=[lat_min_obs, lat_max_obs],
        lon_bounds=[lon_min_obs, lon_max_obs],
        cmap=diverging_cmap,
        diverging=True,
    )

In [None]:
quantile_month_plot(qm.ds['af'], cmap=diverging_cmap, diverging=True)

## QQ-scaled data

In [None]:
#model_sel = {'lat': slice(lat_min_obs, lat_max_obs), 'lon': slice(lon_min_obs, lon_max_obs)}

In [None]:
hist_clim = ds_hist[hist_var].mean('time', keep_attrs=True)
ref_clim = ds_ref[ref_var].mean('time', keep_attrs=True)
target_clim = ds_target[target_var].mean('time', keep_attrs=True)
qq_clim = ds_qq[target_var].mean('time', keep_attrs=True)

In [None]:
if len(hist_clim['lat']) != len(qq_clim['lat']):
    regridder = xe.Regridder(hist_clim, qq_clim, "bilinear")
    hist_clim = regridder(hist_clim)

if len(ref_clim['lat']) != len(qq_clim['lat']):
    regridder = xe.Regridder(ref_clim, qq_clim, "bilinear")
    ref_clim = regridder(ref_clim)

In [None]:
hist_ref_mean_diff = ref_clim - hist_clim
hist_ref_mean_diff = hist_ref_mean_diff.compute()

In [None]:
qq_mean_change = qq_clim - target_clim
qq_mean_change = qq_mean_change.compute()

In [None]:
difference = qq_mean_change - hist_ref_mean_diff

In [None]:
bias_abs_max = np.max(np.abs(model_mean_change.sel(model_sel).values))
qq_abs_max = np.max(np.abs(qq_mean_change.values))
abs_max = np.max([bias_abs_max, qq_abs_max])
vmax = abs_max
vmin = -1 * abs_max

fig = plt.figure(figsize=[24, 6])

ax1 = fig.add_subplot(131, projection=ccrs.PlateCarree())
hist_ref_mean_diff.sel(model_sel).plot(
    ax=ax1,
    transform=ccrs.PlateCarree(),
    cmap=diverging_cmap,
    vmax=vmax,
    vmin=vmin
)
ax1.set_title('ref - hist')

ax2 = fig.add_subplot(132, projection=ccrs.PlateCarree())
qq_mean_change.plot(
    ax=ax2,
    transform=ccrs.PlateCarree(),
    cmap=diverging_cmap,
    vmax=vmax,
    vmin=vmin
)
ax2.set_title('QQ-scaled - original')

ax3 = fig.add_subplot(133, projection=ccrs.PlateCarree())
difference.plot(
    ax=ax3,
    transform=ccrs.PlateCarree(),
    cmap=diverging_cmap,
    levels=13,
)
ax3.set_title('Difference')

for ax in [ax1, ax2, ax3]:
    ax.coastlines()
    ax.plot(example_lon, example_lat, 'go', zorder=5, transform=ccrs.PlateCarree())
xmin, xmax = ax3.get_xlim()
ymin, ymax = ax3.get_ylim()
ax1.set_extent([xmin, xmax, ymin, ymax], crs=ccrs.PlateCarree())
ax2.set_extent([xmin, xmax, ymin, ymax], crs=ccrs.PlateCarree())

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=[14, 7])
if 'lat' in ds_qq.dims:
    point_selection = {'lat': example_lat, 'lon': example_lon}
else:
    point_selection = {}
ds_hist[hist_var].sel(point_selection, method='nearest').groupby("time.dayofyear").mean().plot(label="historical", color='tab:orange', linestyle=':')
ds_ref[ref_var].sel(point_selection, method='nearest').groupby("time.dayofyear").mean().plot(label="reference", color='tab:orange')
ds_target[target_var].sel(point_selection, method='nearest').groupby("time.dayofyear").mean().plot(label="original", color='tab:red', linestyle=':')
ds_qq[target_var].sel(point_selection, method='nearest').groupby("time.dayofyear").mean().plot(label="QQ-scaled", color='tab:red')
plt.legend()
plt.title('Daily climatology')
plt.show()

In [None]:
fig = plt.figure(figsize=[10, 6])
if 'lat' in ds_qq.dims:
    point_selection = {'lat': example_lat, 'lon': example_lon}
else:
    point_selection = {}
ds_target[target_var].sel(point_selection, method='nearest').plot.hist(bins=50, density=True, label='observations', alpha=0.7)
ds_qq[target_var].sel(point_selection, method='nearest').plot.hist(bins=50, density=True, label='QQ-scaled data', facecolor='green', alpha=0.7)
plt.ylabel('probability')
plt.legend()
plt.show()