In [None]:
import xarray as xr
import sys
sys.path.append("..") # Set to path of codebase

from qme_train import *
from qme_apply import *

In [None]:
import dask.diagnostics
from dask.distributed import Client, LocalCluster

cluster = LocalCluster()
client = Client(cluster)
client

In [None]:
# Select the data
var = "pr"

st_year_train = 1980
en_year_train = 2019

# Also can include the training period when applying
st_year_apply = 1980 
en_year_apply = 2019

# Apply bias correction to future data for this model
st_year_fut = 2080
en_year_fut = 2099

obs_path = '/g/data/ia39/npcp/data/{var}/observations/AGCD/raw/task-reference/{var}_NPCP-20i_AGCD_v1-0-1_day_{year}0101-{year}1231.nc'
mdl_path = '/g/data/ia39/npcp/data/{var}/CSIRO-ACCESS-ESM1-5/BOM-BARPA-R/raw/task-reference/{var}_NPCP-20i_CSIRO-ACCESS-ESM1-5_{empat}_r6i1p1f1_BOM-BARPA-R_v1_day_{year}0101-{year}1231.nc'

obs_file_list = [obs_path.format(var = var, year = y) 
                 for y in range(st_year_train, en_year_train + 1)]

mdl_file_list = [mdl_path.format(var = var, year = y, empat = "ssp370" if y > 2014 else "historical") 
                 for y in range(st_year_train, en_year_apply + 1)]

fut_file_list = [mdl_path.format(var = var, year = y, empat = "ssp370" if y > 2014 else "historical") 
                 for y in range(st_year_fut, en_year_fut + 1)]

In [None]:
params = {
    "xtr": 3,
    "cal_smth": 21,
    "mthd": '_quick',
    "mn_smth": '_3mn' if var == "pr" else '',
    "ssze_lim": 50,
    "mltp": False,
    "lmt": 1.5 if var == "pr" else -1,
    "lmt_thresh": 10
}

In [None]:
def standardise_latlon(ds, digits=4):
    """
    This function rounds the latitude / longitude coordinates to the 4th digit, because some dataset
    seem to have strange digits (e.g. 50.00000001 instead of 50.0), which prevents merging of data.
    """
    ds = ds.assign_coords({"lat": np.round(ds.lat, digits)})
    ds = ds.assign_coords({"lon": np.round(ds.lon, digits)})
    return(ds)

In [None]:
# Load all data and chunk (chunk sizes for lat and lon may be adjusted here to help performance)
lat_chunk_size = 25
lon_chunk_size = 25

obs_data = xr.open_mfdataset(obs_file_list, preprocess = standardise_latlon)[var].chunk(time = -1, lat = lat_chunk_size, lon = lon_chunk_size)
mdl_data = xr.open_mfdataset(mdl_file_list, preprocess = standardise_latlon)[var].chunk(time = -1, lat = lat_chunk_size, lon = lon_chunk_size)
fut_data = xr.open_mfdataset(fut_file_list, preprocess = standardise_latlon)[var].chunk(time = -1, lat = lat_chunk_size, lon = lon_chunk_size)
obs_data

In [None]:
# Select training data years
mdl_training = mdl_data.sel(time = mdl_data.time.dt.year.isin(range(st_year_train, en_year_train + 1))).chunk({"time": -1})
obs_training = obs_data.sel(time = obs_data.time.dt.year.isin(range(st_year_train, en_year_train + 1))).chunk({"time": -1})

# Select model data years for application
mdl_apply = mdl_data.sel(time = mdl_data.time.dt.year.isin(range(st_year_apply, en_year_apply + 1))).chunk({"time": -1})

In [None]:
# Create distributions histograms
dist_mdl = make_dist(var, mdl_training).chunk({"values": -1, "month": -1})
dist_obs = make_dist(var, obs_training).chunk({"values": -1, "month": -1})

In [None]:
# Apply QME to create adjustment factors
# Using .persist() will start the calculation in the background and keep it in distributed memory
# Without using persist, dask may calculate these twice (once for current and once for future data)
dist_bc = calc_qme(var, dist_mdl, dist_obs, **params).chunk({"values": -1, "month": -1}).persist()
dist_bc

In [None]:
# Apply bias correction factors to model data
mdl_bc = apply_bc(var, mdl_apply, dist_bc.biascorr).rename(var)
mdl_bc

In [None]:
# Apply bias correction factors to future data
fut_bc = apply_bc(var, fut_data, dist_bc.biascorr).rename(var)
fut_bc

In [None]:
# Set output directory as appropriate
outdir = ""

In [None]:
%%time
# Save bias corrected model data as netCDF
mdl_bc.to_netcdf(outdir + f'{var}_sample_out_historical.nc')

In [None]:
%%time
# Save bias corrected future data as netCDF
fut_bc.to_netcdf(outdir + f'{var}_sample_out_future.nc')