In [1]:
import xarray as xr
import dask
import itertools
import dask.bag as db

import preprocess_obs 

## Read and stack data

In [2]:
cafe_file = '/g/data/xv83/dbi599/precip_cafe-c5-d60-pX-f6_19901101-19931101_3650D_cafe-grid.zarr.zip'
ds_cafe = xr.open_zarr(cafe_file)
ds_cafe                

Unnamed: 0,Array,Chunk
Bytes,116.80 kB,116.80 kB
Shape,"(3650, 4)","(3650, 4)"
Count,2 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray
"Array Chunk Bytes 116.80 kB 116.80 kB Shape (3650, 4) (3650, 4) Count 2 Tasks 1 Chunks Type datetime64[ns] numpy.ndarray",4  3650,

Unnamed: 0,Array,Chunk
Bytes,116.80 kB,116.80 kB
Shape,"(3650, 4)","(3650, 4)"
Count,2 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.24 GB,11.10 MB
Shape,"(4, 3650, 96, 17, 17)","(1, 50, 96, 17, 17)"
Count,293 Tasks,292 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 3.24 GB 11.10 MB Shape (4, 3650, 96, 17, 17) (1, 50, 96, 17, 17) Count 293 Tasks 292 Chunks Type float64 numpy.ndarray",3650  4  17  17  96,

Unnamed: 0,Array,Chunk
Bytes,3.24 GB,11.10 MB
Shape,"(4, 3650, 96, 17, 17)","(1, 50, 96, 17, 17)"
Count,293 Tasks,292 Chunks
Type,float64,numpy.ndarray


In [3]:
ds_stacked = ds_cafe.stack({'sample': ['ensemble', 'init_date', 'lead_time']})
ds_stacked

Unnamed: 0,Array,Chunk
Bytes,11.21 MB,11.21 MB
Shape,"(1401600,)","(1401600,)"
Count,6 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray
"Array Chunk Bytes 11.21 MB 11.21 MB Shape (1401600,) (1401600,) Count 6 Tasks 1 Chunks Type datetime64[ns] numpy.ndarray",1401600  1,

Unnamed: 0,Array,Chunk
Bytes,11.21 MB,11.21 MB
Shape,"(1401600,)","(1401600,)"
Count,6 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.24 GB,33.76 MB
Shape,"(17, 17, 1401600)","(17, 17, 14600)"
Count,1818 Tasks,96 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 3.24 GB 33.76 MB Shape (17, 17, 1401600) (17, 17, 14600) Count 1818 Tasks 96 Chunks Type float64 numpy.ndarray",1401600  17  17,

Unnamed: 0,Array,Chunk
Bytes,3.24 GB,33.76 MB
Shape,"(17, 17, 1401600)","(17, 17, 14600)"
Count,1818 Tasks,96 Chunks
Type,float64,numpy.ndarray


In [7]:
ds_stacked = ds_stacked.assign_coords({'time': ds_stacked['time']})

In [12]:
ds_stacked

Unnamed: 0,Array,Chunk
Bytes,11.21 MB,11.21 MB
Shape,"(1401600,)","(1401600,)"
Count,6 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray
"Array Chunk Bytes 11.21 MB 11.21 MB Shape (1401600,) (1401600,) Count 6 Tasks 1 Chunks Type datetime64[ns] numpy.ndarray",1401600  1,

Unnamed: 0,Array,Chunk
Bytes,11.21 MB,11.21 MB
Shape,"(1401600,)","(1401600,)"
Count,6 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.24 GB,33.76 MB
Shape,"(17, 17, 1401600)","(17, 17, 14600)"
Count,1818 Tasks,96 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 3.24 GB 33.76 MB Shape (17, 17, 1401600) (17, 17, 14600) Count 1818 Tasks 96 Chunks Type float64 numpy.ndarray",1401600  17  17,

