# Make Figures: Notebook for tasmax, pr seasonal quantile trend maps for all GCMs in GDPCIR paper

* Copy `tasmax_seasonal_quantile_trend_allmodels.ipynb` to this notebook and copy functionality from `precip_seasonal_quantile_trend.ipynb`. Make it work for tasmax and pr for all seasons, quantiles, and GCMs.
* Intermediate data is now prepped and saved in `fig_3-4_makedata_tasmax_pr_seasonal_quantile_trend_allmodels.ipynb`


In [None]:
import os

# put this wherever you want
FIGURE_OUTPUT_DIR = "/gcs/impactlab-data/climate/downscaling/qc/kelly_diagnostics/figure3-4/images"#{var}/{kwstr}" 

REPO_ROOT = "../../"
assert "notebooks" in os.listdir(REPO_ROOT)

figure_3_output_file_path = os.path.join(
    FIGURE_OUTPUT_DIR,
    "figure_3_{var}_summer_q{quant}_trend_with_biascorrected_clipped_{model}.png",
)
figure_3diagnostic_output_file_path = os.path.join(
    FIGURE_OUTPUT_DIR,
    "figure_3-4_{var}_{season}_q{quant}_trend_with_biascorrected_downscaled_clipped_{model}.png",
)
figure_3diagnostic_withdownscaled_output_file_path = os.path.join(
    FIGURE_OUTPUT_DIR,
    "figure_3_{var}_{season}_q{quant}_trend_with_biascorrected_downscaledvraw_clipped_{model}.png",
)


figure_a2_output_file_path = os.path.join(
    FIGURE_OUTPUT_DIR,
    "figure_a2_tasmax_summer_q{quant}_trend_diff_linear_{model}.png",
)

fps_yaml_path = os.path.join(
    REPO_ROOT,
    "notebooks/downscaling_pipeline/post_processing_and_delivery/data_paths.yaml",
)


bucket_mapping_oregon_trail = {
    "biascorrected-492e989a": "biascorrected-4a21ed18",
    "clean-b1dbca25": "clean-f1e04ef5",
    "downscaled-288ec5ac": "downscaled-48ec31ab",
    "raw-305d04da": "raw-957d115e",
    "support-c23ff1a3": "support-f8a48a9e",
}

# NEW support BUCKET
BUCKET = 'support-f8a48a9e'
DS_BUCKET = 'downscaled-48ec31ab'


In [None]:
import cartopy.crs as ccrs
import dask
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from datatree import DataTree
import yaml
from cartopy.feature import NaturalEarthFeature
from rhg_compute_tools import kubernetes as rhgk

import sys
sys.path.insert(1, '../downscaling_pipeline/post_processing_and_delivery/')
import dc6_functions


# Start here if intermediate data is saved in `fig_3-4_makedata_pr_seasonal_quantile_trend_allmodels.ipynb` 

In [None]:
var = "pr" #"tasmax"
fut_scenario = "ssp370"
threshstr = "drythresh10_"#"wetthresh1_"

In [None]:
import datatree as dt

# open data files - this is kinda slow, in particular because of loading the datatrees I think
# about 10-11GB total
SCRATCH = "/gcs/impactlab-data/climate/downscaling/qc/kelly_diagnostics/figure3-4"#impactlab-data-scratch"
SAVEDIR = f"{SCRATCH}"#/gdpcir-diagnostics/figure3-4"

# if var=="tasmax":
#     fignum="3"
# elif var=="pr":
fignum="3-4"
    
rawclean_dt = dt.open_datatree(f"{SAVEDIR}/figure_3-4_{fut_scenario}_{var}_allseason_{threshstr}quantiles_trends_rawcleaned_allgcms.zarr", engine="zarr").load()

diffbcrgraw_ds = xr.open_dataset(
    f"{SAVEDIR}/figure_{fignum}_{fut_scenario}_{var}_allseason_{threshstr}quantiles_trends_biascorrected_v_cleaned_allgcms.nc").load()
diffdnscbc_ds = xr.open_dataset(
    f"{SAVEDIR}/figure_{fignum}_{fut_scenario}_{var}_allseason_{threshstr}quantiles_trends_downscaled_v_biascorrected_allgcms.nc").load()
diffdnrgraw_ds = xr.open_dataset(
    f"{SAVEDIR}/figure_{fignum}_{fut_scenario}_{var}_allseason_{threshstr}quantiles_trends_downscaled_v_cleaned_allgcms.nc").load()
rawclncoarse_ds = xr.open_dataset(
    f"{SAVEDIR}/figure_3_{fut_scenario}_{var}_allseason_{threshstr}quantiles_trends_raw_cleaned_coarse_allgcms.nc").load()

In [None]:
if var=="pr": # not needed for tasmax
    # rawcleanhist_dt = dt.open_datatree(f"{SAVEDIR}/figure_{fignum}_historical_{var}_allseason_{threshstr}quantiles_rawcleaned_allgcms.zarr", engine="zarr").load() # unused
    rawclnhistrg_ds = xr.open_dataset(f"{SAVEDIR}/figure_{fignum}_historical_{var}_allseason_{threshstr}quantiles_raw_cleaned_fine_allgcms.nc").load()
    
