## Notebook for tasmax, pr seasonal quantile trend maps for all GCMs in GDPCIR paper

* last updated: 2023/05/10 by Kelly McCusker. Copy `tasmax_seasonal_quantile_trend_allmodels.ipynb` to this notebook and copy functionality from `precip_seasonal_quantile_trend.ipynb`. Make it work for tasmax and pr for all seasons, quantiles, and GCMs.


In [None]:
import os

# put this wherever you want
FIGURE_OUTPUT_DIR = "/gcs/impactlab-data/climate/downscaling/qc/kelly_diagnostics/figure3-4/images"#{var}/{kwstr}" 

REPO_ROOT = "../../"
assert "notebooks" in os.listdir(REPO_ROOT)

figure_3_output_file_path = os.path.join(
    FIGURE_OUTPUT_DIR,
    "figure_3_{var}_summer_q{quant}_trend_with_biascorrected_clipped_{model}.png",
)
figure_3diagnostic_output_file_path = os.path.join(
    FIGURE_OUTPUT_DIR,
    "figure_3-4_{var}_{season}_q{quant}_trend_with_biascorrected_downscaled_clipped_{model}.png",
)
figure_3diagnostic_withdownscaled_output_file_path = os.path.join(
    FIGURE_OUTPUT_DIR,
    "figure_3_{var}_{season}_q{quant}_trend_with_biascorrected_downscaledvraw_clipped_{model}.png",
)


figure_a2_output_file_path = os.path.join(
    FIGURE_OUTPUT_DIR,
    "figure_a2_tasmax_summer_q{quant}_trend_diff_linear_{model}.png",
)

fps_yaml_path = os.path.join(
    REPO_ROOT,
    "notebooks/downscaling_pipeline/post_processing_and_delivery/data_paths.yaml",
)


bucket_mapping_oregon_trail = {
    "biascorrected-492e989a": "biascorrected-4a21ed18",
    "clean-b1dbca25": "clean-f1e04ef5",
    "downscaled-288ec5ac": "downscaled-48ec31ab",
    "raw-305d04da": "raw-957d115e",
    "support-c23ff1a3": "support-f8a48a9e",
}

# NEW support BUCKET
BUCKET = 'support-f8a48a9e'
DS_BUCKET = 'downscaled-48ec31ab'


In [None]:
%%capture
try:
    import rhg_compute_tools
except ModuleNotFoundError:
    ! pip install rhg_compute_tools

In [None]:
try:
    import dodola.services
    from dodola.services import xesmf_regrid

except ModuleNotFoundError:
    print("pip install")
    %pip install --no-deps git+https://github.com/ClimateImpactLab/dodola

    from dodola.services import xesmf_regrid

In [None]:
import cartopy.crs as ccrs
import dask
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from datatree import DataTree
import yaml
from cartopy.feature import NaturalEarthFeature
from rhg_compute_tools import kubernetes as rhgk

import sys
sys.path.insert(1, '../downscaling_pipeline/post_processing_and_delivery/')
import dc6_functions


# FIRST SECTION: Save out intermediate data for figure-making later in notebook

In [None]:
# each 50 workers get 1/2 a 48GiB Node.
# necessary to use half a node b/c of reanalysis.
from dask_gateway import Gateway

gateway = Gateway()
img = os.environ["JUPYTER_IMAGE"]
# with gateway.new_cluster(worker_image=img, scheduler_image=img) as cluster:

cluster = gateway.new_cluster(worker_image=img, scheduler_image=img, profile="big")
client = cluster.get_client()
# img = "pangeo/pangeo-notebook:2023.04.15"
# client, cluster = rhgk.get_big_cluster(name=img,
#                                        scheduler=img,
#                                        # extra_pip_packages='xarray xclim rhg_compute_tools --no-deps git+https://github.com/ClimateImpactLab/dodola'
#                                       )
# extra_pip_packages='cdsapi'
cluster.scale(130); # used 100 workers to process all models

In [None]:
cluster

In [None]:
# cluster.scale(0)
# client.restart()
# del cluster, client

In [None]:
variables = ['tasmax', 'tavg', 'pr']
models = dc6_functions.get_cmip6_models()
institutions = dc6_functions.get_cmip6_institutions()
ensemble_members = dc6_functions.get_cmip6_ensemble_members()



In [None]:
mod = list(models.keys())[0]
fut_scenario = "ssp370"
var = "pr" #"pr" #"tasmax"

months = [6, 7, 8]
quantiles = [0.01, 0.05, 0.5, 0.95,0.99]


In [None]:
with open(fps_yaml_path, "r") as f:
    fps = yaml.load(f, yaml.Loader)

In [None]:

def read_data(model, scenario, variable, step_label, pix, fps=fps, ref=False,chunks="auto"):
        
    if ref:
        if step_label == 'downscaled_delivered':
            # "QPLAD fine reference zarr"
            da = xr.open_zarr(f'gs://support-f8a48a9e/qplad-fine-reference/{variable}/v20220201000555.zarr', 
                                 chunks=chunks)[variable]
        elif step_label == 'biascorrected':
            # "QDM reference zarr"
            da = xr.open_zarr(f'gs://support-f8a48a9e/qdm-reference/{variable}/v20220201000555.zarr', 
                                 chunks=chunks)[variable]
        else:
            raise ValueError
    else:
        da = xr.open_zarr(fps[f'{model}-{variable}'][scenario][step_label], chunks=chunks)[variable]
    
    # da_tasmax = xr.open_zarr(fps[f"{model}-tasmax"][scenario][step_label])["tasmax"]
    if pix is not None:
        da = da.isel(lat=pix["lat"], lon=pix["lon"], drop=True)
    if (step_label == "biascorrected") and (variable=="tasmax"):
        # apply the swap to be consistent with downscaling
        da_tasmin = xr.open_zarr(fps[f"{model}-tasmin"][scenario][step_label])["tasmin"]
        if pix is not None:
            da_tasmin = da_tasmin.isel(lat=pix["lat"], lon=pix["lon"], drop=True)
        da = xr.where(da_tasmin > da, da_tasmin, da)
    return da


In [None]:
def cap_precip(ref, hist, fut):
    # compute ref max over training period
    ref_max_pr = ref.sel(time=slice('1994-12-16', '2015-01-15')).max('time').compute()
    # comput hist max over training period
    hist_max_pr = hist.sel(time=slice('1994-12-16', '2015-01-15')).max('time').compute()
    # compute fut max over 21 year period the given date falls in
    max_pr_per_year = (
        xr.concat([hist, fut], dim='time')
        .groupby('time.year')
        .max(dim='time')
        .rolling(year=21, center=True, min_periods=21)
        .max()
        .compute()
    )
    
    # compute the factor making sure not to leave nans around
    factor = (
        (max_pr_per_year.dropna(dim='year', how='all') / hist_max_pr)
    )
    # compute the cap with all this
    cap_values = ref_max_pr * np.maximum(1, factor)
    
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        cap_values_expanded = (
            cap_values
            .reindex(year=np.unique(fut.time.dt.year), method='nearest')
            .sel(year=fut.time.dt.year)
            .drop('year')
        )
        
        capped_fut = np.minimum(fut, cap_values_expanded)
          
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        cap_values_expanded_hist = (
            cap_values
            .reindex(year=np.unique(hist.time.dt.year), method='nearest')
            .sel(year=hist.time.dt.year)
            .drop('year')
        )
        
        capped_hist = np.minimum(hist, cap_values_expanded_hist)
        
    return capped_hist, capped_fut

