In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import collections
import os 

import warnings
warnings.filterwarnings("ignore")

import xesmf as xe

from utils import _convert_lons, _remove_leap_days, compute_daily_climo
from regridding import apply_weights

import dask.distributed as dd
import dask_kubernetes as dk
import dask
import rhg_compute_tools.kubernetes as rhgk

### This notebook is a test of all the steps for Spatial Disaggregation to get a handle on the total CPU time it will take for this part of BCSD. 

Once-off steps: 

1. compute multi-decade daily climatologies of ERA-5 at obs-res and coarsen it to model-res (they, e.g. NASA-NEX, do not say how, we will do bilinear for consistency with later step)

Per model/scenario/experiment steps:

1. subtract (or divide for precip) BC’ed model data at model-res from obs climo at model resolution to calculate a “scaling factor” 
2. bilinearly interpolate “scaling factor” (using xESMF) from the model grid to the obs grid 
3. Apply scaling factor by adding (for temp) and multiplying (for precip) the “scaling factor” to the obs-res daily climatology 

NOTE: For the purpose of being conservative with timing, the "coarsen obs climatology step to model-res" is in the per model/scenario/experiment step, since we don't know for sure how/if CMIP6 models will be at exactly the same resolution. 

Currently this workflow is only built out for temperature, not precipitation. All steps are included, the last step (applying the interpolated scale factor to the obs-res daily climatology) has not yet been tested. All other parts of the workflow have been tested. The second to last step, the interpolation of the scaling factor from coarse to fine, is the most memory intensive, thus I have only tested for a subset of timesteps. 

In [2]:
client, cluster = rhgk.get_standard_cluster()
cluster

