In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import collections
import os 

import warnings
warnings.filterwarnings("ignore")

import xesmf as xe

from utils import _convert_lons, _remove_leap_days, compute_daily_climo
from regridding import apply_weights

import dask.distributed as dd
import dask_kubernetes as dk
import dask
import rhg_compute_tools.kubernetes as rhgk

### This notebook is a test of all the steps for Spatial Disaggregation to get a handle on the total CPU time it will take for this part of BCSD. 

Once-off steps: 

1. compute multi-decade daily climatologies of ERA-5 at obs-res and coarsen it to model-res (they, e.g. NASA-NEX, do not say how, we will do bilinear for consistency with later step)

Per model/scenario/experiment steps:

1. subtract (or divide for precip) BC’ed model data at model-res from obs climo at model resolution to calculate a “scaling factor” 
2. bilinearly interpolate “scaling factor” (using xESMF) from the model grid to the obs grid 
3. Apply scaling factor by adding (for temp) and multiplying (for precip) the “scaling factor” to the obs-res daily climatology 

NOTE: For the purpose of being conservative with timing, the "coarsen obs climatology step to model-res" is in the per model/scenario/experiment step, since we don't know for sure how/if CMIP6 models will be at exactly the same resolution. 

Currently this workflow is only built out for temperature, not precipitation. All steps are included, the last step (applying the interpolated scale factor to the obs-res daily climatology) has not yet been tested. All other parts of the workflow have been tested. The second to last step, the interpolation of the scaling factor from coarse to fine, is the most memory intensive, thus I have only tested for a subset of timesteps. 

In [2]:
client, cluster = rhgk.get_standard_cluster()
cluster