In [None]:
def gcm_pr_q_trend_biascorrected_pix(pix, model, fut_scenario, 
                                  quantile,
                                  fut_period=range(2080, 2100+1), 
                                  hist_period=range(1995, 2014+1), 
                                  # months=months
                                    ) -> xr.DataArray():
    
    """
    function that loads GCM data at the step `step_label` (raw, bias corrected, downscaled etc) 
    from the URLs yaml file and computes the trend in a particular seasonal quantile of tasmax 
    for `model`, scenario `fut_scenario`, variable `var. Trend is computed between `fut_period` 
    and `hist_period`. Season is defined with `months`. Quantile is defined with `quantile`. 
    
    You can locally, without the dask cluster, test the function with `pix`.
    """
    fut = read_data(model=model, scenario=fut_scenario, variable="pr", 
                    step_label='biascorrected', pix=pix, chunks=None)
    hist = read_data(model=model, scenario='historical', variable="pr",
                     step_label='biascorrected', pix=pix, chunks=None)

    fut = fut.load()
    hist = hist.load()
  
    ref = read_data(variable="pr",step_label='biascorrected', ref=True,
                    pix=pix, chunks=None, model=None, scenario=None)
    ref = ref.load()
    hist, fut = cap_precip(ref, hist, fut)
        
    fut = fut.where(fut.time.dt.year.isin(fut_period), drop=True)
    # fut = fut.where(fut.time.dt.month.isin(months), drop=True)
    seadt={}
    for sea,months in zip(["DJF","MAM","JJA","SON"],
                          [[12,1,2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]):
        tst = fut.where(fut.time.dt.month.isin(months), drop=True)
        tst = tst.assign_coords(season=('time', [sea,]*len(tst.time)))
        tst = tst.swap_dims({'time': 'season'})
        seadt[sea] = tst
    fut = xr.concat(seadt.values(),dim="season")
    
    hist = hist.where(hist.time.dt.year.isin(hist_period), drop=True)
    # hist = hist.where(hist.time.dt.month.isin(months), drop=True)
    seadt={}
    for sea,months in zip(["DJF","MAM","JJA","SON"],
                          [[12,1,2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]):
        tst = hist.where(hist.time.dt.month.isin(months), drop=True)
        tst = tst.assign_coords(season=('time', [sea,]*len(tst.time)))
        tst = tst.swap_dims({'time': 'season'})
        seadt[sea] = tst
    hist = xr.concat(seadt.values(),dim="season")


    fut_q = fut.groupby("season").quantile(q=quantile)#.quantile(q=quantile, dim='time')
    hist_q = hist.groupby("season").quantile(q=quantile)#.quantile(q=quantile, dim='time')
    trend = fut_q/hist_q
    
    return trend

def gcm_pr_q_trend_biascorrected(**kwargs): 
    pixs = [{'lat': [x], 'lon': list(range(0, 360))} for x in range(0, 180)]
    futs = client.map(gcm_pr_q_trend_biascorrected_pix, pixs, **kwargs)
    results = client.gather(futs)
    return xr.combine_by_coords(results).pr

In [None]:
def gcm_historical_q(model, step_label, 
                     variable,
                     quantile, 
                     hist_period=range(1995, 2014+1), 
                     # months=months, 
                     pix=None) -> xr.DataArray():
    
    """
    function that loads historical GCM data at the step `step_label` (raw, bias corrected, downscaled etc) 
    from the URLs yaml file and computes the seasonal quantiles
    for `model`, scenario `fut_scenario`, variable `variable`. 
    Quantiles are defined with `quantile`. 
    
    You can locally, without the dask cluster, test the function with `pix`.
    """
    hist = read_data(model=model, scenario='historical', variable=variable, 
                     step_label=step_label, pix=pix)
    
    hist = hist.where(hist.time.dt.year.isin(hist_period), drop=True)
    # hist = hist.where(hist.time.dt.month.isin(months), drop=True)
    seadt={}
    for sea,months in zip(["DJF","MAM","JJA","SON"],
                          [[12,1,2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]):
        tst = hist.where(hist.time.dt.month.isin(months), drop=True)
        tst = tst.assign_coords(season=('time', [sea,]*len(tst.time)))
        tst = tst.swap_dims({'time': 'season'})
        seadt[sea] = tst
    hist = xr.concat(seadt.values(),dim="season")
    
    if pix is not None:
        hist = hist.load()
    else:
        # move chunks to space to take temporal quantile
        hist = hist.chunk({'season': -1, 'lat': 360, 'lon': 360})

    hist_q = hist.groupby("season").quantile(q=quantile)#.quantile(q=quantile, dim='time')
    
    return hist_q

def ref_q(variable,
          quantile, 
          period=range(1995, 2014+1),
          # months=months,
          pix=None) -> xr.DataArray():

    """
    function that loads variable reanalysis data and computes the `quantile` of the 
    time series subset defined by `period` (years) and `months` (seasonality). 
    
    You can locally, without the dask cluster, test the function with `pix`.
    """
    
    ref_da = read_data(variable=variable, ref=True, step_label='downscaled_delivered', 
                       pix=pix, model=None, scenario=None)
    
    ref_da = ref_da.where(ref_da.time.dt.year.isin(period), drop=True)
    # ref_da = ref_da.where(ref_da.time.dt.month.isin(months), drop=True)
    seadt={}
    for sea,months in zip(["DJF","MAM","JJA","SON"],
                          [[12,1,2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]):
        tst = ref_da.where(ref_da.time.dt.month.isin(months), drop=True)
        tst = tst.assign_coords(season=('time', [sea,]*len(tst.time)))
        tst = tst.swap_dims({'time': 'season'})
        seadt[sea] = tst
    ref_da = xr.concat(seadt.values(),dim="season")
    
    if pix is not None:
        ref_da = ref_da.load()
    else:
        # move chunks to space to take temporal quantile
        ref_da = ref_da.chunk({'season':-1, 'lat':360, 'lon':360})
    ref_q = ref_da.groupby("season").quantile(q=quantile)#, dim='time')
    return ref_q

In [None]:

def gcm_q_trend(
    model,
    scenario,
    variable,
    step_label,
    # months,
    quantile,
    fut_period=range(2080, 2100 + 1),
    hist_period=range(1995, 2014 + 1),
    pix=None,
) -> xr.DataArray:
    """
    calculate period difference in a seasonal quantile

    function that loads GCM data at the step `step_label` (raw, bias corrected,
    downscaled etc) from the URLs yaml file and computes the trend in a particular
    seasonal quantile of tasmax for `model`, scenario `fut_scenario`, variable `var.
    Trend is computed between `fut_period` and `hist_period`. Season is defined with
    `months`. Quantiles to compute are defined with `quantile`, a scalar or list.

    You can locally, without the dask cluster, test the function with `pix`.
    """

    # read data, subset time series for fut
    fut = read_data(model=model, scenario=scenario, variable=variable,
                               step_label=step_label, pix=pix)
        
    fut = fut.where(fut.time.dt.year.isin(fut_period), drop=True)
    seadt={}
    for sea,months in zip(["DJF","MAM","JJA","SON"],
                          [[12,1,2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]):
        tst = fut.where(fut.time.dt.month.isin(months), drop=True)
        tst = tst.assign_coords(season=('time', [sea,]*len(tst.time)))
        tst = tst.swap_dims({'time': 'season'})
        seadt[sea] = tst
    fut = xr.concat(seadt.values(),dim="season")

    # same for hist
    hist = read_data(model=model, scenario="historical", variable=variable,
                     step_label=step_label, pix=pix)
    hist = hist.where(hist.time.dt.year.isin(hist_period), drop=True)
    seadt={}
    for sea,months in zip(["DJF","MAM","JJA","SON"],
                          [[12,1,2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]):
        tst = hist.where(hist.time.dt.month.isin(months), drop=True)
        tst = tst.assign_coords(season=('time', [sea,]*len(tst.time)))
        tst = tst.swap_dims({'time': 'season'})
        seadt[sea] = tst
    hist = xr.concat(seadt.values(),dim="season")

    # if took only one pixel, load
    if pix is not None:
        fut = fut.load()
        hist = hist.load()
    else:
        # move chunks to space to take temporal quantile
        fut = fut.chunk({"season": -1, "lat": 360, "lon": 360})
        hist = hist.chunk({"season": -1, "lat": 360, "lon": 360})

    fut_q = fut.groupby("season").quantile(q=quantile)
    hist_q = hist.groupby("season").quantile(q=quantile)
    if variable == "pr":
        trend = fut_q/hist_q
    else:
        trend = fut_q - hist_q

    return trend

def regrid_to_coarse_resolution(da):
    domain_ds = xr.open_zarr("gs://support-f8a48a9e/domain.1x1.zarr", chunks=None)
    da = xesmf_regrid(
        x=xr.Dataset({"da": da}),
        domain=domain_ds,
        method="nearest_s2d",
        astype=np.float32,
        add_cyclic=None,
        keep_attrs=True,
    )["da"]
    return da


def regrid_to_downscaled_resolution(da):
    """
    regrid a lat/lon dataarray to the downscaling resolution using dodola's `xesmf_regrid`.
    """
    domain_ds = xr.open_zarr("gs://support-f8a48a9e/domain.0p25x0p25.zarr", chunks=None)
    da = xesmf_regrid(
        x=xr.Dataset({"da": da}),
        domain=domain_ds,
        method="nearest_s2d",
        astype=np.float32,
        add_cyclic=None,
        keep_attrs=True,
    )["da"]
    return da

def cyclic_lon(da):
    """
    [0, 360] - > [-180, 180]
    """
    return da.assign_coords(
        {"lon": xr.where(da.lon > 180, da.lon - 360, da.lon)}
    ).sortby("lon")

def compute_trends(model, scenario, variable, 
                   # months, 
                   quantile):
    
    print("biascorrected_trend")
    if variable=="pr":
        biascorrected_trend = gcm_pr_q_trend_biascorrected(model=model, 
                                                           fut_scenario=fut_scenario,
                                                          quantile=quantile)
    else:
        biascorrected_trend = gcm_q_trend(
            model=model, scenario=scenario, variable=variable, step_label="biascorrected",
            # months=months, 
            quantile=quantile
        ).compute()

    biascorrected_trend = cyclic_lon(biascorrected_trend)
    biascorrected_trend.values = np.ascontiguousarray(biascorrected_trend.values)
    biascorrected_fine_trend = regrid_to_downscaled_resolution(biascorrected_trend)
    
    print("downscaled_trend")
    with dask.config.set(**{"array.slicing.split_large_chunks": False}):
        downscaled_trend = gcm_q_trend(
            model=model,
            scenario=scenario,
            variable=variable,
            step_label="downscaled_delivered",
            # months=months, 
            quantile=quantile
        ).compute()
    
    print("raw_cleaned_trend")
    raw_cleaned_trend = gcm_q_trend(
        model=model,
        scenario=scenario,
        variable=variable,
        step_label="clean",
        # months=months, 
        quantile=quantile
    ).compute()

    print("regridding and cyclic lons")
    raw_cleaned_trend = cyclic_lon(raw_cleaned_trend)
    raw_cleaned_trend.values = np.ascontiguousarray(raw_cleaned_trend.values)
    
    raw_cleaned_trend_regridded_coarse = regrid_to_coarse_resolution(raw_cleaned_trend)
    raw_cleaned_trend_regridded_coarse.values = np.ascontiguousarray(raw_cleaned_trend_regridded_coarse.values)
    raw_cleaned_trend_regridded_fine = regrid_to_downscaled_resolution(
        raw_cleaned_trend_regridded_coarse
    )
    
    raw_cleaned_trend_regridded_direct = regrid_to_downscaled_resolution(raw_cleaned_trend)
    if variable == "pr":
        diff_downscaled_regriddedraw = downscaled_trend / raw_cleaned_trend_regridded_fine
        diff_downscaled_biascorrected = downscaled_trend / biascorrected_fine_trend
    
        diff_biascorrected_regriddedraw = (
            biascorrected_trend / raw_cleaned_trend_regridded_coarse
        )
    else:
        diff_downscaled_regriddedraw = downscaled_trend - raw_cleaned_trend_regridded_fine
        diff_downscaled_biascorrected = downscaled_trend - biascorrected_fine_trend
    
        diff_biascorrected_regriddedraw = (
            biascorrected_trend - raw_cleaned_trend_regridded_coarse
        )
    
    print("historical quantile")
    raw_cleaned_historical_q = gcm_historical_q(model=model,
                                                variable=variable,
                                                quantile=quantile, 
                                                step_label='clean').compute()
    raw_cleaned_historical_q_regridfine = regrid_to_downscaled_resolution(cyclic_lon(raw_cleaned_historical_q))

    all_plot_pieces = [
        raw_cleaned_trend,
        diff_biascorrected_regriddedraw,
        diff_downscaled_biascorrected,
        diff_downscaled_regriddedraw,
        raw_cleaned_trend_regridded_coarse, # use this for averaging across GCMs
        raw_cleaned_historical_q, # use this for precip Fig (fig4) panel 2
        raw_cleaned_historical_q_regridfine, # use this in case comparing historical q to ref q
    ]
    
    return all_plot_pieces

In [None]:
# only need to do this once (not per GCM)
def get_reference_quantile(variable, quantile):
    
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        rq = ref_q(variable,quantile)
        
    rq = rq.compute()
    rq = cyclic_lon(rq)
    return rq

In [None]:
var,quantiles

### Compute the reference seasonal quantiles and save to disk

In [None]:
SCRATCH = SCRATCH = "/gcs/impactlab-data/climate/downscaling/qc/kelly_diagnostics/figure3-4"#impactlab-data-scratch"
SAVEDIR = f"{SCRATCH}"#/gdpcir-diagnostics/figure3-4/"

if var=="pr":
    reference_quantile = get_reference_quantile(variable=var,
                                                quantile=quantiles)



    attrs = {
        "Created by": "Kelly McCusker <kmccusker@rhg.com",
        "Date created": "May 9, 2023",
        "Description": (
            "Figure 4, panel 1 (as submitted, ERA5 reference quantiles). "+\
            f"This is seasonal daily {var} quantile 1995-2014 (0.25deg resolution)."
        )
    }
    reference_quantile.attrs.update(attrs)
    reference_quantile.to_netcdf(f"{SAVEDIR}/figure_3-4_ERA5_{var}_allseason_quantiles_fine_reference.nc", 
                            encoding={var: {"zlib": True}})


In [None]:
# ! ls -ltrh $SAVEDIR

In [None]:
client.restart()

In [None]:
# test wrapper function
all_pieces = compute_trends(model=mod, scenario=fut_scenario, variable=var, 
                            # months=months,
                            quantile=quantiles
                           )

In [None]:
        # raw_cleaned_trend,
        # diff_biascorrected_regriddedraw,
        # diff_downscaled_biascorrected,
        # diff_downscaled_regriddedraw,
        # raw_cleaned_trend_regridded_coarse, # use this for averaging across GCMs
        # raw_cleaned_historical_q, # use this for precip Fig panel 2
        # raw_cleaned_historical_q_regridfine, # use this in case comparing historical q to ref q

all_pieces

In [None]:
for it in all_pieces:
    print(it.nbytes/(1024)**3)
    
        # raw_cleaned_trend,
        # diff_biascorrected_regriddedraw,
        # diff_downscaled_biascorrected,
        # diff_downscaled_regriddedraw,
        # raw_cleaned_trend_regridded_coarse, # use this for averaging across GCMs
        # raw_cleaned_historical_q, # use this for precip Fig panel 2
        # raw_cleaned_historical_q_regridfine, # use this in case comparing historical q to ref q


In [None]:

## TEST THE PR DATA COMING OUT...This figure code from precip_seasonal_quantile_trend.ipynb

quant=0.95
sea="JJA"

plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
# all_pieces = [reference_quantile, raw_cleaned_historical_regridded, raw_cleaned_trend, ratio_regriddedbiascorrected_regriddedraw, ratio_downscaled_biascorrected]
titles = ['(A) reference', '(B) model', '(C) multiplicative change in model', '(D) ratio of changes (biascorrected/model)', '(E) ratio of changes (downscaled/biascorrected)']
labels = ['precipitation [mm]', 'precipitation [mm]', 'ratio [mm/mm]', 'ratio [mm/mm]', 'ratio [[mm/mm]/[mm/mm]]']
kwargs_list = [abs_plot_kwargs, abs_plot_kwargs, plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
all_pieces_copy = copy(all_pieces)
fig, axes = plt.subplots(ncols=len(all_pieces_copy[2:]), nrows=1, figsize=(36, 10), subplot_kw={'projection': ccrs.Robinson()})
coastline_feature = NaturalEarthFeature('physical', 'coastline', '50m',
                                       edgecolor='black', facecolor='none')
for i,_ in enumerate(all_pieces_copy[2:]):
    v, k, kw, ax = all_pieces_copy[i+2], titles[i+2], kwargs_list[i+2], axes[i]
    v['lat'].attrs = dict()
    v['lon'].attrs = dict()
    ax.add_feature(coastline_feature)
    im = all_pieces[i].sel(quantile=quant,season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                            cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                             extend='both', label=labels[i+2]), 
                            **kwargs_list[i+2])
    #divider = make_axes_locatable(ax)
    #cax = divider.append_axes('right', size='3%', pad=0.15)
    #fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='precipitation (mm)')
    ax.set_title(titles[i+2])
    ax.set_xlabel('')
    ax.set_ylabel('')
#fig.suptitle(TITLE)


### Loop through all GCMs and compute the summary stats

In [None]:
allmods = []
modsdone = []
for mod in models:
    print(mod)
    # this takes about 2 mins per model with 60 big workers, 1:20 with 100 workers
    if fut_scenario in models[mod]:
        allmods.append(compute_trends(model=mod, scenario=fut_scenario, variable=var, 
                                      # months=months,
                                      quantile=quantiles))
        modsdone.append(mod)
    else: print("skipping..")
        
print("done")
# TODO silence the large chunk warnings
# with dask.config.set(**{'array.slicing.split_large_chunks': False}):
#     ...     array[indexer]

# TODO change to C_CONTIGUOUS on all arrays that get regridded to fix that warning


In [None]:
len(allmods)

# this took 33 mins for 20 models tasmax with 100 workers
# this took longer for 20 models pr with 120 workers, tho I didn't time it properly

### Reorganize the intermediate outputs: Compile the metrics at different stages for all GCMs

In [None]:
rawcln = {}
rawclncoarse = []
diffbcregridraw = []
diffdnscbc = []
diffdnregridraw = []

rawclnhist = {}
rawclnhistrg = []

for ii,pieces in enumerate(allmods):
        
    rawcln[modsdone[ii]] = pieces[0]       # raw_cleaned_trend, (different grids)
    rawclncoarse.append(pieces[4])    # raw cleaned regridded to coarse
    diffbcregridraw.append(pieces[1]) # diff_biascorrected_regriddedraw
    diffdnscbc.append(pieces[2])      # diff_downscaled_biascorrected
    diffdnregridraw.append(pieces[3]) # diff downscaled v regridded raw
        
    rawclnhist[modsdone[ii]] = pieces[5]  # raw cleaned historical quantile
    rawclnhistrg.append(pieces[6])      # raw cleaned historical quantile fine
    
    
    
rawclean_dt = DataTree.from_dict(rawcln)
diffbcrgraw_ds = xr.concat(diffbcregridraw, dim=pd.Index(modsdone, name="model")).to_dataset(name=var)
diffdnscbc_ds = xr.concat(diffdnscbc, dim=pd.Index(modsdone, name="model")).to_dataset(name=var)
diffdnrgraw_ds = xr.concat(diffdnregridraw, dim=pd.Index(modsdone, name="model")).to_dataset(name=var)
rawclncoarse_ds = xr.concat(rawclncoarse, dim=pd.Index(modsdone, name="model")).to_dataset(name=var)

rawcleanhist_dt = DataTree.from_dict(rawclnhist)
rawclnhistrg_ds = xr.concat(rawclnhistrg,dim=pd.Index(modsdone, name="model")).to_dataset(name=var)


### Save intermediate data for the figures

In [None]:
# SCRATCH = "/gcs/impactlab-data-scratch"
# SAVEDIR = f"{SCRATCH}/gdpcir-diagnostics/figure3-4/"
SCRATCH = SCRATCH = "/gcs/impactlab-data/climate/downscaling/qc/kelly_diagnostics/figure3-4"#impactlab-data-scratch"
SAVEDIR = f"{SCRATCH}"#/gdpcir-diagnostics/figure3-4/"


In [None]:
created = "May 25, 2023"
# Save these outputs for later
attrs = {
    "Created by": "Kelly McCusker <kmccusker@rhg.com",
    "Date created": created,
    "Description": (
        "Figure 3, panel 2 (as submitted) for all GCMs. "+\
        "This is the difference in: seasonal daily tasmax quantile change 2080-2100 "+\
        "relative to 1995-2015 at the bias corrected step and the "+\
        "quantile changes at the raw cleaned step (1deg resolution)."
    )
}
diffbcrgraw_ds.attrs.update(attrs)
diffbcrgraw_ds.to_netcdf(f"{SAVEDIR}/figure_3-4_{fut_scenario}_{var}_allseason_quantiles_trends_biascorrected_v_cleaned_allgcms.nc", 
                        encoding={var: {"zlib": True}})


attrs = {
    "Created by": "Kelly McCusker <kmccusker@rhg.com",
    "Date created": created,
    "Description": (
        "Figure 3, not in panels (as submitted) for all GCMs. "+\
        "This is the difference in: seasonal daily tasmax quantile change 2080-2100 "+\
        "relative to 1995-2015 at the downscaled step and the "+\
        "quantile changes at the bias corrected step (0.25deg resolution)."
    )
}
diffdnscbc_ds.attrs.update(attrs)
diffdnscbc_ds.to_netcdf(f"{SAVEDIR}/figure_3-4_{fut_scenario}_{var}_allseason_quantiles_trends_downscaled_v_biascorrected_allgcms.nc", 
                        encoding={var: {"zlib": True}})


attrs = {
    "Created by": "Kelly McCusker <kmccusker@rhg.com",
    "Date created": created,
    "Description": (
        "Figure 3, not in panels (as submitted) for all GCMs. "+\
        "This is the difference in: seasonal daily tasmax quantile change 2080-2100 "+\
        "relative to 1995-2015 at the downscaled step and the "+\
        "quantile changes at the raw cleaned step (0.25deg resolution)."
    )
}
diffdnrgraw_ds.attrs.update(attrs)
diffdnrgraw_ds.to_netcdf(f"{SAVEDIR}/figure_3-4_{fut_scenario}_{var}_allseason_quantiles_trends_downscaled_v_cleaned_allgcms.nc", 
                        encoding={var: {"zlib": True}})


attrs = {
    "Created by": "Kelly McCusker <kmccusker@rhg.com",
    "Date created": created,
    "Description": (
        "Figure 3, not in panels (as submitted) for all GCMs. "+\
        "This is the seasonal daily tasmax quantile change 2080-2100 "+\
        "relative to 1995-2015 at the raw cleaned step regridded to 1deg resolution."
    )
}
rawclncoarse_ds.attrs.update(attrs)
rawclncoarse_ds.to_netcdf(f"{SAVEDIR}/figure_3_{fut_scenario}_{var}_allseason_quantiles_trends_raw_cleaned_coarse_allgcms.nc", 
                        encoding={var: {"zlib": True}})


# TODO ADD attributes to the datatrees
# the datatrees take awhile to save
print(f"Size of data in tree = {rawclean_dt.nbytes / 1e9 :.2f} GB")
rawclean_dt.to_zarr(f"{SAVEDIR}/figure_3-4_{fut_scenario}_{var}_allseason_quantiles_trends_rawcleaned_allgcms.zarr")

print(f"Size of data in tree = {rawcleanhist_dt.nbytes / 1e9 :.2f} GB")
rawcleanhist_dt.to_zarr(f"{SAVEDIR}/figure_3-4_historical_{var}_allseason_quantiles_rawcleaned_allgcms.zarr")


attrs = {
    "Created by": "Kelly McCusker <kmccusker@rhg.com",
    "Date created": created,
    "Description": (
        "Figure 3-4, panel 2 for precip, not in tasmax fig (as submitted) for all GCMs. "+\
        "This is the seasonal daily tasmax quantile in the historical sim "+\
        " years 1995-2015 at the raw cleaned step regridded to fine (0.25degree resolution)."
    )
}
rawclnhistrg_ds.attrs.update(attrs)
rawclnhistrg_ds.to_netcdf(f"{SAVEDIR}/figure_3-4_historical_{var}_allseason_quantiles_raw_cleaned_fine_allgcms.nc", 
                        encoding={var: {"zlib": True}})

print("saved all")

In [None]:
cluster.scale(0)
client.restart()
del cluster, client

In [None]:
! ls -ltrh $SAVEDIR

# SECOND SECTION: Start here if intermediate data is saved above

In [None]:
var = "pr" #"tasmax" #"pr"
fut_scenario = "ssp370"

In [None]:
!ls -ltrh /gcs/impactlab-data/climate/downscaling/qc/kelly_diagnostics/figure3-4

In [None]:
import datatree as dt

# open data files - this is kinda slow, in particular because of loading the datatrees I think
# about 10-11GB total
SCRATCH = "/gcs/impactlab-data/climate/downscaling/qc/kelly_diagnostics/figure3-4"#impactlab-data-scratch"
SAVEDIR = f"{SCRATCH}"#/gdpcir-diagnostics/figure3-4"

# if var=="tasmax":
#     fignum="3"
# elif var=="pr":
fignum="3-4"
    
rawclean_dt = dt.open_datatree(f"{SAVEDIR}/figure_3-4_{fut_scenario}_{var}_allseason_quantiles_trends_rawcleaned_allgcms.zarr", engine="zarr").load()

diffbcrgraw_ds = xr.open_dataset(
    f"{SAVEDIR}/figure_{fignum}_{fut_scenario}_{var}_allseason_quantiles_trends_biascorrected_v_cleaned_allgcms.nc").load()
diffdnscbc_ds = xr.open_dataset(
    f"{SAVEDIR}/figure_{fignum}_{fut_scenario}_{var}_allseason_quantiles_trends_downscaled_v_biascorrected_allgcms.nc").load()
diffdnrgraw_ds = xr.open_dataset(
    f"{SAVEDIR}/figure_{fignum}_{fut_scenario}_{var}_allseason_quantiles_trends_downscaled_v_cleaned_allgcms.nc").load()
rawclncoarse_ds = xr.open_dataset(
    f"{SAVEDIR}/figure_3_{fut_scenario}_{var}_allseason_quantiles_trends_raw_cleaned_coarse_allgcms.nc").load()

if var=="pr": # haven't made these files for tasmax yet
    rawcleanhist_dt = dt.open_datatree(f"{SAVEDIR}/figure_{fignum}_historical_{var}_allseason_quantiles_rawcleaned_allgcms.zarr", engine="zarr").load()
    rawclnhistrg_ds = xr.open_dataset(f"{SAVEDIR}/figure_{fignum}_historical_{var}_allseason_quantiles_raw_cleaned_fine_allgcms.nc").load()


## Make figures
Figure files are saved to disk for each GCM and for the mean across GCMs.

These cells must be run manually per-quantile to save different quantile figures.

### GCM mean

#### This is the figure in initially submitted paper (but an ensemble mean instead of one GCM)

In [None]:
quant = 0.95
sea = "JJA"

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
    abs_plot_kwargs = dict(vmin=230, vmax=315)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted - raw model)",
        "c. difference in change (downscaled - bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted - raw model)"]
    clabel = "temperature (C)"
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted / raw model)",
        "c. difference in change (downscaled / bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted / raw model)"]
    clabel = "precipitation (mm/day)"
    
kwargs_list = [plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
from copy import copy

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


all_pieces = [
    rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
    diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant, season=sea), 
    diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant, season=sea)
]
    

all_pieces_copy = copy(all_pieces)
if 1:    
    
    coastline_feature = NaturalEarthFeature(
        "physical", "coastline", "10m", edgecolor="black", facecolor="none"
    )
    
    fig, axes = plt.subplots(
        ncols=2,
        nrows=1,
        figsize=(18, 4),
        subplot_kw={"projection": ccrs.Robinson()},
        dpi=200,
    )
    
    for i, _ in enumerate(all_pieces_copy[:2]):
        v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature, linewidth=0.2)
        da = all_pieces[i]
        # vmax = da.max().item()
        # vmin = da.min().item()