VBox(children=(HTML(value='<h2>KubeCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    .…

load test bias corrected output from global bias correction prototype notebook (BC'ed NASA GISS CMIP6 data)

In [3]:
var_model = xr.open_dataset('/home/jovyan/global_bias_correction_scaling_test.nc')

In [4]:
tmax_model = var_model.rename({'lat': 'latitude', 'lon': 'longitude'})

In [6]:
tmax_obs = xr.open_mfdataset(os.path.join('/gcs/rhg-data/climate/source_data/GMFD/tmax', 
                                         'tmax_0p25_daily_199*'), concat_dim='time',
                              parallel=True).squeeze(drop=True)

In [7]:
# standardize longitudes 
tmax_obs = _convert_lons(tmax_obs)

Remove leap days from obs 

In [8]:
# remove leap days 
# tmax_obs = tmax_obs.sel(time=~((tmax_obs.time.dt.month == 2) & (tmax_obs.time.dt.day == 29)))
tmax_obs = _remove_leap_days(tmax_obs)

In [9]:
space_chunks = {'time': -1, 'latitude': 75, 'longitude': 75}
day_chunks = {'dayofyear': 1, 'latitude': -1, 'longitude': -1}

Load daily obs climatology 

In [10]:
climo_obs_fine = xr.open_dataset("/home/jovyan/gmfd_test_climo.nc")

### Interpolate obs climo: fine -> coarse 

In [11]:
%%time 
obs_to_mod_weights = '/home/jovyan/obs_to_mod_bilinear_spatial_disagg.nc'
regridder_obs_to_mod = xe.Regridder(tmax_obs.isel(time=0).rename({'latitude': 'lat', 'longitude': 'lon'}), 
                                    tmax_model.isel(time=0).rename({'latitude': 'lat', 'longitude': 'lon'}), 
                         'bilinear', filename=obs_to_mod_weights, reuse_weights=True)

Reuse existing file: /home/jovyan/obs_to_mod_bilinear_spatial_disagg.nc
CPU times: user 224 ms, sys: 23.3 ms, total: 247 ms
Wall time: 365 ms


In [12]:
%%time
climo_obs_coarse_lazy = xr.map_blocks(apply_weights, regridder_obs_to_mod, 
                                args=[climo_obs_fine['tmax'].rename({'latitude': 'lat', 'longitude': 'lon'})])

CPU times: user 1.32 s, sys: 1.83 s, total: 3.14 s
Wall time: 16.1 s


In [13]:
%%time 
climo_obs_coarse = climo_obs_coarse_lazy.rename({'lat': 'latitude', 'lon': 'longitude'}).compute()

CPU times: user 709 µs, sys: 0 ns, total: 709 µs
Wall time: 716 µs


### Compute scaling factor by subtracting for temperature, dividing for precip, the BC'ed model data at model-res from obs climo at model-res. 

In [14]:
def _calculate_anomaly(ds, climo, var_name):
    # Necessary workaround to xarray's check with zero dimensions
    # https://github.com/pydata/xarray/issues/3575
    da = ds[var_name]
    if sum(da.shape) == 0:
        return da
    groupby_type = ds.time.dt.dayofyear
    gb = da.groupby(groupby_type)
    
    return gb - climo

In [15]:
chunks = {'latitude': 75, 'longitude': 75}

climo_obs_coarse = climo_obs_coarse.chunk(chunks)

In [16]:
%%time 
scale_factor_coarse = xr.map_blocks(_calculate_anomaly, 
                                    tmax_model, args=[climo_obs_coarse, 'tasmax'])

CPU times: user 3.04 s, sys: 411 ms, total: 3.45 s
Wall time: 4.51 s


In [17]:
%%time 
sfc = scale_factor_coarse.compute()

CPU times: user 25.9 s, sys: 2.53 s, total: 28.4 s
Wall time: 28 s


### Interpolate scaling factor: coarse (model grid) -> fine (obs grid)

In [19]:
%%time
mod_to_obs_weights = '/home/jovyan/mod_to_obs_bilinear_spatial_disagg.nc'
regridder_mod_to_obs = xe.Regridder(tmax_model.isel(time=0).rename({'latitude': 'lat', 'longitude': 'lon'}), 
                                    tmax_obs.isel(time=0).rename({'latitude': 'lat', 'longitude': 'lon'}), 
                         'bilinear', filename=mod_to_obs_weights, reuse_weights=True)

Reuse existing file: /home/jovyan/mod_to_obs_bilinear_spatial_disagg.nc
CPU times: user 132 ms, sys: 82.6 ms, total: 214 ms
Wall time: 901 ms


In [20]:
sfc = sfc.drop('dayofyear')

In [21]:
%%time
sff_lazy = xr.map_blocks(apply_weights, regridder_mod_to_obs, 
                                args=[sfc.rename({'latitude': 'lat', 'longitude': 'lon'})])

CPU times: user 44.8 s, sys: 39.1 s, total: 1min 23s
Wall time: 1min 17s


In [22]:
%%time
sff_lazy_compute = sff_lazy.compute()

CPU times: user 456 µs, sys: 0 ns, total: 456 µs
Wall time: 463 µs


### Add (or multiply for precip) the scaling factor to the obs-res daily climatology

In [25]:
sff_ds = sff_lazy_compute.to_dataset(name='scale_factor_fine')

In [26]:
%%time
sff_chunks = {'time': -1, 'lat': 55, 'lon': 55}
sff_ds = sff_ds.chunk(sff_chunks)

CPU times: user 48.7 s, sys: 361 ms, total: 49.1 s
Wall time: 46 s


In [27]:
%%time
cof_chunks = {'latitude': 55, 'longitude': 55}
climo_obs_fine = climo_obs_fine.chunk(cof_chunks)

CPU times: user 6.54 ms, sys: 1.17 ms, total: 7.71 ms
Wall time: 6.88 ms


In [29]:
def apply_scale_factor(ds_sff, obs_climo):
    da = ds_sff['scale_factor_fine'].transpose("latitude", "longitude", "time")
    
    if sum(ds.shape) == 0:
        return ds
    
    groupby_type = ds_sff.time.dt.dayofyear
    
    sff_daily = da.groupby(groupby_type)
    return sff_daily + obs_climo

In [30]:
%%time 
model_ds = xr.map_blocks(apply_scale_factor, sff_ds.rename({'lat': 'latitude', 'lon': 'longitude'}), 
                         args=[climo_obs_fine['tmax']], template=sff_ds.rename({'lat': 'latitude', 'lon': 'longitude'}))

CPU times: user 507 ms, sys: 13.6 ms, total: 520 ms
Wall time: 501 ms


In [31]:
model_downscaled = model_ds.persist()
dd.progress(model_downscaled)

VBox()

In [None]:
model_downscaled

### Apply standardizing functions for final output and save (probably as zarr array)