# /srv/conda/envs/notebook/lib/python3.10/site-packages/datatree/io.py:88: RuntimeWarning: Failed to open Zarr store with consolidated metadata, but successfully read with non-consolidated metadata. This is typically much slower for opening a dataset. To silence this warning, consider:
# 1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
# 2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
# 3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
#   ds = open_dataset(store, engine="zarr", **kwargs)

## Make figures
Figure files are saved to disk for each GCM and for the mean across GCMs.

These cells must be run manually per-quantile to save different quantile figures.

### GCM mean

#### This is the figure in initially submitted paper (but an ensemble mean instead of one GCM)

In [None]:
# quant = 0.95
# sea = "JJA"

# if var == "tasmax":
#     plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
#     diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
#     abs_plot_kwargs = dict(vmin=230, vmax=315)
#     titles = [
#         "a. change in raw model",
#         "b. difference in change (bias adjusted - raw model)",
#         "c. difference in change (downscaled - bias adjusted)",
#     ]
#     titles = ["a. change in raw model", "b. difference in change (bias adjusted - raw model)"]
#     clabel = "temperature (C)"
# elif var == "pr":
#     plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
#     diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
#     abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
#     titles = [
#         "a. change in raw model",
#         "b. difference in change (bias adjusted / raw model)",
#         "c. difference in change (downscaled / bias adjusted)",
#     ]
#     titles = ["a. change in raw model", "b. difference in change (bias adjusted / raw model)"]
#     clabel = "precipitation (mm/day)"
    
# kwargs_list = [plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
# from copy import copy

# from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


# all_pieces = [
#     rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
#     diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant, season=sea), 
#     diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant, season=sea)
# ]
    

# all_pieces_copy = copy(all_pieces)
# if 1:    
    
#     coastline_feature = NaturalEarthFeature(
#         "physical", "coastline", "10m", edgecolor="black", facecolor="none"
#     )
    
#     fig, axes = plt.subplots(
#         ncols=2,
#         nrows=1,
#         figsize=(18, 4),
#         subplot_kw={"projection": ccrs.Robinson()},
#         dpi=200,
#     )
    
#     for i, _ in enumerate(all_pieces_copy[:2]):
#         v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
#         v["lat"].attrs = dict()
#         v["lon"].attrs = dict()
#         ax.add_feature(coastline_feature, linewidth=0.2)
#         da = all_pieces[i]
#         # vmax = da.max().item()
#         # vmin = da.min().item()

# #         if i == 1:
# #             vmin, vmax, amax = -3, 3, 3
# #         elif i == 0:
# #             vmin, vmax, amax = -12, 12, 12

# #         amax = max(abs(vmax), abs(vmin))
# #         norm = matplotlib.colors.Normalize(vmin=-amax, vmax=amax)

#         im = da.plot(
#             add_colorbar=True,
#             ax=ax,
#             transform=ccrs.PlateCarree(),
#             # clim=(vmin, vmax),
#             # vmin=-amax,
#             # vmax=amax,
#             # cmap="RdBu_r",
#             cbar_kwargs=dict(
#                 fraction=0.046,
#                 pad=0.04,
#                 orientation="vertical",
#                 extend="both",
#                 label=clabel,
#             ),
#             **kw
#         )
#         ax.set_title(titles[i])
#         ax.set_xlabel("")
#         ax.set_ylabel("")
#     fig.set_facecolor("white")
#     # fig.savefig(figure_3_output_file_path.format(var=var,model="GCMmean",quant=quant), 
#     #             facecolor="white", bbox_inches="tight")

#### add a panel showing ratio between error and raw trend

In [None]:
quant = 0.95
sea = "JJA"

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
    abs_plot_kwargs = dict(vmin=230, vmax=315)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted - raw model)",
        "c. difference in change (downscaled - bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted - raw model)"]
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    err_plot_kwargs = dict(cmap="RdBu", vmin=-2, vmax=2)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted / raw model)",
        "c. difference in change (downscaled / bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted / raw model)"]
    
kwargs_list = [plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
from copy import copy

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


all_pieces = [
    rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
    diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant, season=sea), 
    diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant, season=sea)
]
    

all_pieces_copy = copy(all_pieces)
if 1:    
    
    coastline_feature = NaturalEarthFeature(
        "physical", "coastline", "10m", edgecolor="black", facecolor="none"
    )
    
    fig, axes = plt.subplots(
        ncols=3,
        nrows=1,
        figsize=(21, 4),
        subplot_kw={"projection": ccrs.Robinson()},
        dpi=200,
    )
    
    for i, _ in enumerate(all_pieces_copy[:2]):
        v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature, linewidth=0.2)
        da = all_pieces[i]
        # vmax = da.max().item()
        # vmin = da.min().item()