#         if i == 1:
#             vmin, vmax, amax = -3, 3, 3
#         elif i == 0:
#             vmin, vmax, amax = -12, 12, 12

#         amax = max(abs(vmax), abs(vmin))
#         norm = matplotlib.colors.Normalize(vmin=-amax, vmax=amax)

        im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            # clim=(vmin, vmax),
            # vmin=-amax,
            # vmax=amax,
            # cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label=clabel,
            ),
            **kw
        )
        ax.set_title(titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
    fig.set_facecolor("white")
    # fig.savefig(figure_3_output_file_path.format(var=var,model="GCMmean",quant=quant), 
    #             facecolor="white", bbox_inches="tight")

#### add a panel showing ratio between error and raw trend

In [None]:
quant = 0.95
sea = "JJA"

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
    abs_plot_kwargs = dict(vmin=230, vmax=315)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted - raw model)",
        "c. difference in change (downscaled - bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted - raw model)"]
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    err_plot_kwargs = dict(cmap="RdBu", vmin=-2, vmax=2)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted / raw model)",
        "c. difference in change (downscaled / bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted / raw model)"]
    
kwargs_list = [plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
from copy import copy

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


all_pieces = [
    rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
    diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant, season=sea), 
    diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant, season=sea)
]
    