VBox(children=(HTML(value='<h2>KubeCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    .…

load test bias corrected output from global bias correction prototype notebook (BC'ed NASA GISS CMIP6 data)

In [3]:
tmax_model = xr.open_dataset('/home/jovyan/global_bias_corrected_tenyears.nc').rename(
                             {'lat': 'latitude', 'lon': 'longitude', 
                              '__xarray_dataarray_variable__': 'tasmax'})

In [4]:
tmax_obs = xr.open_mfdataset(os.path.join('/gcs/rhg-data/climate/source_data/GMFD/tmax', 
                                         'tmax_0p25_daily_199*'), concat_dim='time',
                              parallel=True).squeeze(drop=True)

In [5]:
# standardize longitudes 
tmax_obs = _convert_lons(tmax_obs)

Remove leap days from obs 

In [6]:
# remove leap days 
# tmax_obs = tmax_obs.sel(time=~((tmax_obs.time.dt.month == 2) & (tmax_obs.time.dt.day == 29)))
tmax_obs = _remove_leap_days(tmax_obs)

In [7]:
size_model = len(tmax_model.latitude) * len(tmax_model.longitude)
size_obs = len(tmax_obs.latitude) * len(tmax_obs.longitude)

In [None]:
space_chunks = {'time': -1, 'latitude': 75, 'longitude': 75}
day_chunks = {'dayofyear': 1, 'latitude': -1, 'longitude': -1}

Compute obs climatology (obs res)

Note: in real workflow, this will be pre-computed (only need to do this once) and loaded. 

In [9]:
'''%%time 
tmax_obs = tmax_obs.chunk(space_chunks)
tmax_obs_lazy = compute_daily_climo(tmax_obs['tmax'])
climo_obs_fine = tmax_obs_lazy.persist()
# rechunk climo fine into day chunks 
climo_obs_fine = climo_obs_fine.chunk(day_chunks)
climo_obs_fine = climo_obs_fine.compute()
climo_attrs = {"file description": "daily climatology for 1990s, without leap years, from GMFD", 
               "author": "Diana Gergel", "contact": "dgergel@rhg.com"}
climo_obs_ds = climo_obs_fine.to_dataset(name='tmax')
climo_obs_ds.attrs.update(climo_attrs)
climo_obs_ds.to_netcdf("/home/jovyan/gmfd_test_climo.nc")'''

In [None]:
climo_obs_fine = xr.open_dataset("/home/jovyan/gmfd_test_climo.nc")

### Interpolate obs climo: fine -> coarse 

In [14]:
%%time 
obs_to_mod_weights = '/home/jovyan/obs_to_mod_bilinear_spatial_disagg.nc'
regridder_obs_to_mod = xe.Regridder(tmax_obs.isel(time=0).rename({'latitude': 'lat', 'longitude': 'lon'}), 
                                    tmax_model.isel(time=0).rename({'latitude': 'lat', 'longitude': 'lon'}), 
                         'bilinear', filename=obs_to_mod_weights, reuse_weights=True)



Reuse existing file: /home/jovyan/obs_to_mod_bilinear_spatial_disagg.nc
CPU times: user 32.8 ms, sys: 12.4 ms, total: 45.2 ms
Wall time: 26.1 ms




In [15]:
%%time
climo_obs_coarse_lazy = xr.map_blocks(apply_weights, regridder_obs_to_mod, 
                                args=[climo_obs_fine.rename({'latitude': 'lat', 'longitude': 'lon'})])

CPU times: user 30.3 ms, sys: 2.96 ms, total: 33.3 ms
Wall time: 26.1 ms


In [16]:
%%time 
climo_obs_coarse = climo_obs_coarse_lazy.persist()



CPU times: user 16.3 s, sys: 426 ms, total: 16.7 s
Wall time: 16.4 s




### Compute scaling factor by subtracting for temperature, dividing for precip, the BC'ed model data at model-res from obs climo at model-res. 

In [17]:
def _calculate_anomaly(ds, climo, var_name='temperature'):
    # Necessary workaround to xarray's check with zero dimensions
    # https://github.com/pydata/xarray/issues/3575
    da = ds[var_name]
    if sum(da.shape) == 0:
        return da
    groupby_type = ds.day_of_year
    gb = da.groupby(groupby_type)
    
    return gb - climo

def compute_scale_factor(spec):
    '''
    computes scale factor at the coarse level 
    '''
    da_adj, da_obs_climo_coarse, var_name = spec
    
    return xr.map_blocks(_calculate_anomaly, da_adj, 
                         args=[da_obs_climo_coarse, var_name])



In [None]:
chunks = {'time': 1500, 'latitude': len(tmax_model.latitude), 
                               'longitude': len(tmax_model.longitude)}

tmax_model = tmax_model.chunk(chunks)

In [None]:
display(tmax_model)

In [None]:
'''spec = (tmax_model, climo_obs_coarse)
scale_factor_coarse = compute_scale_factor(spec)

def compute_scale_factor(spec):
    '''
    computes scale factor at the coarse level 
    '''
    da_adj, da_obs_climo_coarse, var_name = spec
    
    return xr.map_blocks(_calculate_anomaly, da_adj, 
                         args=[da_obs_climo_coarse, var_name])'''

In [None]:
%%time 
scale_factor_coarse = xr.map_blocks(_calculate_anomaly, tmax_model, args=[climo_obs_coarse, 'tasmax'])

In [None]:
scale_factor_coarse['time'] = tmax_model.time.copy()

In [None]:
%%time 
sfc = scale_factor_coarse.persist()

In [None]:
'''%%time 
sfc = sfc.compute()'''

In [None]:
# if starting from this point 
sfc = xr.open_dataset('/home/jovyan/sfc_test.nc')

### Interpolate scaling factor: coarse (model grid) -> fine (obs grid)

In [None]:
%%time
mod_to_obs_weights = '/home/jovyan/mod_to_obs_bilinear_weights.nc'
regridder_mod_to_obs = xe.Regridder(tmax_model.rename({'latitude': 'lat', 'longitude': 'lon'}), 
                                    tmax_obs.rename({'latitude': 'lat', 'longitude': 'lon'}), 
                         'bilinear', filename=mod_to_obs_weights, reuse_weights=True)

Functions for replicating xESMF functionality for running on workers since this is too large to run in notebook memory and xESMF is not setup to be used in conjunction with dask. 

In [None]:
def read_xesmf_weights_coo_matrix(weights_file, size_in, size_out):
    ds = xr.open_dataset(weights_file)
    n_s = ds.dims['n_s']
    col = ds['col'].values - 1
    row = ds['row'].values - 1
    S = ds['S'].values
    A = coo_matrix((S, (row, col)), shape=[size_out, size_in]) 
    return A

def apply_weights(spec):
    weights, da, shape_in, shape_out, lats_out, lons_out = spec
    indata = da.values
    
    shape_horiz = shape_in[-2:]
    extra_shape = shape_in[0:-2]
    
    if len(shape_in) > 2:
        indata_flat = indata.reshape(shape_in[0], shape_in[1]*shape_in[2])
    else: 
        indata_flat = indata.reshape(-1, shape_in[0]*shape_in[1])
    
    outdata_flat = weights.dot(indata_flat.T).T
    
    outdata = outdata_flat.reshape(
            [shape_out[0], shape_out[1]])
        
    if len(shape_in) > 2:
        dims = {'time': da.time,'latitude': lats_out, 'longitude': lons_out}
        coords = {'time': da.time, 'latitude': lats_out, 'longitude': lons_out}
    else:
        dims = {'latitude': lats_out, 'longitude': lons_out}
        coords = {'latitude': lats_out, 'longitude': lons_out}
    
    
    outdata_da = xr.DataArray(outdata, dims=dims, 
                              coords=coords)
    
    return outdata_da

In [None]:
%%time 
# make scipy sparse weight matrix 
weights_coo_mat = read_xesmf_weights_coo_matrix(mod_to_obs_weights, size_model, size_obs)

Note: the method with `da.map_blocks` only works on a few years, beyond that we run out of memory. But I believe that this is the better way to do it - needs work. 

In [None]:
JOBS = [(weights_coo_mat, sfc['temperature'].isel(time=timestep).drop('time'), (180, 360), (720, 1440), 
         tmax_gmfd.latitude, tmax_gmfd.longitude) for timestep in sfc['temperature'].time]

In [None]:
sff_futures = client.map(apply_weights, JOBS[:20])

In [None]:
dd.progress(sff_futures)

In [None]:
sffs = client.gather(sff_futures)

In [None]:
sff = xr.concat(sffs, pd.Index(sfc.time.values, name='time'))

### Last step: add (or multiply for precip) the scaling factor to the obs-res daily climatology

NOTE: this step has not been tested given that the memory for the above step needs to be further worked out, but it is essentially the inverse of the step above where we compute the scaling factor. 

In [None]:
def apply_scale_factor(scale_factor_fine, obs_climo):
    da = ds['temperature']
    if 'dayofyear' in scale_factor_fine.dims:
        scale_factor_fine.rename({'dayofyear': 'day_of_year'})
        
    sff_daily = scale_factor_fine.groupby(scale_factor_fine.day_of_year)
    
    return obs_climo + sff_daily

def apply_scale_factor_wrapper(spec):
    '''
    applies scale factor to obs climatology
    '''
    scale_factor_fine, da_obs_climo_fine = spec
    
    return xr.map_blocks(apply_scale_factor, da_adj, args=[da_obs_climo_fine])

In [None]:
# chunk by year 
sff_chunk = sff.chunk({'time': 365, 'latitude': len(tmax_obs.latitude), 'longitude': len(tmax_obs.longitude)})

In [None]:
%%time 
spec = (sff_chunk, climo_obs_fine)
model_ds = apply_scale_factor_wrapper(spec)

In [None]:
model_ds = model_ds.persist()

In [None]:
model_downscaled = model_ds.compute()

### Apply standardizing functions for final output and save (probably as zarr array)