## Notebook for tasmax's northern hemisphere summer 95th percentile maps in GDPCIR paper

#### last updated : 2022/11/15, by Emile Tenezakis (e.tenezakis@gmail.com). 
#### scaling : with a remote dask cluster as a backend to the xarray datasets. With the cluster scaling parameters as is, notebook should take around 10 minutes, including normal cluster spin up time. 
#### output : notebook saves the figure to user specific `figure_3_output_file_path` defined below.
#### library dependencies : This ran on rhodium's onyx environment, with the open source rhodium `rhg_compute_tools` and open source `dodola` packages pip-installed in editable mode (if using onyx, install both without their dependencies with the `--no-deps` flag).  
#### data dependencies : publicly available GDPCIR datasets stored on google cloud, and a yaml file containing the URLs to these datasets. This yaml file is available in the GDPCIR github repository, specify your local path to it below with `fps_yaml_path`. 

In [None]:
figure_3_output_file_path = '/home/jovyan/tests/fig3_v0.png' # put that wherever you want
fps_yaml_path = '/home/jovyan/repositories/downscaleCMIP6/notebooks/downscaling_pipeline/post_processing_and_delivery/data_paths.yaml'
model = 'NorESM2-LM'
fut_scenario = 'ssp370'
var = 'tasmax'

In [None]:
import dask
import xarray as xr
import numpy as np
import pandas as pd
import yaml
from rhg_compute_tools import kubernetes as rhgk
from dodola.services import xesmf_regrid
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

In [None]:
# each 50 workers get 1/2 a 48GiB Node. 
# necessary to use half a node b/c of reanalysis. 
client, cluster = rhgk.get_big_cluster()
cluster.scale(50)

In [None]:
cluster

In [None]:
def gcm_q_trend(model, fut_scenario, var, step_label, fut_period=range(2080, 2100+1), hist_period=range(1995, 2014+1), months=[6,7,8], quantile=0.95, pix=None) -> xr.DataArray():
    
    """
    function that loads GCM data at the step `step_label` (raw, bias corrected, downscaled etc) from the URLs yaml file and computes the trend in a particular seasonal quantile of tasmax for `model`, scenario `fut_scenario`, variable `var. Trend is computed between `fut_period` and `hist_period`. Season is defined with `months`. Quantile is defined with `quantile`. 
    
    You can locally, without the dask cluster, test the function with `pix`.
    """
    with open(fps_yaml_path, 'r') as f:
        fps = yaml.load(f, yaml.Loader)
    hist_scenario_label = 'historical'
    fut = xr.open_zarr(fps[f'{model}-{var}'][fut_scenario][step_label])[var]
    if pix is not None:
        fut = fut.isel(lat=pix['lat'], lon=pix['lon'], drop=True)
    fut = fut.where(fut.time.dt.year.isin(fut_period), drop=True)
    hist = xr.open_zarr(fps[f'{model}-{var}'][hist_scenario_label][step_label])[var]
    if pix is not None:
        hist = hist.isel(lat=pix['lat'], lon=pix['lon'], drop=True)
    hist = hist.where(hist.time.dt.year.isin(hist_period), drop=True)
    
    fut = fut.where(fut.time.dt.month.isin(months), drop=True)
    hist = hist.where(hist.time.dt.month.isin(months), drop=True)
    
    if pix is not None:
        fut = fut.load()
        hist = hist.load()
    else:
        # move chunks to space to take temporal quantile
        fut = fut.chunk({'time': -1, 'lat': 360, 'lon': 360})
        hist = hist.chunk({'time': -1, 'lat': 360, 'lon': 360})

    fut_q = fut.quantile(q=quantile, dim='time')
    hist_q = hist.quantile(q=quantile, dim='time')
    trend = fut_q - hist_q
    
    return trend

