# Quantile-quantile (QQ) scaled climate projections

In [None]:
import glob
import calendar
import sys
import gc

import xarray as xr
from xclim import sdba
import matplotlib.pyplot as plt
import numpy as np
import cartopy.crs as ccrs
import xesmf as xe
import dask.diagnostics
import cmdline_provenance as cmdprov
import dask

from calc_adjustment import read_data
from apply_adjustment import check_units

In [None]:
dask.diagnostics.ProgressBar().register()

In [None]:
# Parameters
example_lat = -42.9
example_lon = 147.3
example_month = 6

In [None]:
# Required parameters
assert 'cmip_var' in locals(), "Must provide a CMIP variable name (option -p cmip_var {name})"
assert 'obs_var' in locals(), "Must provide a observations variable name (option -p obs_var {name})"
assert 'cmip_units' in locals(), "Must provide CMIP units name (option -p cmip_units {units})"
assert 'adjustment_file' in locals(), "Must provide an adjustment factors file (option -p adjustment_file {file path})"
assert 'hist_files' in locals(), """Must provide historical data files (option -p hist_files {"file paths"})"""
assert 'fut_files' in locals(), """Must provide future data files (option -p fut_files {"file paths"})"""
assert 'obs_files' in locals(), """Must provide observational data files (option -p obs_files {"file paths"})"""
assert 'qq_file' in locals(), "Must provide an qq-scaled data file (option -p qq_file {file path})"
assert 'obs_time_bounds' in locals(), """Must provide time bounds for observations (option -p obs_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""
assert 'hist_time_bounds' in locals(), """Must provide time bounds for historical GCM data (option -p hist_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""
assert 'fut_time_bounds' in locals(), """Must provide time bounds for future GCM data (option -p fut_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""

In [None]:
hist_files = hist_files.split()
fut_files = fut_files.split()
obs_files = obs_files.split()

obs_time_bounds = obs_time_bounds.split()
hist_time_bounds = hist_time_bounds.split()
fut_time_bounds = fut_time_bounds.split()

## Read data

In [None]:
ds_obs = read_data(
    obs_files,
    obs_var,
    time_bounds=obs_time_bounds,
    input_units=obs_units,
    output_units=qq_units
)

In [None]:
ds_obs[obs_var]

In [None]:
lat_min_obs = ds_obs['lat'].values.min()
lat_max_obs = ds_obs['lat'].values.max()
lon_min_obs = ds_obs['lon'].values.min()
lon_max_obs = ds_obs['lon'].values.max()

In [None]:
print(lat_min_obs, lat_max_obs, lon_min_obs, lon_max_obs)

In [None]:
ds_hist = read_data(
    hist_files,
    cmip_var,
    time_bounds=hist_time_bounds,
    input_units=cmip_units,
    output_units=qq_units)
ds_hist[cmip_var]

In [None]:
ds_future = read_data(
    fut_files,
    cmip_var,
    time_bounds=fut_time_bounds,
    input_units=cmip_units,
    output_units=qq_units)
ds_future[cmip_var]

In [None]:
ds_adjust = xr.open_dataset(adjustment_file)
qm = sdba.QuantileDeltaMapping.from_dataset(ds_adjust)

In [None]:
qq_obs = xr.open_dataset(qq_file)

## Quantile mapping

In [None]:
def quantile_month_plot(da, cmap=None, diverging=False, point=None):
    """Create two dimensional month/quantile plot"""

    fig, ax = plt.subplots(figsize=[16, 6])
    if 'lat' in da.dims:
        point_selection = {'lat': example_lat, 'lon': example_lon}
    else:
        point_selection = {}
    quantiles = da.sel(point_selection, method='nearest')
    if diverging:
        abs_max = np.max(np.abs(quantiles.values))
        vmax = abs_max
        vmin = -1 * abs_max
    else:
        vmin = vmax = None
    quantiles.transpose('month', 'quantiles').plot(cmap=cmap, vmax=vmax, vmin=vmin)
    yticks = np.arange(1,13)
    ytick_labels = [calendar.month_abbr[i] for i in yticks]
    plt.yticks(yticks, ytick_labels)
    ax.invert_yaxis()
    plt.show()

In [None]:
def quantile_spatial_plot(
    da, month, lat_bounds=None, lon_bounds=None, cmap=None, diverging=False,
):
    """Spatial plot of the 10th, 50th and 90th percentile"""
    
    da_selection = da.sel({'quantiles': [.1, .5, .9], 'month': month}, method='nearest')
    if lat_bounds:
        lat_min_obs, lat_max_obs = lat_bounds
        da_selection = da_selection.sel(lat=slice(lat_min_obs, lat_max_obs))
    if lon_bounds:
        lon_min_obs, lon_max_obs = lon_bounds
        da_selection = da_selection.sel(lon=slice(lon_min_obs, lon_max_obs))
    if diverging:
        abs_max = np.max(np.abs(da_selection.values))
        vmax = abs_max
        vmin = -1 * abs_max
    else:
        vmin = vmax = None
    p = da_selection.plot(
        col='quantiles',
        transform=ccrs.PlateCarree(),
        cmap=cmap,
        figsize=[20, 5.5],
        subplot_kws={'projection': ccrs.PlateCarree(),},
        vmax=vmax,
        vmin=vmin,
    )
    for ax in p.axes.flat:
        ax.coastlines()
        ax.plot(example_lon, example_lat, 'go', zorder=5, transform=ccrs.PlateCarree())
    plt.suptitle(calendar.month_name[month])
    plt.show()

In [None]:
if 'lat' in qm.ds.dims:
    quantile_spatial_plot(
        qm.ds['hist_q'],
        example_month,
        lat_bounds=[lat_min_obs, lat_max_obs],
        lon_bounds=[lon_min_obs, lon_max_obs],
        cmap='hot_r',
    )

In [None]:
quantile_month_plot(qm.ds['hist_q'], cmap='hot_r')

In [None]:
qm.ds['af']

In [None]:
if 'lat' in qm.ds.dims:
    quantile_spatial_plot(
        qm.ds['af'],
        example_month,
        lat_bounds=[lat_min_obs, lat_max_obs],
        lon_bounds=[lon_min_obs, lon_max_obs],
        cmap='RdBu_r',
        diverging=True,
    )

In [None]:
quantile_month_plot(qm.ds['af'], cmap='RdBu_r', diverging=True)

## QQ-scaled projections

In [None]:
model_sel = {'lat': slice(lat_min_obs, lat_max_obs), 'lon': slice(lon_min_obs, lon_max_obs)}

In [None]:
hist_clim = ds_hist[cmip_var].mean('time', keep_attrs=True)
future_clim = ds_future[cmip_var].mean('time', keep_attrs=True)
model_mean_change = future_clim - hist_clim

In [None]:
model_mean_change = model_mean_change.compute()

In [None]:
obs_clim = ds_obs[obs_var].mean('time', keep_attrs=True)
qq_clim = qq_obs[obs_var].mean('time', keep_attrs=True)
qq_mean_change = qq_clim - obs_clim

In [None]:
qq_mean_change = qq_mean_change.compute()

In [None]:
regridder = xe.Regridder(model_mean_change, qq_mean_change, "bilinear")
model_mean_change_regridded = regridder(model_mean_change)

In [None]:
mean_change_difference = qq_mean_change - model_mean_change_regridded

In [None]:
model_abs_max = np.max(np.abs(model_mean_change.sel(model_sel).values))
qq_abs_max = np.max(np.abs(qq_mean_change.values))
abs_max = np.max([model_abs_max, qq_abs_max])
vmax = abs_max
vmin = -1 * abs_max

fig = plt.figure(figsize=[16, 6])

ax1 = fig.add_subplot(121, projection=ccrs.PlateCarree())
model_mean_change.sel(model_sel).plot(
    ax=ax1,
    transform=ccrs.PlateCarree(),
    cmap='RdBu_r',
    vmax=vmax,
    vmin=vmin
)
ax1.set_title('GCM')

ax2 = fig.add_subplot(122, projection=ccrs.PlateCarree())
qq_mean_change.plot(
    ax=ax2,
    transform=ccrs.PlateCarree(),
    cmap='RdBu_r',
    vmax=vmax,
    vmin=vmin
)
ax2.set_title('QQ-scaled')

for ax in [ax1, ax2]:
    ax.coastlines()
    ax.plot(example_lon, example_lat, 'go', zorder=5, transform=ccrs.PlateCarree())
xmin, xmax = ax2.get_xlim()
ymin, ymax = ax2.get_ylim()
ax1.set_extent([xmin, xmax, ymin, ymax], crs=ccrs.PlateCarree())

plt.suptitle('Projected change')
plt.show()

In [None]:
fig = plt.figure(figsize=[16, 6])

ax1 = fig.add_subplot(111, projection=ccrs.PlateCarree())
mean_change_difference.plot(
    ax=ax1,
    transform=ccrs.PlateCarree(),
    cmap='RdBu_r',
    levels=13,
)
ax1.set_title('Mean change difference (QQ-scaled minus GCM)')

ax1.coastlines()
ax1.plot(example_lon, example_lat, 'go', zorder=5, transform=ccrs.PlateCarree())

#plt.suptitle(f'{model_name}, {future_scenario} projected change')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=[14, 7])
if 'lat' in qq_obs.dims:
    point_selection = {'lat': example_lat, 'lon': example_lon}
else:
    point_selection = {}
ds_hist[cmip_var].sel(point_selection, method='nearest').groupby("time.dayofyear").mean().plot(label="historicalGCM", color='tab:orange', linestyle=':')
ds_future[cmip_var].sel(point_selection, method='nearest').groupby("time.dayofyear").mean().plot(label="future GCM", color='tab:orange')
ds_obs[obs_var].sel(point_selection, method='nearest').groupby("time.dayofyear").mean().plot(label="observations", color='tab:red', linestyle=':')
qq_obs[obs_var].sel(point_selection, method='nearest').groupby("time.dayofyear").mean().plot(label="QQ-scaled data", color='tab:red')
plt.legend()
plt.title('Daily climatology')
plt.show()

In [None]:
fig = plt.figure(figsize=[10, 6])
if 'lat' in qq_obs.dims:
    point_selection = {'lat': example_lat, 'lon': example_lon}
else:
    point_selection = {}
ds_obs[obs_var].sel(point_selection, method='nearest').plot.hist(bins=50, density=True, label='observations', alpha=0.7)
qq_obs[obs_var].sel(point_selection, method='nearest').plot.hist(bins=50, density=True, label='QQ-scaled data', facecolor='green', alpha=0.7)
plt.ylabel('probability')
plt.legend()
plt.show()