all_pieces_copy = copy(all_pieces)
if 1:    
    
    coastline_feature = NaturalEarthFeature(
        "physical", "coastline", "10m", edgecolor="black", facecolor="none"
    )
    
    fig, axes = plt.subplots(
        ncols=3,
        nrows=1,
        figsize=(21, 4),
        subplot_kw={"projection": ccrs.Robinson()},
        dpi=200,
    )
    
    for i, _ in enumerate(all_pieces_copy[:2]):
        v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature, linewidth=0.2)
        da = all_pieces[i]
        # vmax = da.max().item()
        # vmin = da.min().item()

#         if i == 1:
#             vmin, vmax, amax = -3, 3, 3
#         elif i == 0:
#             vmin, vmax, amax = -12, 12, 12

#         amax = max(abs(vmax), abs(vmin))
#         norm = matplotlib.colors.Normalize(vmin=-amax, vmax=amax)

        im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            # clim=(vmin, vmax),
            # vmin=-amax,
            # vmax=amax,
            # cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label=clabel,
            ),
            **kw
        )
        ax.set_title(titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
        
    ax = axes[2]
    ax.add_feature(coastline_feature, linewidth=0.2)
    da = (all_pieces[1]/all_pieces[0])
    im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            # clim=(vmin, vmax),
            # vmin=-amax,
            # vmax=amax,
            # cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label="ratio",
            ),
            **err_plot_kwargs#dict(cmap='RdBu_r', vmin=-.25, vmax=.25)
        )
    ax.set_title("ratio bias-adjustment error/raw trend")
    ax.set_xlabel("")
    ax.set_ylabel("")

    fig.set_facecolor("white")
    # fig.savefig(figure_3_output_file_path.format(var=var,model="GCMmean",quant=quant), 
    #             facecolor="white", bbox_inches="tight")

