## Set up Workspace

In [ ]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate, stats
import xesmf as xe
from datetime import datetime, timedelta
from ngallery_utils import DATASETS

%matplotlib inline

var = 'uas'

## Look at Data

In [ ]:
hist_file = DATASETS.fetch("uas.hist.CanESM2.CRCM5-UQAM.day.NAM-44i.raw.Colorado.nc")
rcp85_file = DATASETS.fetch("uas.rcp85.CanESM2.CRCM5-UQAM.day.NAM-44i.raw.Colorado.nc")
meas_file = DATASETS.fetch("uas.gridMET.NAM-44i.Colorado.nc")

In [ ]:
ds_hist = xr.open_dataset(hist_file)
ds_hist

In [ ]:
ds_hist[var].isel(time=0).plot.contourf(x="lon", y="lat", cmap="PRGn", cbar_kwargs={"label": "m/s"})

In [ ]:
ds_rcp85 = xr.open_dataset(rcp85_file)
ds_rcp85

In [ ]:
ds_rcp85[var].isel(time=0).plot.contourf(x="lon", y="lat", cmap="PRGn", cbar_kwargs={"label": "m/s"})

In [ ]:
ds_meas = xr.open_dataset(meas_file)
ds_meas

In [ ]:
ds_meas[var].isel(time=0).plot.contourf(x="lon", y="lat", cmap="PRGn", cbar_kwargs={"label": "m/s"})

# Filter Data

In [ ]:
wesn = [-110, -108, 39.5, 41.5]
ds_meas_flt = ds_meas.sel(lon = slice(wesn[0], wesn[1]), lat = slice(wesn[2], wesn[3]))
ds_hist_flt = ds_hist.sel(lon = slice(wesn[0], wesn[1]), lat = slice(wesn[2], wesn[3]))
ds_rcp85_flt = ds_rcp85.sel(lon = slice(wesn[0], wesn[1]), lat = slice(wesn[2], wesn[3]))

## Align Time

In [ ]:
ds_meas_noleap = ds_meas_flt.sel(time=~((ds_meas_flt.time.dt.dayofyear == 366) ))

In [ ]:
def _cfnoleap_to_datetime(da):
    datetimeindex = da.indexes['time'].to_datetimeindex()
    ds = da#.to_dataset()
    ds['time_dt']= ('time', datetimeindex)
    ds = ds.swap_dims({'time': 'time_dt'})
    assert len(da.time) == len(ds.time_dt)
    return ds

ds_hist_dt = _cfnoleap_to_datetime(ds_hist_flt)   
ds_rcp85_dt = _cfnoleap_to_datetime(ds_rcp85_flt)  

In [ ]:
def _regroup_models_bytime(ds_meas, ds_hist_dt, ds_rcp_dt):
    t0_meas = ds_meas.time[0]
    tn_meas = ds_meas.time[-1]
    t0_fut = tn_meas.values + np.timedelta64(1, 'D')
    
    ds_past = ds_hist_dt.sel(time_dt = slice(t0_meas, tn_meas))
    ds_past = ds_past.swap_dims({'time_dt':'time'})
    
    ds_fut_pt1 = ds_hist_dt.sel(time_dt = slice(t0_fut,None))
    ds_fut = xr.concat([ds_fut_pt1[var], ds_rcp_dt[var]], 'time_dt')
    ds_fut = ds_fut.swap_dims({'time_dt':'time'})
    return ds_past, ds_fut

ds_past, ds_fut = _regroup_models_bytime(ds_meas_noleap, ds_hist_dt, ds_rcp85_dt)

## Bias Correction
  
  Method is "Range", good for relative humididty. 
  Need upper and lower values for range
  
  Take the difference from hist and measured in upper and lower, apply that shift to the new upper and lower

  result = (x_fut-lower_fut)/(upper_fut-lower_fut)

  denormalize_result = result*(upper_new - lower_new) + lower_new

It isn't min and max, but 0 and 100, relative humidity is dividded by 0 to range from 0 to 1 and then multiplied by 100 again.

In [ ]:
def _reshape(ds, window_width):
    split = lambda g: (g.rename({'time': 'day'})
                       .assign_coords(day=g.time.dt.dayofyear.values))
    ds2 = ds.groupby('time.year').apply(split)
    
    early_Jans = ds2.isel(day = slice(None,window_width//2))
    late_Decs = ds2.isel(day = slice(-window_width//2,None))
    
    ds3 = xr.concat([late_Decs,ds2,early_Jans],dim='day')
    return ds3

In [None]:
def _calc_stats(ds, window_width):
    ds_rsh = _reshape(ds, window_width)
    
    ds_rolled = ds_rsh.rolling(day=window_width, center=True).construct('win_day')
    
    n = window_width//2+1
    ds_min = ds_rolled.min(dim=['year','win_day']).isel(day=slice(n,-n))
    ds_max = ds_rolled.max(dim=['year','win_day']).isel(day=slice(n,-n))
    
    ds_avyear = ds_rsh.mean(dim=['year','day'])
    ds_range = ((ds_avyear - ds_min) / (ds_max - ds_min))
    return ds_min, ds_max, ds_range

In [None]:
window_width=31
meas_min, meas_max, meas_range = _calc_stats(ds_meas_noleap, window_width)
hist_min, hist_max, hist_range = _calc_stats(ds_past, window_width)

In [ ]:
def _get_params(meas_min, meas_max, past_min, past_max):    
    min_shift = meas_min - past_min
    max_shift = meas_max - past_max
    return min_shift, max_shift

min_shift, max_shift = _get_params(meas_min, meas_max, hist_min, hist_max)

In [ ]:
def _calc_fut_stats(ds_fut, window_width):
    ds_rolled = ds_fut.rolling(time=window_width, center=True).construct('win_day')
    
    ds_min = ds_rolled.min(dim=['win_day'])
    ds_max = ds_rolled.max(dim=['win_day'])
    
    ds_avyear = ds_fut.mean(dim=['time'])
    ds_range = ((ds_avyear - ds_min) / (ds_max - ds_min))
    return ds_min, ds_max, ds_range

fut_min, fut_max, fut_range = _calc_fut_stats(ds_fut, window_width)

In [ ]:
fut_min_bc = fut_min + min_shift
fut_max_bc = fut_max + max_shift

In [ ]:
fut_corrected = fut_range * (fut_max_bc - fut_min_bc) + fut_min_bc

# Visualize the Correction

In [ ]:
def gaus(mean, std, doy):
    a = mean.sel(day=doy)
    mu = a.isel(lon = 0, lat = 0)

    b =std.sel(day=doy)
    sigma = b.isel(lon = 0, lat = 0)

    x = np.linspace(mu - 3*sigma, mu + 3*sigma, 100)
    y = stats.norm.pdf(x, mu, sigma)
    return x, y

In [ ]:
fut_typ_mean, fut_typ_std, fut_typ_zscore = _calc_stats(ds_fut, window_width)
fut_typ_mean_bc = fut_typ_mean + shift
fut_typ_std_bc = fut_typ_std * scale

doy=20
plt.figure()
x,y = gaus(hist_mean[var], hist_std[var], doy)
plt.plot(x, y, 'orange', label = 'historical model')
x,y = gaus(meas_mean[var], meas_std[var], doy)
plt.plot(x, y, 'red', label = 'measured')
x,y = gaus(fut_typ_mean, fut_typ_std, doy)
plt.plot(x, y, 'blue', label = 'raw future model')
x,y = gaus(fut_typ_mean_bc[var], fut_typ_std_bc[var], doy)
plt.plot(x, y, 'green', label = 'corrected future model')
plt.legend()