Unnamed: 0,Array,Chunk
Bytes,3.24 GB,33.76 MB
Shape,"(17, 17, 1401600)","(17, 17, 14600)"
Count,1818 Tasks,96 Chunks
Type,float64,numpy.ndarray


In [None]:
def stack_super_ensemble(ds, period, super_ensemble_dims=('init_date','ensemble'), new_dim='super_ensemble'):
    """ Stack multiple dims along a common dimension keeping only data in a given period and removing nans 
    """
    #ds_reindex = reindex_forecast_hack(ds, dropna=True)
    #time_mask = (ds_reindex.time >= period[0]) & (ds_reindex.time <= period[1])
    #ds_reindex = ds_reindex.where(time_mask, drop=True)
    
    stack_dims = list(set(super_ensemble_dims).intersection(set(['init_date','ensemble'])))
    stack_dims = [dim for dim in ds_reindex.dims if dim in stack_dims] # reorder to match dataset
    ds_stacked = ds_reindex.stack({new_dim: stack_dims})
    ds_ie = []; slide = 0
    for lead in ds_stacked.lead_time:
        if 'lead_time' in super_ensemble_dims:
            tmp = ds_stacked.sel(lead_time=lead, drop=True)
            tmp_squeeze = tmp.where(tmp.notnull(), drop=True)
            tmp_squeeze = tmp_squeeze.assign_coords({new_dim:slide+np.arange(len(tmp_squeeze[new_dim]))})
        else:
            tmp = ds_stacked.sel(lead_time=lead)
            tmp_squeeze = tmp.where(tmp.notnull(), drop=True)
            tmp_squeeze = tmp_squeeze.assign_coords({new_dim:range(len(tmp_squeeze[new_dim]))})
        ds_ie.append(tmp_squeeze)
        slide += len(tmp_squeeze[new_dim])
    if 'lead_time' in super_ensemble_dims:
        return xr.concat(ds_ie, dim=new_dim)
    else:
        return xr.concat(ds_ie, dim='lead_time')

In [5]:
awap_file = '/g/data/xv83/ds0092/data/csiro-dcfp-csiro-awap/rain_day_19000101-20201202_cafe-grid.zarr/'

In [6]:
da_awap = preprocess_obs.read_obs(awap_file, 'precip', 'precip')
da_awap = da_awap.sel(time=slice('1990-01-01', '1993-12-31'))
da_awap

Unnamed: 0,Array,Chunk
Bytes,3.38 MB,3.38 MB
Shape,"(1460, 17, 17)","(1460, 17, 17)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 3.38 MB 3.38 MB Shape (1460, 17, 17) (1460, 17, 17) Count 5 Tasks 1 Chunks Type float64 numpy.ndarray",17  17  1460,

Unnamed: 0,Array,Chunk
Bytes,3.38 MB,3.38 MB
Shape,"(1460, 17, 17)","(1460, 17, 17)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray


## Fidelity test

In [None]:
def unseen_univariate_moment_fidelity(fcst, obsv, period, alpha=5, n_repeats=1000, n_block=1, by_lead=True):
    """ Run unseen fidelity test
        if alpha == None, returns two Datasets
            n_repeats samples of the moments from the model
            moments from the observations
        else returns single Dataset with
            0 where obsv falls between alpha/2 % and 100-alpha/2 % of the bootstrapped model distribution
            -1 where the obsv falls above 100-alpha/2 %
            1 where the obsv falls below alpha/2 %
    """
        
    fcst_period_stacked = stack_super_ensemble(fcst, period, 
                                               super_ensemble_dims=('init_date','ensemble') 
                                               if by_lead else ('init_date','ensemble', 'lead_time'))
    obsv_period = obsv.where((obsv.time >= period[0]) & (obsv.time <= period[1]), drop=True)
    
    fcst_moments = n_random_resamples(fcst_period_stacked, 
                                      samples={'super_ensemble': (len(obsv_period['time']), n_block)}, 
                                      n_repeats=n_repeats,
                                      function=get_first_four_moments,
                                      function_kwargs={'dim':'super_ensemble'})

    obsv_moments = get_first_four_moments(obsv_period, 'time')
    
    if alpha:
        fcst_moments_upper = fcst_moments.quantile(1-alpha/200, dim='k')
        fcst_moments_lower = fcst_moments.quantile(alpha/200, dim='k')

        not_nan = (obsv_moments * fcst_moments_lower * fcst_moments_upper).notnull()
        above_min = obsv_moments >= fcst_moments_lower
        below_max = obsv_moments <= fcst_moments_upper

        passed = above_min & below_max
        too_low = (above_min == False) & (passed == False) 
        too_high = (below_max == False) & (passed == False)

        return (xr.where(passed, 2, 0) + xr.where(too_low, 3, 0) + xr.where(too_high, 1, 0) - 2).where(not_nan)
    else:
        return fcst_moments, obsv_moments