#         if i == 1:
#             vmin, vmax, amax = -3, 3, 3
#         elif i == 0:
#             vmin, vmax, amax = -12, 12, 12

#         amax = max(abs(vmax), abs(vmin))
#         norm = matplotlib.colors.Normalize(vmin=-amax, vmax=amax)

        im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            # clim=(vmin, vmax),
            # vmin=-amax,
            # vmax=amax,
            # cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label=clabel,
            ),
            **kw
        )
        ax.set_title(titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
        
    ax = axes[2]
    ax.add_feature(coastline_feature, linewidth=0.2)
    da = (all_pieces[1]/all_pieces[0])
    im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            # clim=(vmin, vmax),
            # vmin=-amax,
            # vmax=amax,
            # cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label="ratio",
            ),
            **err_plot_kwargs#dict(cmap='RdBu_r', vmin=-.25, vmax=.25)
        )
    ax.set_title("ratio bias-adjustment error/raw trend")
    ax.set_xlabel("")
    ax.set_ylabel("")

    fig.set_facecolor("white")
    # fig.savefig(figure_3_output_file_path.format(var=var,model="GCMmean",quant=quant), 
    #             facecolor="white", bbox_inches="tight")

#### try histograms instead
TODO: only grab land grid cells

In [None]:
da.where(all_pieces[3]>1)

In [None]:
quant = 0.95
dolog=True

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
    abs_plot_kwargs = dict(vmin=230, vmax=315)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted - raw model)",
        "c. difference in change (downscaled - bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted - raw model)"]
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = [
        "a. change in raw model",
        "b. difference in change (bias adjusted / raw model)",
        "c. difference in change (downscaled / bias adjusted)",
    ]
    titles = ["a. change in raw model", "b. difference in change (bias adjusted / raw model)"]
    
kwargs_list = [plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
from copy import copy

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


    
    
fig, axes = plt.subplots(
    ncols=3,
    nrows=4,
    figsize=(36, 30),
    dpi=200,
)
for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
        
    all_pieces = [
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant, season=sea), 
            diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
            rawclnhistrg_ds[var].mean(dim="model").sel(quantile=quant, season=sea),
    ]

    for i, _ in enumerate(all_pieces[:2]):
        print(i)
        v, k, kw, ax = all_pieces[i], titles[i], kwargs_list[i], axes[s,i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        da = all_pieces[i]

        infs = da.where(np.isinf(da),drop=True)
        print(infs)
        im = da.where(~np.isinf(da)).plot.hist(
                ax=ax,
                bins=150,
                density=True,
            log=dolog,
                # **kw
        ) # .where(all_pieces[3]>1) <--- adding this threw an exception
        
        ax.set_title(sea+ ": " + titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
    ax=axes[s,2]
        # PLOT HIST OF RATIO
        # da = (all_pieces[1]/all_pieces[0])
        # im = da.plot.hist(
        #     ax=ax,
        #     bins=150,
        #     density=True,
        # )
        # ax.set_title("ratio bias-adjustment error/raw trend")

        # OR PLOT SCATTER OF ERROR AGAINST RAW TREND
    lats,lons = np.meshgrid(all_pieces[0].lat,all_pieces[0].lon)
    im = ax.scatter(all_pieces[0],all_pieces[1], 
                        s=2, 
                        alpha=0.5, 
                        c=lons,
                        cmap="twilight",
                    )
    ax.set_title(f"{sea}: raw trend vs bias adjustment error")
    ax.set_ylabel("bias adjustment error")
    ax.set_xlabel("raw GCM trend")
    plt.colorbar(im, **dict(
                    fraction=0.046,
                    pad=0.04,
                    orientation="vertical",
                    extend="both",
                    label="longitude",
            ))
fig.set_facecolor("white")
fpath=FIGURE_OUTPUT_DIR.format(var=var)
fn = f"{fpath}/figure_3-4_{var}_allseason_q{quant}_trend_vs_biascorrectederror_histograms_GCMmean_log{str(dolog)}.png"
fig.savefig(fn, facecolor="white", bbox_inches="tight")

#### Paper Fig 3 (tasmax) and 4 (pr) - make the version of the figure that goes into the paper (tasmax and precip are slightly different). Include all seasons, loop through quantiles

In [None]:
if var=="pr":
    ref_quantile_ds = xr.open_dataset(f"{SAVEDIR}/figure_3-4_ERA5_pr_allseason_{threshstr}quantiles_fine_reference.nc").load()

### `pr` (Figure 4 & supplementary)

In [None]:
threshstr

In [None]:
printtofile=True
nokw = False
quantiles = [0.95, 0.99]

fontsize = 10

figsize=(20.5,12)
plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
diffdiff_plot_kwargs = dict(cmap='RdBu', vmin=.75, vmax=1.25)
if threshstr=="drythresh10_":
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=15)
else:
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
titles = ['Reference \n', 
              'Raw GCM \n', 
              'Change in raw GCM \n(2080-2100 / 1995-2014)', 
              'Difference in change \n(bias adjusted / raw GCM)', 
              'Difference in change \n(downscaled / bias adjusted)']
labels = ['Precipitation (mm day$^-1$)',
          'Precipitation (mm day$^-1$)', 
          'Ratio', 
          'Ratio of differences', 
          'Ratio of differences']
kwargs_list = [abs_plot_kwargs, 
                   abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diffdiff_plot_kwargs]
    
from string import ascii_lowercase as alc


from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

coastline_feature = NaturalEarthFeature('physical', 'coastline', '50m',
                                       edgecolor='black', facecolor='none')

for quant in quantiles:
    print(f"Doing quantile={quant}")

    all_pieces = [
            ref_quantile_ds[var].sel(quantile=quant),
            rawclnhistrg_ds[var].mean(dim="model").sel(quantile=quant),
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
            diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
        ]
        
    fig, axes = plt.subplots(ncols=len(all_pieces), nrows=4, figsize=figsize, 
                             subplot_kw={'projection': ccrs.Robinson()})
    for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
        for i,_ in enumerate(all_pieces):
        
            # k, kw, ax = (
            #     # all_pieces_copy[i].sel(season=sea), 
            #     f"{alc[i+s]}"+titles[i], 
            #     kwargs_list[i], 
            #     axes[s,i]
            # )
            ax = axes[s,i]
            # v['lat'].attrs = dict()
            # v['lon'].attrs = dict()
            ax.add_feature(coastline_feature)
            if nokw:
                kwstr="_noclim"
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i], ),
                                                        )
            else:
                kwstr=""
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i], ),
                                                        **kwargs_list[i])
            #divider = make_axes_locatable(ax)
            #cax = divider.append_axes('right', size='3%', pad=0.15)
            #fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='precipitation (mm)')
            ax.set_title(f"({alc[i+s*len(all_pieces)]}) "+ sea + ": " + titles[i], fontsize=fontsize)
            ax.set_xlabel('')
            ax.set_ylabel('')
    #fig.suptitle(TITLE)
    if printtofile:
        fn=figure_3diagnostic_output_file_path.format(var=var,
                                                      season=f"allseason_{threshstr}", 
                                                      quant=quant,
                                                      model="GCMmean2",
                                                      kwstr=kwstr)
        fig.savefig(fn, facecolor='white', bbox_inches='tight',dpi=300)
        print(f"saved {fn}")

