# Validation for quantile scaling

In [1]:
import sys
import warnings
warnings.filterwarnings('ignore')

import xarray as xr
import cmocean
import geopandas as gp

sys.path.append('/g/data/xv83/quantile-mapping/qq-workflows')
import validation
sys.path.append('/g/data/xv83/quantile-mapping/qqscale')
import utils

Matplotlib is building the font cache; this may take a moment.


In [None]:
# Required parameters
assert 'nquantiles' in locals(), "Must provide the number of quantiles (option -p nquantiles {number})"
assert 'scaling' in locals(), "Must provide the scaling method (option -p scaling {name})"
assert 'hist_var' in locals(), "Must provide an historical variable name (option -p hist_var {name})"
assert 'ref_var' in locals(), "Must provide a reference variable name (option -p ref_var {name})"
assert 'target_var' in locals(), "Must provide a target variable name (option -p target_var {name})"
assert 'hist_units' in locals(), "Must provide historical units (option -p hist_units {units})"
assert 'ref_units' in locals(), "Must provide reference units (option -p ref_units {units})"
assert 'target_units' in locals(), "Must provide target units (option -p target_units {units})"
assert 'output_units' in locals(), "Must provide output units (option -p output_units {units})"
assert 'adjustment_file' in locals(), "Must provide an adjustment factors file (option -p adjustment_file {file path})"
assert 'hist_files' in locals(), """Must provide historical data files (option -p hist_files {"file paths"})"""
assert 'ref_files' in locals(), """Must provide reference data files (option -p ref_files {"file paths"})"""
assert 'target_files' in locals(), """Must provide target data files (option -p target_files {"file paths"})"""
assert 'qq_file' in locals(), "Must provide an qq-scaled data file (option -p qq_file {file path})"
assert 'hist_time_bounds' in locals(), """Must provide time bounds for historical data (option -p hist_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""
assert 'ref_time_bounds' in locals(), """Must provide time bounds for reference data (option -p ref_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""
assert 'target_time_bounds' in locals(), """Must provide time bounds for target data (option -p target_time_bounds {"YYYY-MM-DD YYYY-MM-DD"})"""

In [None]:
hist_files = hist_files.split()
ref_files = ref_files.split()
target_files = target_files.split()

hist_time_bounds = hist_time_bounds.split()
ref_time_bounds = ref_time_bounds.split()
target_time_bounds = target_time_bounds.split()

In [None]:
plot_config = {}
plot_config['plot_pdfs_flag'] = False
plot_config['plot_1d_quantiles_flag'] = True
plot_config['plot_1d_values_flag'] = True
mask_ocean = True
if hist_var == 'tasmin':
    plot_config['pdf_xbounds'] = (-10, 30)
    plot_config['pdf_ybounds'] = None
    plot_config['q_xbounds'] = (0, 100)
    plot_config['regular_cmap'] = cmocean.cm.thermal
    plot_config['diverging_cmap'] = 'RdBu_r'
    plot_config['general_levels'] = [-4.0, -2.5, -1, 0.5, 2, 3.5, 5, 6.5, 8, 9.5, 11, 12.5, 14, 15.5, 17, 18.5, 20, 21.5]
    plot_config['af_levels'] = None
    plot_config['difference_levels'] = [-2.5, -2.0, -1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5, 2.0, 2.5]
    plot_config['extreme_for_values'] = 'max'
    clim_extend = 'both'
elif hist_var == 'tasmax':
    plot_config['pdf_xbounds'] = (0, 45)
    plot_config['pdf_ybounds'] = None
    plot_config['q_xbounds'] = (0, 100)
    plot_config['regular_cmap'] = cmocean.cm.thermal
    plot_config['diverging_cmap'] = 'RdBu_r'
    plot_config['general_levels'] = [5, 7.5, 10, 12.5, 15, 17.5, 20, 22.5, 25, 27.5, 30, 32.5, 35]
    plot_config['af_levels'] = None
    plot_config['difference_levels'] = [-2.5, -2.0, -1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5, 2.0, 2.5]
    plot_config['extreme_for_values'] = 'max'
    clim_extend = 'both'
elif hist_var == 'pr':
    plot_config['pdf_xbounds'] = (5, 80)
    plot_config['pdf_ybounds'] = (0, 0.02)
    plot_config['regular_cmap'] = cmocean.cm.rain
    plot_config['diverging_cmap'] = 'BrBG'
    plot_config['general_levels'] = [0, 0.01, 0.25, 0.5, 1, 2, 5, 10, 20, 40, 60]
    plot_config['af_levels'] = [0.125, 0.25, 0.5, 0.67, 0.8, 1, 1.25, 1.5, 2, 4, 8]
    plot_config['difference_levels'] = [-150, -130, -110, -90, -70, -50, -30, -10, 10, 30, 50, 70, 90, 110, 130, 150]
    plot_config['extreme_for_values'] = 'max'
    clim_extend = 'max'
    plot_config['q_xbounds'] = (80, 100)
    if int(nquantiles) == 500:
        plot_config['q_xbounds'] = (96, 100)
    elif int(nquantiles) == 1000:
        plot_config['q_xbounds'] = (98, 100)
elif hist_var == 'rsds':
    plot_config['pdf_xbounds'] = (0, 300)
    plot_config['pdf_ybounds'] = None
    plot_config['q_xbounds'] = (0, 100)
    plot_config['regular_cmap'] = cmocean.cm.solar
    plot_config['diverging_cmap'] = 'RdBu_r'
    plot_config['general_levels'] = [115, 130, 145, 160, 175, 200, 225, 250, 275]
    plot_config['af_levels'] = None
    plot_config['difference_levels'] = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    plot_config['extreme_for_values'] = 'max'
    clim_extend = 'both'
elif 'hurs' in hist_var:
    plot_config['pdf_xbounds'] = (0, 100)
    plot_config['pdf_ybounds'] = None
    plot_config['q_xbounds'] = (0, 100)
    plot_config['regular_cmap'] = cmocean.cm.thermal
    plot_config['diverging_cmap'] = 'RdBu_r'
    plot_config['general_levels'] = [10, 15 ,20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90]
    plot_config['af_levels'] = None
    plot_config['difference_levels'] = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    plot_config['extreme_for_values'] = 'min'
    clim_extend = 'both'
elif hist_var in ['sfcWind', 'sfcWindmax']:
    plot_config['pdf_xbounds'] = (-5, 40)
    plot_config['pdf_ybounds'] = None
    plot_config['plot_pdfs_flag'] = True
    plot_config['q_xbounds'] = (0, 100)
    plot_config['regular_cmap'] = cmocean.cm.speed
    plot_config['diverging_cmap'] = 'RdBu_r'
    plot_config['general_levels'] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    plot_config['af_levels'] = None
    plot_config['difference_levels'] = [-1.0, -0.8, -0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8, 1.0]
    plot_config['extreme_for_values'] = 'max'
    clim_extend = 'max'
else:
    raise ValueError(f'No plotting configuration defined for {hist_var}')

In [None]:
city_lat_lon = {
    'Hobart': (-42.9, 147.3),
    'Melbourne': (-37.8, 145.0),
    'Mildura': (-34.2, 142.1),
    'Thredbo': (-36.5, 148.3),
    'Sydney': (-33.9, 151.2),
    'Brisbane': (-27.5, 153.0),
    'Cairns': (-16.9, 145.8),
    'Darwin': (-12.5, 130.8),
    'Alice Springs': (-23.7, 133.9),
    'Port Hedland': (-20.3, 118.6),
    'Karlamilyi National Park': (-22.7, 122.2),
    'Perth': (-32.0, 115.9),
    'Adelaide': (-34.9, 138.6),
}

## Read data

In [None]:
ds_adjust = xr.open_dataset(adjustment_file)

In [None]:
if 'month' in ds_adjust.dims:
    plot_config['months'] = [1,2,3,4,5,6,7,8,9,10,11,12]
    af_dims = ['quantiles', 'month']
else:
    plot_config['months'] = []
    af_dims = 'quantiles'

In [None]:
ds_hist = utils.read_data(
    hist_files,
    hist_var,
    time_bounds=hist_time_bounds,
    input_units=hist_units,
    output_units=output_units,
)
try:
    ds_hist = ds_hist.drop_vars('crs')
except ValueError:
    pass

In [None]:
ds_ref = utils.read_data(
    ref_files,
    ref_var,
    time_bounds=ref_time_bounds,
    input_units=ref_units,
    output_units=output_units,
)
try:
    ds_ref = ds_ref.drop_vars('crs')
except ValueError:
    pass

In [None]:
ds_target = utils.read_data(
    target_files,
    target_var,
    time_bounds=target_time_bounds,
    input_units=target_units,
    output_units=output_units
)
try:
    ds_target = ds_target.drop_vars('crs')
except ValueError:
    pass

In [None]:
ds_qq = xr.open_dataset(qq_file)
qq_vars = list(ds_qq.keys())
if target_var in qq_vars:
    qq_var = target_var
else:
    qq_var = ref_var
try:
    ds_qq = ds_qq.drop_vars('crs')
except ValueError:
    pass
    
if 'qq_clipped_file' in locals():
    ds_qq_clipped = xr.open_dataset(qq_clipped_file)
    try:
        ds_qq_clipped = ds_qq_clipped.drop_vars('crs')
    except ValueError:
        pass
    da_qq_clipped = ds_qq_clipped[qq_var]
else:
    ds_qq_clipped = None
    da_qq_clipped = None
    
if 'qq_cmatch_file' in locals():
    ds_qq_cmatch = xr.open_dataset(qq_cmatch_file)
    try:
        ds_qq_cmatch = ds_qq_cmatch.drop_vars('crs')
    except ValueError:
        pass
    da_qq_cmatch = ds_qq_cmatch[qq_var]
else:
    ds_qq_cmatch = None
    da_qq_cmatch = None

## Australia-wide

#### QDM

GCM change = ref (ssp) - hist  
QQ change = qq - target (obs)

#### eCDFm

GCM change = target (ssp) - hist  
QQ change = qq - ref (obs)

In [None]:
def select_months(ds, months=[]):
    """Select months from dataset.

    Parameters
    ----------
    ds : xarray Dataset or DataArray
    months : list
        Months to select (1-12)

    Returns
    -------
    ds_selection : xarray Dataset or DataArray
        Input dataset with month extracted
    """

    if months:
        da_months = ds['time'].dt.month
        time_selection = da_months.isin(months)
        ds_selection = ds.sel({'time': time_selection})
    else:
        ds_selection = ds

    return ds_selection

In [None]:
def get_comparison_data(season='annual'):
    """Get comparison data for plotting"""

    months = {
        'DJF': [1, 2, 12],
        'MAM': [3, 4, 5],
        'JJA': [6, 7, 8],
        'SON': [9, 10, 11],
        'annual': [],
    }
    ds_hist_selection = select_months(ds_hist, months[season])
    ds_ref_selection = select_months(ds_ref, months[season])
    ds_target_selection = select_months(ds_target, months[season])
    ds_qq_selection = select_months(ds_qq, months[season])
    if ds_qq_clipped is not None:
        ds_qq_clipped_selection = select_months(ds_qq_clipped, months[season])
    if ds_qq_cmatch is not None:
        ds_qq_cmatch_selection = select_months(ds_qq_cmatch, months[season])
    
    climatologies = {}
    climatologies['hist'] = ds_hist_selection[hist_var].mean('time', keep_attrs=True).compute()
    climatologies['ref'] = ds_ref_selection[ref_var].mean('time', keep_attrs=True).compute()
    climatologies['target'] = ds_target_selection[target_var].mean('time', keep_attrs=True).compute()
    climatologies['qq'] = ds_qq_selection[qq_var].mean('time', keep_attrs=True).compute()
    if ds_qq_clipped is not None:
        climatologies['qq_clipped'] = ds_qq_clipped_selection[qq_var].mean('time', keep_attrs=True).compute()
    if ds_qq_cmatch is not None:
        climatologies['qq_cmatch'] = ds_qq_cmatch_selection[qq_var].mean('time', keep_attrs=True).compute()
    
    comparisons = {}
    comparisons['ref_hist'] = validation.spatial_comparison_data(
        climatologies['ref'],
        climatologies['hist'],
        scaling
    )
    comparisons['qq_target'] = validation.spatial_comparison_data(
        climatologies['qq'],
        climatologies['target'],
        scaling
    )
    comparisons['target_hist'] = validation.spatial_comparison_data(
        climatologies['target'],
        climatologies['hist'],
        scaling
    )
    comparisons['qq_ref'] = validation.spatial_comparison_data(
        climatologies['qq'],
        climatologies['ref'],
        scaling
    )
    comparisons['qdc_model_change'] = validation.spatial_comparison_data(
        comparisons['qq_target'],
        comparisons['ref_hist'],
        'additive',
    )
    if ds_qq_clipped is not None:
        comparisons['qq_clipped'] = validation.spatial_comparison_data(
            climatologies['qq_clipped'],
            climatologies['qq'],
            scaling
        )
    if ds_qq_cmatch is not None:
        comparisons['qq_cmatch_target'] = validation.spatial_comparison_data(
            climatologies['qq_cmatch'],
            climatologies['target'],
            scaling
        )
        comparisons['qdc_cmatch_model_change'] = validation.spatial_comparison_data(
            comparisons['qq_cmatch_target'],
            comparisons['ref_hist'],
            'additive',
        )
    
    return climatologies, comparisons 

In [None]:
def plot_comparisons(climatologies, comparisons, season='annual'):
    """Plot comparisons for a given season"""
    
    validation.spatial_comparison_plot(
        climatologies['ref'],
        climatologies['hist'],
        comparisons['ref_hist'],
        'reference',
        'historical',
        plot_config['regular_cmap'],
        plot_config['diverging_cmap'],
        plot_config['general_levels'],
        plot_config['difference_levels'],
        scaling,
        city_lat_lon=city_lat_lon,
        land_only=mask_ocean,
        clim_extend=clim_extend,
    )
    
    validation.spatial_comparison_plot(
        climatologies['qq'],
        climatologies['target'],
        comparisons['qq_target'],
        'QQ',
        'target',
        plot_config['regular_cmap'],
        plot_config['diverging_cmap'],
        plot_config['general_levels'],
        plot_config['difference_levels'],
        scaling,
        city_lat_lon=city_lat_lon,
        land_only=mask_ocean,
        clim_extend=clim_extend,
    )
    
    validation.spatial_comparison_plot(
        climatologies['target'],
        climatologies['hist'],
        comparisons['target_hist'],
        'target',
        'historical',
        plot_config['regular_cmap'],
        plot_config['diverging_cmap'],
        plot_config['general_levels'],
        plot_config['difference_levels'],
        scaling,
        city_lat_lon=city_lat_lon,
        land_only=mask_ocean,
        clim_extend=clim_extend
    )
    
    validation.spatial_comparison_plot(
        climatologies['qq'],
        climatologies['ref'],
        comparisons['qq_ref'],
        'QQ',
        'reference',
        plot_config['regular_cmap'],
        plot_config['diverging_cmap'],
        plot_config['general_levels'],
        plot_config['difference_levels'],
        scaling,
        city_lat_lon=city_lat_lon,
        land_only=mask_ocean,
        clim_extend=clim_extend
    )
    
    if 'qq_clipped' in comparisons:
        validation.spatial_comparison_plot(
            climatologies['qq_clipped'],
            climatologies['qq'],
            comparisons['qq_clipped'],
            'QQ clipped',
            'QQ',
            plot_config['regular_cmap'],
            plot_config['diverging_cmap'],
            plot_config['general_levels'],
            plot_config['difference_levels'],
            scaling,
            city_lat_lon=city_lat_lon,
            land_only=mask_ocean,
            clim_extend=clim_extend
        )
    
    if 'qdc_model_change' in comparisons:
        validation.spatial_comparison_plot(
            comparisons['qq_target'],
            comparisons['ref_hist'],
            comparisons['qdc_model_change'],
            'QQ change',
            'Model change',
            plot_config['diverging_cmap'],
            plot_config['diverging_cmap'],
            plot_config['difference_levels'],
            plot_config['difference_levels'],
            'additive',
            city_lat_lon=city_lat_lon,
            land_only=mask_ocean,
            clim_extend='both'
        )
    
    if 'qdc_cmatch_model_change' in comparisons:
        validation.spatial_comparison_plot(
            climatologies['qq_cmatch'],
            climatologies['target'],
            comparisons['qq_cmatch_target'],
            'QQ (change matched)',
            'target',
            plot_config['regular_cmap'],
            plot_config['diverging_cmap'],
            plot_config['general_levels'],
            plot_config['difference_levels'],
            scaling,
            city_lat_lon=city_lat_lon,
            land_only=mask_ocean,
            clim_extend=clim_extend,
        )
        validation.spatial_comparison_plot(
            comparisons['qq_cmatch_target'],
            comparisons['ref_hist'],
            comparisons['qdc_cmatch_model_change'],
            'QDC change (change matched)',
            'Model change',
            plot_config['diverging_cmap'],
            plot_config['diverging_cmap'],
            plot_config['difference_levels'],
            plot_config['difference_levels'],
            'additive',
            city_lat_lon=city_lat_lon,
            land_only=mask_ocean,
            clim_extend='both'
        )

### Annual

In [None]:
climatologies, comparisons = get_comparison_data()

In [None]:
plot_comparisons(climatologies, comparisons)

### DJF

In [None]:
climatologies, comparisons = get_comparison_data(season='DJF')

In [None]:
plot_comparisons(climatologies, comparisons, season='DJF')

### MAM

In [None]:
climatologies, comparisons = get_comparison_data(season='MAM')

In [None]:
plot_comparisons(climatologies, comparisons, season='MAM')

### JJA

In [None]:
climatologies, comparisons = get_comparison_data(season='JJA')

In [None]:
plot_comparisons(climatologies, comparisons, season='JJA')

### SON

In [None]:
climatologies, comparisons = get_comparison_data(season='SON')

In [None]:
plot_comparisons(climatologies, comparisons, season='SON')

## Crazy value check

In [None]:
shape = gp.read_file('/g/data/ia39/aus-ref-clim-data-nci/shapefiles/data/australia/australia.shp')
ds_adjust_shape = validation.subset_shape(ds_adjust, shape=shape)

In [None]:
ds_adjust_shape['af'].max()

In [None]:
ds_adjust_shape['af'].where(ds_adjust_shape['hist_q'] > 1, 0).max()

In [None]:
ds_adjust_shape['af'].where(ds_adjust_shape['hist_q'] > 1, 0).max(dim=af_dims).plot()

In [None]:
seasonal_r = validation.calc_seasonal_correlation(ds_target, target_var, ds_qq, qq_var)

In [None]:
validation.plot_seasonal_correlation(seasonal_r, land_only=True, city_lat_lon=city_lat_lon)

In [None]:
if 'qq_cmatch_file' in locals():
    seasonal_r_cmatch = validation.calc_seasonal_correlation(ds_target, target_var, ds_qq_cmatch, qq_var)

In [None]:
if 'qq_cmatch_file' in locals():
    validation.plot_seasonal_correlation(seasonal_r_cmatch, land_only=True, city_lat_lon=city_lat_lon)

In [None]:
if 'qq_clipped_file' in locals():
    seasonal_r_clipped = validation.calc_seasonal_correlation(ds_target, target_var, ds_qq_clipped, qq_var)

In [None]:
if 'qq_clipped_file' in locals():
    validation.plot_seasonal_correlation(seasonal_r_clipped, land_only=True, city_lat_lon=city_lat_lon)

In [None]:
seasonal_diff = validation.calc_seasonal_change_diff(
    ds_hist[hist_var],
    ds_ref[ref_var],
    ds_target[target_var],
    ds_qq[qq_var],
    scaling
)

In [None]:
validation.plot_seasonal_change_diff(
    seasonal_diff,
    land_only=True,
    city_lat_lon=city_lat_lon,
    levels=[0, 10, 20, 40, 80, 160, 320, 640, 1280]
)

In [None]:
if 'qq_cmatch_file' in locals():
    seasonal_diff_cmatch = validation.calc_seasonal_change_diff(
        ds_hist[hist_var],
        ds_ref[ref_var],
        ds_target[target_var],
        ds_qq_cmatch[qq_var],
        scaling
    )

In [None]:
if 'qq_cmatch_file' in locals():
    validation.plot_seasonal_change_diff(
        seasonal_diff_cmatch,
        land_only=True,
        city_lat_lon=city_lat_lon,
        levels=[0, 10, 20, 40, 80, 160, 320, 640, 1280]
    )

In [None]:
if 'qq_clipped_file' in locals():
    seasonal_diff_clipped = validation.calc_seasonal_change_diff(
        ds_hist[hist_var],
        ds_ref[ref_var],
        ds_target[target_var],
        ds_qq_clipped[qq_var],
        scaling
    )

In [None]:
if 'qq_clipped_file' in locals():
    validation.plot_seasonal_change_diff(
        seasonal_diff_clipped,
        land_only=True,
        city_lat_lon=city_lat_lon,
        levels=[0, 10, 20, 40, 80, 160, 320, 640, 1280]
    )

In [None]:
validation.plot_monthly_change_sign_agreement(
    ds_hist[hist_var],
    ds_ref[ref_var],
    land_only=True,
    city_lat_lon=city_lat_lon,
)

In [None]:
validation.plot_monthly_change_sign_agreement(
    ds_target[target_var],
    ds_qq[qq_var],
    land_only=True,
    city_lat_lon=city_lat_lon,
)

## Points of interest

In [None]:
def plot_city(city, n_values=50):
    """Generate plots for a city"""
    
    lat, lon = city_lat_lon[city]
    
    validation.single_point_analysis(
        ds_hist[hist_var],
        ds_ref[ref_var],
        ds_target[target_var],
        ds_qq[qq_var],
        ds_adjust,
        hist_var,
        scaling,
        city,
        lat,
        lon,
        plot_config['regular_cmap'],
        plot_config['diverging_cmap'],
        plot_config['general_levels'],
        plot_config['af_levels'],
        da_qq_clipped=da_qq_clipped,
        da_qq_cmatch=da_qq_cmatch,
        pdf_xbounds=plot_config['pdf_xbounds'],
        pdf_ybounds=plot_config['pdf_ybounds'],
        q_xbounds=plot_config['q_xbounds'],
        n_values=n_values,
        months=plot_config['months'],
        plot_1d_quantiles=plot_config['plot_1d_quantiles_flag'],
        plot_1d_values=plot_config['plot_1d_values_flag'],
        plot_pdfs=plot_config['plot_pdfs_flag'],
        seasonal_agg='mean',
        extreme_for_values=plot_config['extreme_for_values']
    )

In [None]:
plot_city('Hobart')

In [None]:
plot_city('Melbourne')

In [None]:
plot_city('Mildura')

In [None]:
plot_city('Thredbo')

In [None]:
plot_city('Sydney')

In [None]:
plot_city('Brisbane')

In [None]:
plot_city('Cairns')

In [None]:
plot_city('Darwin')

In [None]:
plot_city('Alice Springs')

In [None]:
plot_city('Port Hedland')

In [None]:
plot_city('Karlamilyi National Park')

In [None]:
plot_city('Perth')

In [None]:
plot_city('Adelaide')