#### try histograms instead
TODO: only grab land grid cells

In [None]:
da.where(all_pieces[3]>1)

In [None]:
quant = 0.95
dolog=True

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
    abs_plot_kwargs = dict(vmin=230, vmax=315)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted - raw model)",
        "c. difference in change (downscaled - bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted - raw model)"]
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted / raw model)",
        "c. difference in change (downscaled / bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted / raw model)"]
    
kwargs_list = [plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
from copy import copy

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


    
    
fig, axes = plt.subplots(
    ncols=3,
    nrows=4,
    figsize=(36, 30),
    dpi=200,
)
for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
        
    all_pieces = [
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant, season=sea), 
            diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
            rawclnhistrg_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
    ]

    for i, _ in enumerate(all_pieces[:2]):
        print(i)
        v, k, kw, ax = all_pieces[i], titles[i], kwargs_list[i], axes[s,i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        da = all_pieces[i]

        infs = da.where(np.isinf(da),drop=True)
        print(infs)
        im = da.where(~np.isinf(da)).plot.hist(
                ax=ax,
                bins=150,
                density=True,
            log=dolog,
                # **kw
        ) # .where(all_pieces[3]>1) <--- adding this threw an exception
        
        ax.set_title(sea+ ": " + titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
    ax=axes[s,2]
        # PLOT HIST OF RATIO
        # da = (all_pieces[1]/all_pieces[0])
        # im = da.plot.hist(
        #     ax=ax,
        #     bins=150,
        #     density=True,
        # )
        # ax.set_title("ratio bias-adjustment error/raw trend")

        # OR PLOT SCATTER OF ERROR AGAINST RAW TREND
    lats,lons = np.meshgrid(all_pieces[0].lat,all_pieces[0].lon)
    im = ax.scatter(all_pieces[0],all_pieces[1], 
                        s=2, 
                        alpha=0.5, 
                        c=lons,
                        cmap="twilight",
                    )
    ax.set_title(f"{sea}: raw trend vs bias adjustment error")
    ax.set_ylabel("bias adjustment error")
    ax.set_xlabel("raw GCM trend")
    plt.colorbar(im, **dict(
                    fraction=0.046,
                    pad=0.04,
                    orientation="vertical",
                    extend="both",
                    label="longitude",
            ))
fig.set_facecolor("white")
fpath=FIGURE_OUTPUT_DIR.format(var=var)
fn = f"{fpath}/figure_3-4_{var}_allseason_q{quant}_trend_vs_biascorrectederror_histograms_GCMmean_log{str(dolog)}.png"
fig.savefig(fn, facecolor="white", bbox_inches="tight")

#### Paper Fig 3 (tasmax) and 4 (pr) - make the version of the figure that goes into the paper (tasmax and precip are slightly different). Include all seasons, loop through quantiles

In [None]:
if var=="pr":
    ref_quantile_ds = xr.open_dataset(f"{SAVEDIR}/figure_3-4_ERA5_pr_allseason_quantiles_fine_reference.nc").load()

In [None]:
printtofile=True
nokw = False
quantiles = [0.01, 0.05, 0.5, 0.95, 0.99]

fontsize = 16

if var == "tasmax":
    figsize=(36, 30)
    plot_kwargs = dict(cmap='RdBu_r',vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r',vmin=-3, vmax=3)
    diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis',vmin=230, vmax=315)
    titles = [
        # '(A) reference', 
        # '(B) raw model', 
        'Change in raw GCM', 
        'Difference in change \n(bias adjusted - raw GCM)', 
        'Difference in change \n(downscaled - bias adjusted)']
    labels = [#'max temperature [K]', 
              #'max temperature [K]', 
              'max temperature [C]', 
              'max temperature [C]', 
              'max temperature [C]']
    kwargs_list = [#abs_plot_kwargs, 
                   #abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diffdiff_plot_kwargs]
elif var == "pr":
    figsize=(36,28)
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diffdiff_plot_kwargs = dict(cmap='RdBu', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = ['Reference', 
              'Raw GCM', 
              'Change in raw GCM', 
              'Difference in change \n(bias adjusted / raw GCM)', 
              'Difference in change \n(downscaled / bias adjusted)']
    labels = ['total precipitation [mm]', 
              'total precipitation [mm]', 
              'total precipitation [mm]', 
              'ratio [mm/mm]', 
              'ratio [mm/mm]']
    kwargs_list = [abs_plot_kwargs, 
                   abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]
    
from string import ascii_uppercase as alc

# plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
# diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
# diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
# abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
# all_pieces = [
#     # reference_quantile, 
#     # raw_cleaned_historical_regridded, 
#     raw_cleaned_trend, 
#     ratio_regriddedbiascorrected_regriddedraw, 
#     ratio_downscaled_biascorrected
# ]
# all_pieces = [
#     rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
#     diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
#     diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
# ]

from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
# all_pieces_copy = copy(all_pieces)
coastline_feature = NaturalEarthFeature('physical', 'coastline', '50m',
                                       edgecolor='black', facecolor='none')

for quant in quantiles[-2:]:
    print(f"Doing quantile={quant}")

    if var == "tasmax":
        all_pieces = [
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
            diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
        ]
    elif var == "pr":
        all_pieces = [
            ref_quantile_ds[var].sel(quantile=quant),
            rawclnhistrg_ds[var].mean(dim="model").sel(quantile=quant),
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
            diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
        ]
        
    fig, axes = plt.subplots(ncols=len(all_pieces), nrows=4, figsize=figsize, 
                             subplot_kw={'projection': ccrs.Robinson()})
    for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
        for i,_ in enumerate(all_pieces):
        
            k, kw, ax = (
                # all_pieces_copy[i].sel(season=sea), 
                f"{alc[i+s]}"+titles[i], 
                kwargs_list[i], 
                axes[s,i]
            )
            # v['lat'].attrs = dict()
            # v['lon'].attrs = dict()
            ax.add_feature(coastline_feature)
            if nokw:
                kwstr="_noclim"
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i], ),
                                                        )
            else:
                kwstr=""
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i], ),
                                                        **kwargs_list[i])
            #divider = make_axes_locatable(ax)
            #cax = divider.append_axes('right', size='3%', pad=0.15)
            #fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='precipitation (mm)')
            ax.set_title(f"({alc[i+s*len(all_pieces)]}) "+ sea + ": " + titles[i], fontsize=fontsize)
            ax.set_xlabel('')
            ax.set_ylabel('')
    #fig.suptitle(TITLE)
    if printtofile:
        fn=figure_3diagnostic_output_file_path.format(var=var,
                                                      season="allseason", 
                                                      quant=quant,
                                                      model="GCMmean",
                                                      kwstr=kwstr)
        fig.savefig(fn, facecolor='white', bbox_inches='tight')
        print(f"saved {fn}")