### `tasmax` (Figure 3 & supplementary) - code from Brewster

In [None]:
from string import ascii_lowercase as alc


var = "tasmax"
printtofile=True
nokw = False
quantiles = [0.95, 0.99]

fontsize=10
figsize = (13, 13)
plot_kwargs = dict(cmap='RdBu_r',vmin=-12, vmax=12)
diff_plot_kwargs = dict(cmap='RdBu_r',vmin=-3, vmax=3)
diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.2, vmax=.2)
titles = [
    'Change in raw GCM\n',
    'Difference in change \n(bias adjusted - raw GCM)',
    'Difference in change \n(downscaled - bias adjusted)']
labels = [
    'Max Temperature (°C)',
    'Max Temperature (°C)',
    'Max Temperature (°C)']
kwargs_list = [
    plot_kwargs,
    diff_plot_kwargs,
    diffdiff_plot_kwargs]


coastline_feature = NaturalEarthFeature('physical', 'coastline', '110m',
                                        edgecolor='black', facecolor='none')

for quant in quantiles:
    print(f"Doing quantile={quant}")

    all_pieces = [
        rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
        diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant),
        diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
    ]

    fig, axes = plt.subplots(ncols=len(all_pieces), nrows=4, figsize=figsize,
                             subplot_kw={'projection': ccrs.Robinson()})
    for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
        for i,_ in enumerate(all_pieces):

            k, kw, ax = (
                f"{alc[i+s]}"+titles[i],
                kwargs_list[i],
                axes[s,i]
            )
            ax.add_feature(coastline_feature)
            if nokw:
                kwstr="_noclim"
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(),
                                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical',
                                                                         extend='both', label=labels[i], ),
                                                        )
            else:
                kwstr=""
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(),
                                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical',
                                                                         extend='both', label=labels[i], ),
                                                        **kwargs_list[i])
            ax.set_title(f"{alc[i+s*len(all_pieces)]}. "+ sea + ": " + titles[i], fontsize=fontsize)
            ax.set_xlabel('')
            ax.set_ylabel('')
    if printtofile:
        fn=figure_3diagnostic_output_file_path.format(var=var,
                                                      season="allseason",
                                                      quant=quant,
                                                      model="GCMmean",
                                                      kwstr=kwstr)
        fig.savefig(fn, facecolor='white', bbox_inches='tight', dpi=300)
        print(f"saved {fn}")
        plt.close()