In [None]:
def ref_q(period=range(1995, 2014+1),months=[6,7,8],quantile=0.95, pix=None) -> xr.DataArray():

    """
    function that loads tasmax reanalysis data and computes the `quantile` of the time series subset defined by `period` (years) and `months` (seasonality). 
    
    You can locally, without the dask cluster, test the function with `pix`.
    """
    
    ref_da = xr.open_zarr('gs://clean-b1dbca25/reanalysis/ERA-5/F320/tasmax.1995-2015.F320.zarr')['tasmax']
    if pix is not None:
        ref_da = ref_da.isel(lat=pix['lat'], lon=pix['lon'], drop=True)
    ref_da = ref_da.where(ref_da.time.dt.year.isin(period), drop=True)
    ref_da = ref_da.where(ref_da.time.dt.month.isin(months), drop=True)
    if pix is not None:
        ref_da = ref_da.load()
    else:
        # move chunks to space to take temporal quantile
        ref_da = ref_da.chunk({'time':-1, 'lat':360, 'lon':360})
    ref_q = ref_da.quantile(q=quantile, dim='time')
    return ref_q

In [None]:
def regrid_to_downscaled_resolution(da):
    """
    regrid a lat/lon dataarray to the downscaling resolution using dodola's `xesmf_regrid`. 
    """
    domain_ds = xr.open_zarr('gs://support-c23ff1a3/domain.0p25x0p25.zarr', chunks=None)
    da = xesmf_regrid(x=xr.Dataset({'da':da}), domain=domain_ds, method='nearest_s2d', astype=np.float32, add_cyclic=None, keep_attrs=True)['da']
    return da

In [None]:
def cyclic_lon(da):
    """
    [0, 360] - > [-180, 180]
    """
    return da.assign_coords({'lon':xr.where(da.lon>180, da.lon-360, da.lon)}).sortby('lon')    

In [None]:
# test with one pixel
# test_trend = trend(model=model, fut_scenario=fut_scenario, var=var, step_label='downscaled_delivered', pix={'lat':300, 'lon':300})

In [None]:
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    downscaled_trend = gcm_q_trend(model=model, fut_scenario=fut_scenario, var=var, step_label='downscaled_delivered')

In [None]:
downscaled_trend = downscaled_trend.compute()

In [None]:
raw_cleaned_trend = gcm_q_trend(model=model, fut_scenario=fut_scenario, var=var, step_label='clean')

In [None]:
raw_cleaned_trend = raw_cleaned_trend.compute()

In [None]:
raw_cleaned_trend = cyclic_lon(raw_cleaned_trend)

In [None]:
raw_cleaned_trend_regridded = regrid_to_downscaled_resolution(raw_cleaned_trend)

In [None]:
diff_downscaled_regriddedraw = (downscaled_trend - raw_cleaned_trend_regridded)

In [None]:
# test on one pixel
# test_ref_q = ref_q(pix={'lat':300, 'lon':300})

In [None]:
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    reference_quantile = ref_q()

In [None]:
reference_quantile = reference_quantile.compute()

In [None]:
reference_quantile = cyclic_lon(reference_quantile)

In [None]:
plot_kwargs = dict(cmap='viridis', vmin=0, vmax=10)
diff_plot_kwargs = dict(cmap='viridis', vmin=-1, vmax=1)
abs_plot_kwargs = dict(cmap='viridis', vmin=230, vmax=315)
all_pieces = [raw_cleaned_trend, downscaled_trend, diff_downscaled_regriddedraw]
titles = ['(A) change in model', '(B) change in downscaled', '(C) difference in change (downscaled - model)']
kwargs_list = [plot_kwargs, plot_kwargs, diff_plot_kwargs]
from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
all_pieces_copy = copy(all_pieces)
fig, axes = plt.subplots(ncols=len(all_pieces_copy), nrows=1, figsize=(8*len(all_pieces_copy), 6))
for i,_ in enumerate(all_pieces_copy):
    v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
    v['lat'].attrs = dict()
    v['lon'].attrs = dict()
    im = all_pieces[i].plot(add_colorbar=False, ax=ax, **kwargs_list[i])
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='3%', pad=0.15)
    fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='temperature (K)')
    ax.set_title(titles[i])
    ax.set_xlabel('')
    ax.set_ylabel('')
fig.suptitle('95th Percentile JJA Maximum Temperature. SSP3-7.0. NorESM2-LM')
plt.savefig(figure_3_output_file_path, facecolor='white')