In [7]:
def get_first_four_moments(da, dim):
    """ Return the mean, std, skewness and kurtosis along dim """
    
    ds = da.mean(dim).to_dataset(name='mean')
    daf = da - da.mean(dim)
    ds['std'] = da.std(dim)
    ds['skew'] = ((daf ** 3).mean(dim) / ((daf ** 2).mean(dim) ** (3/2)))
    ds['kurt'] = ((daf ** 4).mean(dim) / ((daf ** 2).mean(dim) ** (2)))
    
    return ds

In [8]:
awap_moments = get_first_four_moments(da_awap, 'time')

In [9]:
awap_moments

Unnamed: 0,Array,Chunk
Bytes,2.31 kB,2.31 kB
Shape,"(17, 17)","(17, 17)"
Count,7 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 2.31 kB 2.31 kB Shape (17, 17) (17, 17) Count 7 Tasks 1 Chunks Type float64 numpy.ndarray",17  17,

Unnamed: 0,Array,Chunk
Bytes,2.31 kB,2.31 kB
Shape,"(17, 17)","(17, 17)"
Count,7 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.31 kB,2.31 kB
Shape,"(17, 17)","(17, 17)"
Count,8 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 2.31 kB 2.31 kB Shape (17, 17) (17, 17) Count 8 Tasks 1 Chunks Type float64 numpy.ndarray",17  17,

Unnamed: 0,Array,Chunk
Bytes,2.31 kB,2.31 kB
Shape,"(17, 17)","(17, 17)"
Count,8 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.31 kB,2.31 kB
Shape,"(17, 17)","(17, 17)"
Count,17 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 2.31 kB 2.31 kB Shape (17, 17) (17, 17) Count 17 Tasks 1 Chunks Type float64 numpy.ndarray",17  17,

Unnamed: 0,Array,Chunk
Bytes,2.31 kB,2.31 kB
Shape,"(17, 17)","(17, 17)"
Count,17 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.31 kB,2.31 kB
Shape,"(17, 17)","(17, 17)"
Count,17 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 2.31 kB 2.31 kB Shape (17, 17) (17, 17) Count 17 Tasks 1 Chunks Type float64 numpy.ndarray",17  17,

Unnamed: 0,Array,Chunk
Bytes,2.31 kB,2.31 kB
Shape,"(17, 17)","(17, 17)"
Count,17 Tasks,1 Chunks
Type,float64,numpy.ndarray