In [None]:
!ls -ltrh /gcs/impactlab-data/climate/downscaling/qc/kelly_diagnostics/figure3-4/images

### Make a version of Fig 3-4 (ensemble mean) that swaps in downscaled - raw model for the last panel. Clim keywords are off. <-- Not using this figure in paper.

In [None]:
nokw=True
quant = 0.99

plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r',vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r',vmin=-3, vmax=3)
    diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis',vmin=230, vmax=315)
    titles = [
        # '(A) reference', 
        # '(B) raw model', 
        '(A) change in raw model', 
        '(B) difference in change (bias adjusted - raw model)', 
        '(C) difference in change (downscaled - raw model)']
    labels = [#'max temperature [K]', 
              #'max temperature [K]', 
              'max temperature [K]', 
              'max temperature [K]', 
              'max temperature [K]']
    kwargs_list = [#abs_plot_kwargs, 
                   #abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diffdiff_plot_kwargs = dict(cmap='RdBu', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = ['(A) reference', 
              '(B) raw model', 
              '(C) change in raw model', 
              '(D) difference in change (bias adjusted / raw model)', 
              '(E) difference in change (downscaled / raw model)']
    labels = ['total precipitation [mm]', 
              'total precipitation [mm]', 
              'total precipitation [mm]', 
              'ratio [mm/mm]', 
              'ratio [mm/mm]']
    kwargs_list = [abs_plot_kwargs, 
                   abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]