In [None]:
printtofile=False
nokw = False
quantiles = [0.01, 0.05, 0.5, 0.95, 0.99]

fontsize = 16

if var == "tasmax":
    figsize=(36, 30)
    plot_kwargs = dict(cmap='RdBu_r',vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r',vmin=-3, vmax=3)
    diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis',vmin=230, vmax=315)
    titles = [
        # '(A) reference', 
        # '(B) raw model', 
        'Change in raw GCM', 
        'Difference in change \n(bias adjusted - raw GCM)', 
        'Difference in change \n(downscaled - bias adjusted)']
    labels = [#'max temperature [K]', 
              #'max temperature [K]', 
              'max temperature [C]', 
              'max temperature [C]', 
              'max temperature [C]']
    kwargs_list = [#abs_plot_kwargs, 
                   #abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diffdiff_plot_kwargs]
elif var == "pr":
    figsize=(36,28)
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diffdiff_plot_kwargs = dict(cmap='RdBu', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = ['Reference \n', 
              'Raw GCM \n', 
              'Change in raw GCM \n', 
              'Difference in change \n(bias adjusted / raw GCM)', 
              'Difference in change \n(downscaled / bias adjusted)']
    labels = ['total precipitation [mm]', 
              'total precipitation [mm]', 
              'total precipitation [mm]', 
              'ratio [mm/mm]', 
              'ratio [mm/mm]']
    kwargs_list = [abs_plot_kwargs, 
                   abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]
    
# from string import ascii_uppercase as alc
from string import ascii_lowercase as alc

# plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
# diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
# diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
# abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
# all_pieces = [
#     # reference_quantile, 
#     # raw_cleaned_historical_regridded, 
#     raw_cleaned_trend, 
#     ratio_regriddedbiascorrected_regriddedraw, 
#     ratio_downscaled_biascorrected
# ]
# all_pieces = [
#     rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
#     diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
#     diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
# ]

from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
# all_pieces_copy = copy(all_pieces)
coastline_feature = NaturalEarthFeature('physical', 'coastline', '50m',
                                       edgecolor='black', facecolor='none')

for quant in quantiles[-2:]:
    print(f"Doing quantile={quant}")

    if var == "tasmax":
        all_pieces = [
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
            diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
        ]
    elif var == "pr":
        all_pieces = [
            ref_quantile_ds[var].sel(quantile=quant),
            rawclnhistrg_ds[var].mean(dim="model").sel(quantile=quant),
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
            diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
        ]
        
    fig, axes = plt.subplots(ncols=len(all_pieces), nrows=4, figsize=figsize, 
                             subplot_kw={'projection': ccrs.Robinson()})
    for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
        for i,_ in enumerate(all_pieces):
        
            # k, kw, ax = (
            #     # all_pieces_copy[i].sel(season=sea), 
            #     f"{alc[i+s]}"+titles[i], 
            #     kwargs_list[i], 
            #     axes[s,i]
            # )
            ax = axes[s,i]
            # v['lat'].attrs = dict()
            # v['lon'].attrs = dict()
            ax.add_feature(coastline_feature)
            if nokw:
                kwstr="_noclim"
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i], ),
                                                        )
            else:
                kwstr=""
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i], ),
                                                        **kwargs_list[i])
            #divider = make_axes_locatable(ax)
            #cax = divider.append_axes('right', size='3%', pad=0.15)
            #fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='precipitation (mm)')
            ax.set_title(f"({alc[i+s*len(all_pieces)]}) "+ sea + ": " + titles[i], fontsize=fontsize)
            ax.set_xlabel('')
            ax.set_ylabel('')
    #fig.suptitle(TITLE)
    if printtofile:
        fn=figure_3diagnostic_output_file_path.format(var=var,
                                                      season=f"allseason_{threshstr}", 
                                                      quant=quant,
                                                      model="GCMmean",
                                                      kwstr=kwstr)
        fig.savefig(fn, facecolor='white', bbox_inches='tight')
        print(f"saved {fn}")

### Make a version of Fig 3-4 (ensemble mean) that swaps in downscaled - raw model for the last panel. Clim keywords are off. <-- Not using this figure in paper.

In [None]:
nokw=True
quant = 0.99

plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r',vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r',vmin=-3, vmax=3)
    diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis',vmin=230, vmax=315)
    titles = [
        # '(A) reference', 
        # '(B) raw model', 
        '(A) change in raw model', 
        '(B) difference in change (bias adjusted - raw model)', 
        '(C) difference in change (downscaled - raw model)']
    labels = [#'max temperature [K]', 
              #'max temperature [K]', 
              'max temperature [K]', 
              'max temperature [K]', 
              'max temperature [K]']
    kwargs_list = [#abs_plot_kwargs, 
                   #abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diffdiff_plot_kwargs = dict(cmap='RdBu', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = ['(A) reference', 
              '(B) raw model', 
              '(C) change in raw model', 
              '(D) difference in change (bias adjusted / raw model)', 
              '(E) difference in change (downscaled / raw model)']
    labels = ['total precipitation [mm]', 
              'total precipitation [mm]', 
              'total precipitation [mm]', 
              'ratio [mm/mm]', 
              'ratio [mm/mm]']
    kwargs_list = [abs_plot_kwargs, 
                   abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]