In [10]:
def random_resample(*objects, samples,
                    function=None, function_kwargs=None, replace=True):
    """
        Randomly resample from provided xarray objects and return the results of the subsampled dataset passed through \
        a provided function
                
        Parameters
        ----------
        *objects : xarray DataArray or Dataset
            Objects containing data to be resampled. The coordinates of the first object are used for resampling and the \
            same resampling is applied to all objects
        samples : dictionary
            Dictionary containing the dimensions to subsample, the number of samples and the continuous block size \
            within the sample. Of the form {'dim1': (n_samples, block_size), 'dim2': (n_samples, block_size)}. The first \
            object in objects must contain all dimensions listed in samples, but subsequent objects need not.
        function : function object, optional
            Function to reduced the subsampled data
        function_kwargs : dictionary, optional
            Keyword arguments to provide to function
        replace : boolean, optional
            Whether the sample is with or without replacement
                
        Returns
        -------
        sample : xarray DataArray or Dataset
            Array containing the results of passing the subsampled data through function
    """
    objects_sub = [obj.copy() for obj in objects]
    for dimension, (n_samples, block_size) in samples.items():
        n_blocks = int(n_samples / block_size)
        n_samples = n_blocks * block_size
        random_samples = [slice(x,x+block_size) for x in np.random.choice(len(objects_sub[0][dimension])-block_size+1, 
                                                                          size=n_blocks,
                                                                          replace=replace)]
        objects_sub = [xr.concat([obj.isel({dimension: random_sample}) for random_sample in random_samples],
                                 dim=dimension) 
                       if dimension in obj.dims else obj 
                       for obj in objects_sub]
    
    if function:
        if function_kwargs:
            res = tuple([function(obj, **function_kwargs) for obj in objects_sub])
        else:
            res = tuple([function(obj) for obj in objects_sub])
    else:
        res = tuple(objects_sub)
    
    if len(res) == 1:
        return res[0]
    else:
        return res
    

def n_random_resamples(*objects, samples, n_repeats, 
                       function=None, function_kwargs=None,
                       replace=True, with_dask=True):
    """
    Repeatedly randomly resample from provided xarray objects and return
    the results of the subsampled dataset passed through a provided function
                
    Parameters
    ----------
    objects : xarray DataArray or Dataset
        Objects containing data to be resampled.
        The coordinates of the first object are used for resampling and
        the same resampling is applied to all objects
    samples : dictionary
        Dictionary containing the dimensions to subsample, the number of samples and the continuous block size \
        within the sample. Of the form {'dim1': (n_samples, block_size), 'dim2': (n_samples, block_size)}
    n_repeats : int
        Number of times to repeat the resampling process
    function : function object, optional
        Function to reduced the subsampled data
    function_kwargs : dictionary, optional
        Keyword arguments to provide to function
    replace : boolean, optional
        Whether the sample is with or without replacement
    with_dask : boolean, optional
        If True, use dask to parallelize across n_repeats using dask.delayed
    write_to_disk
                
    Returns
    -------
    sample : xarray DataArray or Dataset
        Array containing the results of passing the subsampled data through function
    """

    if with_dask & (n_repeats > 1000):
        n_objects = itertools.repeat(objects[0], times=n_repeats)
        b = db.from_sequence(n_objects, npartitions=100)
        rs_list = b.map(random_resample, *(objects[1:]), 
                        **{'samples':samples, 'function':function, 
                           'function_kwargs':function_kwargs, 'replace':replace}).compute()
    else:              
        resample_ = dask.delayed(random_resample) if with_dask else random_resample
        rs_list = [resample_(*objects,
                             samples=samples,
                             function=function,
                             function_kwargs=function_kwargs,
                             replace=replace) for _ in range(n_repeats)] 
        if with_dask:
            rs_list = dask.compute(rs_list)[0]
    
    if len(objects) == 1:
        return xr.concat([r.unify_chunks() for r in rs_list], dim='k')
    else:
        return tuple([xr.concat([r.unify_chunks() for r in rs], dim='k') for rs in zip(*rs_list)])

In [None]:
alpha = 5
n_repeats = 100
n_block = 1

fcst_moments = n_random_resamples(ds_stacked, 
                                  samples={'sample': (len(da_awap['time']), n_block)}, 
                                  n_repeats=n_repeats,
                                  function=get_first_four_moments,
                                  function_kwargs={'dim' : 'sample'})