from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
# all_pieces_copy = copy(all_pieces)
coastline_feature = NaturalEarthFeature('physical', 'coastline', '50m',
                                       edgecolor='black', facecolor='none')

for quant in [0.99]:# quantiles:
    print(f"Doing quantile={quant}")

    if var == "tasmax":
        all_pieces = [
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
            diffdnrgraw_ds[var].mean(dim="model").sel(quantile=quant),
            # diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
        ]
    elif var == "pr":
            all_pieces = [
                ref_quantile_ds[var].sel(quantile=quant),
                rawclnhistrg_ds[var].mean(dim="model").sel(quantile=quant),
                rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
                diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
                diffdnrgraw_ds[var].mean(dim="model").sel(quantile=quant),
            # diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
            ]
        

    fig, axes = plt.subplots(ncols=len(all_pieces), nrows=4, figsize=(36, 30), 
                             subplot_kw={'projection': ccrs.Robinson()})
    for i,_ in enumerate(all_pieces):
        for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
            k, kw, ax = (
                # all_pieces_copy[i].sel(season=sea), 
                titles[i], 
                kwargs_list[i], 
                axes[s,i]
            )
            # v['lat'].attrs = dict()
            # v['lon'].attrs = dict()
            ax.add_feature(coastline_feature)
            if nokw:
                kwstr="noclims"
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i]),robust=True)# **kwargs_list[i])
            else:
                kwstr=""
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i]), **kwargs_list[i])
            #divider = make_axes_locatable(ax)
            #cax = divider.append_axes('right', size='3%', pad=0.15)
            #fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='precipitation (mm)')
            ax.set_title(titles[i] + " " + sea)
            ax.set_xlabel('')
            ax.set_ylabel('')
    #fig.suptitle(TITLE)
    fig.savefig(figure_3diagnostic_withdownscaled_output_file_path.format(season="allseason", 
                                                                          var=var,
                                                                           quant=quant,
                                                                           model=f"GCMmean{kwstr}",
                                                                         kwstr=kwstr), 
                                                           facecolor='white', bbox_inches='tight')

### Figures for all GCMs. With clim keywords. With downscaled vs raw in last panel

In [None]:
nokw=False
quant = 0.95

plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r',vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r',vmin=-3, vmax=3)
    diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis',vmin=230, vmax=315)
    titles = [
        # '(A) reference', 
        # '(B) raw model', 
        '(A) change in raw model', 
        '(B) difference in change (bias adjusted - raw model)', 
        '(C) difference in change (downscaled - raw model)']
    labels = [#'max temperature [K]', 
              #'max temperature [K]', 
              'max temperature [K]', 
              'max temperature [K]', 
              'max temperature [K]']
    kwargs_list = [#abs_plot_kwargs, 
                   #abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diffdiff_plot_kwargs = dict(cmap='RdBu', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = ['(A) reference', 
              '(B) raw model', 
              '(C) change in raw model', 
              '(D) difference in change (bias adjusted / raw model)', 
              '(E) difference in change (downscaled / raw model)']
    labels = ['total precipitation [mm]', 
              'total precipitation [mm]', 
              'total precipitation [mm]', 
              'ratio [mm/mm]', 
              'ratio [mm/mm]']
    kwargs_list = [abs_plot_kwargs, 
                   abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]

from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
# all_pieces_copy = copy(all_pieces)
coastline_feature = NaturalEarthFeature('physical', 'coastline', '50m',
                                       edgecolor='black', facecolor='none')

for mod in diffbcrgraw_ds.model.values:

    if var == "tasmax":
        all_pieces = [
            rawclncoarse_ds[var].sel(model=mod,quantile=quant),
            diffbcrgraw_ds[var].sel(model=mod,quantile=quant), 
            # diffdnrgraw_ds[var].sel(model=mod,quantile=quant),
            diffdnscbc_ds[var].sel(model=mod,quantile=quant)
        ]
    elif var == "pr":
            all_pieces = [
                ref_quantile_ds[var].sel(quantile=quant),
                rawclnhistrg_ds[var].sel(model=mod,quantile=quant),
                rawclncoarse_ds[var].sel(model=mod,quantile=quant),
                diffbcrgraw_ds[var].sel(model=mod,quantile=quant), 
                # diffdnrgraw_ds[var].sel(model=mod,quantile=quant),
                diffdnscbc_ds[var].sel(model=mod,quantile=quant)
            ]
        

    fig, axes = plt.subplots(ncols=len(all_pieces), nrows=4, figsize=(36, 30), 
                             subplot_kw={'projection': ccrs.Robinson()})
    for i,_ in enumerate(all_pieces):
        for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
            k, kw, ax = (
                # all_pieces_copy[i].sel(season=sea), 
                titles[i], 
                kwargs_list[i], 
                axes[s,i]
            )
            # v['lat'].attrs = dict()
            # v['lon'].attrs = dict()
            ax.add_feature(coastline_feature)
            if nokw:
                kwstr="_noclims"
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i]),robust=True)# **kwargs_list[i])
            else:
                kwstr=""
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i]), **kwargs_list[i])
            #divider = make_axes_locatable(ax)
            #cax = divider.append_axes('right', size='3%', pad=0.15)
            #fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='precipitation (mm)')
            ax.set_title(titles[i] + " " + sea)
            ax.set_xlabel('')
            ax.set_ylabel('')
    #fig.suptitle(TITLE)
    fig.savefig(figure_3diagnostic_withdownscaled_output_file_path.format(season="allseason", 
                                                                          var=var,
                                                           quant=quant,
                                                           model=f"{mod}{kwstr}"), 
                                                           facecolor='white', bbox_inches='tight')