from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
# all_pieces_copy = copy(all_pieces)
coastline_feature = NaturalEarthFeature('physical', 'coastline', '50m',
                                       edgecolor='black', facecolor='none')

for quant in [0.99]:# quantiles:
    print(f"Doing quantile={quant}")

    if var == "tasmax":
        all_pieces = [
            rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
            diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
            diffdnrgraw_ds[var].mean(dim="model").sel(quantile=quant),
            # diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
        ]
    elif var == "pr":
            all_pieces = [
                ref_quantile_ds[var].sel(quantile=quant),
                rawclnhistrg_ds[var].mean(dim="model").sel(quantile=quant),
                rawclncoarse_ds[var].mean(dim="model").sel(quantile=quant),
                diffbcrgraw_ds[var].mean(dim="model").sel(quantile=quant), 
                diffdnrgraw_ds[var].mean(dim="model").sel(quantile=quant),
            # diffdnscbc_ds[var].mean(dim="model").sel(quantile=quant)
            ]
        

    fig, axes = plt.subplots(ncols=len(all_pieces), nrows=4, figsize=(36, 30), 
                             subplot_kw={'projection': ccrs.Robinson()})
    for i,_ in enumerate(all_pieces):
        for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
            k, kw, ax = (
                # all_pieces_copy[i].sel(season=sea), 
                titles[i], 
                kwargs_list[i], 
                axes[s,i]
            )
            # v['lat'].attrs = dict()
            # v['lon'].attrs = dict()
            ax.add_feature(coastline_feature)
            if nokw:
                kwstr="noclims"
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i]),robust=True)# **kwargs_list[i])
            else:
                kwstr=""
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i]), **kwargs_list[i])
            #divider = make_axes_locatable(ax)
            #cax = divider.append_axes('right', size='3%', pad=0.15)
            #fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='precipitation (mm)')
            ax.set_title(titles[i] + " " + sea)
            ax.set_xlabel('')
            ax.set_ylabel('')
    #fig.suptitle(TITLE)
    fig.savefig(figure_3diagnostic_withdownscaled_output_file_path.format(season="allseason", 
                                                                          var=var,
                                                                           quant=quant,
                                                                           model=f"GCMmean{kwstr}",
                                                                         kwstr=kwstr), 
                                                           facecolor='white', bbox_inches='tight')

### Figures for all GCMs. With clim keywords. With downscaled vs raw in last panel

In [None]:
nokw=False
quant = 0.95

plot_kwargs = dict(cmap='RdBu_r', vmin=-12, vmax=12)
diff_plot_kwargs = dict(cmap='RdBu_r', vmin=-3, vmax=3)
diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)