### Needs updating: Loop through GCMs - this just does `tasmax`. 3-panel. TODO update to be consistent with Fig3-4 GCM ensemble mean above that works on either `tasmax` or `pr` <-- only need this if adding panels for all GCMs into supplemental

In [None]:
quant = 0.95

plot_kwargs = dict(vmin=0, vmax=10)
diff_plot_kwargs = dict(vmin=-1, vmax=1)
abs_plot_kwargs = dict(vmin=230, vmax=315)
titles = [
    "a. change in model",
    "b. difference in change (biascorrected - model)",
    "c. difference in change (downscaled - biascorrected)",
]
kwargs_list = [plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
from copy import copy

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


# loop through processed models

for mod in diffbcrgraw_ds.model.values:
    
    all_pieces = [
        rawcln[mod], 
        diffbcrgraw_ds.sel(model=mod), 
        diffdnscdb_ds.sel(model=mod)
    ]
    

    all_pieces_copy = copy(all_pieces)
    fig, axes = plt.subplots(
        ncols=len(all_pieces_copy),
        nrows=1,
        figsize=(24, 6),
        subplot_kw={"projection": ccrs.Robinson()},
    )
    coastline_feature = NaturalEarthFeature(
        "physical", "coastline", "50m", edgecolor="black", facecolor="none"
    )
    for i, _ in enumerate(all_pieces_copy):
        v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature)
        im = all_pieces[i].sel(quantile=quant).plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label="temperature (K)",
            ),
            **kwargs_list[i]
        )
        ax.set_title(titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
        
    fig.savefig(figure_3diagnostic_output_file_path.format(model=mod,quant=quant), 
            facecolor="white", bbox_inches="tight")

### Deprecated: Just `tasmax`. Loop through GCMs, 2-panel

In [None]:
coastline_feature = NaturalEarthFeature(
    "physical", "coastline", "10m", edgecolor="black", facecolor="none"
)
titles = ["a. change in model", "b. difference in change (biascorrected - model)"]

for mod in diffbcrgraw_ds.model.values:
    
    fig, axes = plt.subplots(
        ncols=2,
        nrows=1,
        figsize=(18, 4),
        subplot_kw={"projection": ccrs.Robinson()},
        dpi=200,
    )
    
    all_pieces = [
        rawcln[mod], 
        diffbcrgraw_ds.sel(model=mod), 
        diffdnscdb_ds.sel(model=mod)
    ]

    all_pieces_copy = copy(all_pieces)
    for i, _ in enumerate(all_pieces_copy[:2]):
        v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature, linewidth=0.2)
        da = all_pieces[i].sel(quantile=quant)
        vmax = da.max().item()
        vmin = da.min().item()

        if i == 1:
            vmin, vmax, amax = -3, 3, 3
        elif i == 0:
            vmin, vmax, amax = -12, 12, 12

        amax = max(abs(vmax), abs(vmin))
        norm = matplotlib.colors.Normalize(vmin=-amax, vmax=amax)

        im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            clim=(vmin, vmax),
            vmin=-amax,
            vmax=amax,
            cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label="temperature (K)",
            ),
        )
        ax.set_title(titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
    fig.set_facecolor("white")
    fig.savefig(figure_3_output_file_path.format(model=mod,quant=quant), 
                facecolor="white", bbox_inches="tight")

### Deprecated: Old Supp figure - 1-panel showing downscaling vs bias corrected

In [None]:

for mod in diffbcrgraw_ds.model.values:
    
    
    all_pieces = [
        rawcln[mod], 
        diffbcrgraw_ds.sel(model=mod), 
        diffdnscdb_ds.sel(model=mod)
    ]

    all_pieces_copy = copy(all_pieces)

    fig, axes = plt.subplots(
        ncols=1,
        nrows=1,
        figsize=(9, 4),
        subplot_kw={"projection": ccrs.Robinson()},
        dpi=200,
    )
    axes = np.array([axes]).reshape((1,))
    coastline_feature = NaturalEarthFeature(
        "physical", "coastline", "10m", edgecolor="black", facecolor="none"
    )
    titles = ["difference in change (downscaled - biascorrected)"]
    for j, (i, _) in enumerate(list(enumerate(all_pieces_copy))[2:]):
        v, k, kw, ax = all_pieces_copy[i], titles[j], kwargs_list[i], axes[j]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature, linewidth=0.2)
        da = all_pieces[i].sel(quantile=quant)
        # vmax = da.max().item()
        # vmin = da.min().item()
        # amax = max(abs(vmax), abs(vmin))
        vmin, vmax, amax = -1, 1, 1

        im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            clim=(vmin, vmax),
            norm=matplotlib.colors.Normalize(vmin=-amax, vmax=amax),
            cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label="temperature (K)",
            ),
        )
        ax.set_title(titles[j])
        ax.set_xlabel("")
        ax.set_ylabel("")
    fig.set_facecolor("white")
    fig.savefig(figure_a2_output_file_path.format(model=mod,quant=quant), 
                facecolor="white", bbox_inches="tight")

In [None]:
fig, axes = plt.subplots(
    ncols=1,
    nrows=1,
    figsize=(9, 4),
    subplot_kw={"projection": ccrs.Robinson()},
    dpi=200,
)
axes = np.array([axes]).reshape((1,))
coastline_feature = NaturalEarthFeature(
    "physical", "coastline", "50m", edgecolor="black", facecolor="none"
)
titles = ["difference in change (downscaled - biascorrected)"]
for j, (i, _) in enumerate(list(enumerate(all_pieces_copy))[2:]):
    v, k, kw, ax = all_pieces_copy[i], titles[j], kwargs_list[i], axes[j]
    v["lat"].attrs = dict()
    v["lon"].attrs = dict()
    ax.add_feature(coastline_feature, linewidth=0.2)
    da = all_pieces[i].sel(quantile=quant)
    # vmax = da.max().item()
    # vmin = da.min().item()
    # amax = max(abs(vmax), abs(vmin))
    vmin, vmax, amax = -1, 1, 1

    im = da.plot(
        add_colorbar=True,
        ax=ax,
        transform=ccrs.PlateCarree(),
        clim=(vmin, vmax),
        norm=matplotlib.colors.SymLogNorm(0.1, vmin=-amax, vmax=amax),
        cmap="RdBu_r",
        cbar_kwargs=dict(
            fraction=0.046,
            pad=0.04,
            orientation="vertical",
            extend="both",
            label="temperature (K)",
        ),
    )
    ax.set_title(titles[j])
    ax.set_xlabel("")
    ax.set_ylabel("")
fig.set_facecolor("white")

In [None]:
quant

In [None]:
client.close(), cluster.close()