if var == "tasmax":
    plot_kwargs = dict(cmap='RdBu_r',vmin=-12, vmax=12)
    diff_plot_kwargs = dict(cmap='RdBu_r',vmin=-3, vmax=3)
    diffdiff_plot_kwargs = dict(cmap='RdBu_r', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis',vmin=230, vmax=315)
    titles = [
        # '(A) reference', 
        # '(B) raw model', 
        '(A) change in raw model', 
        '(B) difference in change (bias adjusted - raw model)', 
        '(C) difference in change (downscaled - raw model)']
    labels = [#'max temperature [K]', 
              #'max temperature [K]', 
              'max temperature [K]', 
              'max temperature [K]', 
              'max temperature [K]']
    kwargs_list = [#abs_plot_kwargs, 
                   #abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]
elif var == "pr":
    plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diff_plot_kwargs = dict(cmap='RdBu', vmin=0.25, vmax=1.75)
    diffdiff_plot_kwargs = dict(cmap='RdBu', vmin=-.5, vmax=.5)
    abs_plot_kwargs = dict(cmap='viridis', vmin=0, vmax=50)
    titles = ['(A) reference', 
              '(B) raw model', 
              '(C) change in raw model', 
              '(D) difference in change (bias adjusted / raw model)', 
              '(E) difference in change (downscaled / raw model)']
    labels = ['total precipitation [mm]', 
              'total precipitation [mm]', 
              'total precipitation [mm]', 
              'ratio [mm/mm]', 
              'ratio [mm/mm]']
    kwargs_list = [abs_plot_kwargs, 
                   abs_plot_kwargs, 
                   plot_kwargs, 
                   diff_plot_kwargs, 
                   diff_plot_kwargs]

from copy import copy
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
# all_pieces_copy = copy(all_pieces)
coastline_feature = NaturalEarthFeature('physical', 'coastline', '50m',
                                       edgecolor='black', facecolor='none')

for mod in diffbcrgraw_ds.model.values:

    if var == "tasmax":
        all_pieces = [
            rawclncoarse_ds[var].sel(model=mod,quantile=quant),
            diffbcrgraw_ds[var].sel(model=mod,quantile=quant), 
            # diffdnrgraw_ds[var].sel(model=mod,quantile=quant),
            diffdnscbc_ds[var].sel(model=mod,quantile=quant)
        ]
    elif var == "pr":
            all_pieces = [
                ref_quantile_ds[var].sel(quantile=quant),
                rawclnhistrg_ds[var].sel(model=mod,quantile=quant),
                rawclncoarse_ds[var].sel(model=mod,quantile=quant),
                diffbcrgraw_ds[var].sel(model=mod,quantile=quant), 
                # diffdnrgraw_ds[var].sel(model=mod,quantile=quant),
                diffdnscbc_ds[var].sel(model=mod,quantile=quant)
            ]
        

    fig, axes = plt.subplots(ncols=len(all_pieces), nrows=4, figsize=(36, 30), 
                             subplot_kw={'projection': ccrs.Robinson()})
    for i,_ in enumerate(all_pieces):
        for s,sea in enumerate(["DJF","MAM","JJA","SON"]):
            k, kw, ax = (
                # all_pieces_copy[i].sel(season=sea), 
                titles[i], 
                kwargs_list[i], 
                axes[s,i]
            )
            # v['lat'].attrs = dict()
            # v['lon'].attrs = dict()
            ax.add_feature(coastline_feature)
            if nokw:
                kwstr="_noclims"
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i]),robust=True)# **kwargs_list[i])
            else:
                kwstr=""
                im = all_pieces[i].sel(season=sea).plot(add_colorbar=True, ax=ax, transform=ccrs.PlateCarree(), 
                                        cbar_kwargs=dict(fraction=0.046, pad=0.04,orientation='vertical', 
                                                         extend='both', label=labels[i]), **kwargs_list[i])
            #divider = make_axes_locatable(ax)
            #cax = divider.append_axes('right', size='3%', pad=0.15)
            #fig.colorbar(im, cax=cax, orientation='vertical', extend='both', label='precipitation (mm)')
            ax.set_title(titles[i] + " " + sea)
            ax.set_xlabel('')
            ax.set_ylabel('')
    #fig.suptitle(TITLE)
    fig.savefig(figure_3diagnostic_withdownscaled_output_file_path.format(season="allseason", 
                                                                          var=var,
                                                           quant=quant,
                                                           model=f"{mod}{kwstr}"), 
                                                           facecolor='white', bbox_inches='tight')

### Needs updating: Loop through GCMs - this just does `tasmax`. 3-panel. TODO update to be consistent with Fig3-4 GCM ensemble mean above that works on either `tasmax` or `pr` <-- only need this if adding panels for all GCMs into supplemental

In [None]:
quant = 0.95

plot_kwargs = dict(vmin=0, vmax=10)
diff_plot_kwargs = dict(vmin=-1, vmax=1)
abs_plot_kwargs = dict(vmin=230, vmax=315)
titles = [
    "a. change in model",
    "b. difference in change (biascorrected - model)",
    "c. difference in change (downscaled - biascorrected)",
]
kwargs_list = [plot_kwargs, diff_plot_kwargs, diff_plot_kwargs]
from copy import copy

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


# loop through processed models

for mod in diffbcrgraw_ds.model.values:
    
    all_pieces = [
        rawcln[mod], 
        diffbcrgraw_ds.sel(model=mod), 
        diffdnscdb_ds.sel(model=mod)
    ]
    

    all_pieces_copy = copy(all_pieces)
    fig, axes = plt.subplots(
        ncols=len(all_pieces_copy),
        nrows=1,
        figsize=(24, 6),
        subplot_kw={"projection": ccrs.Robinson()},
    )
    coastline_feature = NaturalEarthFeature(
        "physical", "coastline", "50m", edgecolor="black", facecolor="none"
    )
    for i, _ in enumerate(all_pieces_copy):
        v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature)
        im = all_pieces[i].sel(quantile=quant).plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label="temperature (K)",
            ),
            **kwargs_list[i]
        )
        ax.set_title(titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
        
    fig.savefig(figure_3diagnostic_output_file_path.format(model=mod,quant=quant), 
            facecolor="white", bbox_inches="tight")

### Deprecated: Just `tasmax`. Loop through GCMs, 2-panel

In [None]:
coastline_feature = NaturalEarthFeature(
    "physical", "coastline", "10m", edgecolor="black", facecolor="none"
)
titles = ["a. change in model", "b. difference in change (biascorrected - model)"]

for mod in diffbcrgraw_ds.model.values:
    
    fig, axes = plt.subplots(
        ncols=2,
        nrows=1,
        figsize=(18, 4),
        subplot_kw={"projection": ccrs.Robinson()},
        dpi=200,
    )
    
    all_pieces = [
        rawcln[mod], 
        diffbcrgraw_ds.sel(model=mod), 
        diffdnscdb_ds.sel(model=mod)
    ]

    all_pieces_copy = copy(all_pieces)
    for i, _ in enumerate(all_pieces_copy[:2]):
        v, k, kw, ax = all_pieces_copy[i], titles[i], kwargs_list[i], axes[i]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature, linewidth=0.2)
        da = all_pieces[i].sel(quantile=quant)
        vmax = da.max().item()
        vmin = da.min().item()

        if i == 1:
            vmin, vmax, amax = -3, 3, 3
        elif i == 0:
            vmin, vmax, amax = -12, 12, 12

        amax = max(abs(vmax), abs(vmin))
        norm = matplotlib.colors.Normalize(vmin=-amax, vmax=amax)

        im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            clim=(vmin, vmax),
            vmin=-amax,
            vmax=amax,
            cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label="temperature (K)",
            ),
        )
        ax.set_title(titles[i])
        ax.set_xlabel("")
        ax.set_ylabel("")
    fig.set_facecolor("white")
    fig.savefig(figure_3_output_file_path.format(model=mod,quant=quant), 
                facecolor="white", bbox_inches="tight")

### Deprecated: Old Supp figure - 1-panel showing downscaling vs bias corrected

In [None]:

for mod in diffbcrgraw_ds.model.values:
    
    
    all_pieces = [
        rawcln[mod], 
        diffbcrgraw_ds.sel(model=mod), 
        diffdnscdb_ds.sel(model=mod)
    ]

    all_pieces_copy = copy(all_pieces)

    fig, axes = plt.subplots(
        ncols=1,
        nrows=1,
        figsize=(9, 4),
        subplot_kw={"projection": ccrs.Robinson()},
        dpi=200,
    )
    axes = np.array([axes]).reshape((1,))
    coastline_feature = NaturalEarthFeature(
        "physical", "coastline", "10m", edgecolor="black", facecolor="none"
    )
    titles = ["difference in change (downscaled - biascorrected)"]
    for j, (i, _) in enumerate(list(enumerate(all_pieces_copy))[2:]):
        v, k, kw, ax = all_pieces_copy[i], titles[j], kwargs_list[i], axes[j]
        v["lat"].attrs = dict()
        v["lon"].attrs = dict()
        ax.add_feature(coastline_feature, linewidth=0.2)
        da = all_pieces[i].sel(quantile=quant)
        # vmax = da.max().item()
        # vmin = da.min().item()
        # amax = max(abs(vmax), abs(vmin))
        vmin, vmax, amax = -1, 1, 1

        im = da.plot(
            add_colorbar=True,
            ax=ax,
            transform=ccrs.PlateCarree(),
            clim=(vmin, vmax),
            norm=matplotlib.colors.Normalize(vmin=-amax, vmax=amax),
            cmap="RdBu_r",
            cbar_kwargs=dict(
                fraction=0.046,
                pad=0.04,
                orientation="vertical",
                extend="both",
                label="temperature (K)",
            ),
        )
        ax.set_title(titles[j])
        ax.set_xlabel("")
        ax.set_ylabel("")
    fig.set_facecolor("white")
    fig.savefig(figure_a2_output_file_path.format(model=mod,quant=quant), 
                facecolor="white", bbox_inches="tight")

In [None]:
fig, axes = plt.subplots(
    ncols=1,
    nrows=1,
    figsize=(9, 4),
    subplot_kw={"projection": ccrs.Robinson()},
    dpi=200,
)
axes = np.array([axes]).reshape((1,))
coastline_feature = NaturalEarthFeature(
    "physical", "coastline", "50m", edgecolor="black", facecolor="none"
)
titles = ["difference in change (downscaled - biascorrected)"]
for j, (i, _) in enumerate(list(enumerate(all_pieces_copy))[2:]):
    v, k, kw, ax = all_pieces_copy[i], titles[j], kwargs_list[i], axes[j]
    v["lat"].attrs = dict()
    v["lon"].attrs = dict()
    ax.add_feature(coastline_feature, linewidth=0.2)
    da = all_pieces[i].sel(quantile=quant)
    # vmax = da.max().item()
    # vmin = da.min().item()
    # amax = max(abs(vmax), abs(vmin))
    vmin, vmax, amax = -1, 1, 1

    im = da.plot(
        add_colorbar=True,
        ax=ax,
        transform=ccrs.PlateCarree(),
        clim=(vmin, vmax),
        norm=matplotlib.colors.SymLogNorm(0.1, vmin=-amax, vmax=amax),
        cmap="RdBu_r",
        cbar_kwargs=dict(
            fraction=0.046,
            pad=0.04,
            orientation="vertical",
            extend="both",
            label="temperature (K)",
        ),
    )
    ax.set_title(titles[j])
    ax.set_xlabel("")
    ax.set_ylabel("")
fig.set_facecolor("white")

In [None]:
quant

In [None]:
client.close(), cluster.close()