In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import zarr

import myfuncs

## Prepare data

In [2]:
infile = '/g/data/xv83/dbi599/precip_c5-d60-pX-f6_19861101-19871101_atmos_isobaric_daily.zarr.zip'

In [3]:
ds = xr.open_zarr(infile)
#with zarr.ZipStore(infile, mode='r') as store:
#    ds = xr.open_zarr(store)

#### Convert units

In [5]:
pr_daily = ds['precip'] * 86400
pr_daily.attrs['units'] = 'mm'
pr_daily

Unnamed: 0,Array,Chunk
Bytes,4.05 GB,7.76 MB
Shape,"(4, 3653, 96, 19, 19)","(1, 28, 96, 19, 19)"
Count,1049 Tasks,524 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 4.05 GB 7.76 MB Shape (4, 3653, 96, 19, 19) (1, 28, 96, 19, 19) Count 1049 Tasks 524 Chunks Type float64 numpy.ndarray",3653  4  19  19  96,

Unnamed: 0,Array,Chunk
Bytes,4.05 GB,7.76 MB
Shape,"(4, 3653, 96, 19, 19)","(1, 28, 96, 19, 19)"
Count,1049 Tasks,524 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,116.90 kB,116.90 kB
Shape,"(3653, 4)","(3653, 4)"
Count,2 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray
"Array Chunk Bytes 116.90 kB 116.90 kB Shape (3653, 4) (3653, 4) Count 2 Tasks 1 Chunks Type datetime64[ns] numpy.ndarray",4  3653,

Unnamed: 0,Array,Chunk
Bytes,116.90 kB,116.90 kB
Shape,"(3653, 4)","(3653, 4)"
Count,2 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray


## Annual rainfall

#### Switch lead_time for time (so annual totals can be calculated)

In [6]:
def reindex_forecast(ds, dropna=False):
    """Switch out lead_time axis for time axis (or vice versa) in a forecast dataset."""
    
    if 'lead_time' in ds.dims:
        index_dim = 'lead_time'
        reindex_dim = 'time'
    elif 'time' in ds.dims:
        index_dim = 'time'
        reindex_dim = 'lead_time'
    else:
        raise ValueError("Neither a time nor lead_time dimension can be found")
    swap = {index_dim: reindex_dim}
    to_concat = []
    for init_date in ds['init_date']:
        fcst = ds.sel({'init_date': init_date})
        fcst = fcst.where(fcst[reindex_dim].notnull(), drop=True)
        fcst = fcst.assign_coords({'lead_time': fcst['lead_time'].astype(int)})
        to_concat.append(fcst.swap_dims(swap))
    concat = xr.concat(to_concat, dim='init_date')
    if dropna:
        return concat.where(concat.notnull(), drop=True)
    else:
        return concat

In [7]:
pr_daily_reindexed = reindex_forecast(pr_daily)
pr_daily_reindexed

Unnamed: 0,Array,Chunk
Bytes,5.27 GB,7.76 MB
Shape,"(4, 4748, 96, 19, 19)","(1, 28, 96, 19, 19)"
Count,14113 Tasks,2084 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 5.27 GB 7.76 MB Shape (4, 4748, 96, 19, 19) (1, 28, 96, 19, 19) Count 14113 Tasks 2084 Chunks Type float64 numpy.ndarray",4748  4  19  19  96,

Unnamed: 0,Array,Chunk
Bytes,5.27 GB,7.76 MB
Shape,"(4, 4748, 96, 19, 19)","(1, 28, 96, 19, 19)"
Count,14113 Tasks,2084 Chunks
Type,float64,numpy.ndarray


#### Calculate annual totals

In [8]:
def sum_min_samples(ds, dim, min_samples):
    """Return sum only if there are more than min_samples along dim."""
    
    s = ds.sum(dim, skipna=False)
    # Reference to final lead in sample
    if 'lead_time' in ds.coords:
        if dim in ds['lead_time'].dims:
            l = ds['lead_time'].max(dim, skipna=False)
            s = s.assign_coords({'lead_time': l if len(ds[dim]) >= min_samples else np.nan*l})
    if 'lead_year' in ds.coords:
        l = ds['lead_year'].max(dim, skipna=False)
        s = s.assign_coords({'lead_year': l if len(ds[dim]) >= min_samples else np.nan*l})
        
    return s if len(ds[dim]) >= min_samples else np.nan*s

In [9]:
# total annual (Dec- Nov) precip (mm)
pr_annual = pr_daily_reindexed.resample(time="A-DEC", label='right').apply(sum_min_samples, dim='time', min_samples=365)

In [10]:
pr_annual

Unnamed: 0,Array,Chunk
Bytes,15.53 MB,277.25 kB
Shape,"(4, 14, 96, 19, 19)","(1, 1, 96, 19, 19)"
Count,19329 Tasks,56 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 15.53 MB 277.25 kB Shape (4, 14, 96, 19, 19) (1, 1, 96, 19, 19) Count 19329 Tasks 56 Chunks Type float64 numpy.ndarray",14  4  19  19  96,

Unnamed: 0,Array,Chunk
Bytes,15.53 MB,277.25 kB
Shape,"(4, 14, 96, 19, 19)","(1, 1, 96, 19, 19)"
Count,19329 Tasks,56 Chunks
Type,float64,numpy.ndarray


In [11]:
pr_annual['time']

In [12]:
pr_annual['lead_time']

Convert lead times from days since init_date to years since init_date... 

In [13]:
DAYS_TO_ZERO_LEAD = 60   # what does this mean? 
lead_time_years = np.floor(((pr_annual.lead_time - DAYS_TO_ZERO_LEAD) / 365))
pr_annual = pr_annual.assign_coords({'lead_time': lead_time_years})
pr_annual['lead_time'].attrs['units'] = 'A'
pr_annual['lead_time']

In [14]:
pr_annual['init_date']

#### Spatial averaging

In [15]:
# spatial average of the total annual rainfall

# (Only some notebooks do spatial averaging. Others don't.)

# local_funcs.py used to have a function for this, which applied a region mask.
# That function doesn't exist anymore and I'm not sure how to access the region mask file
#
#def average_region(ds, region_mask):
#    return ds.where(region_mask).mean(['lat','lon'])

#da = da.mean(['lat', 'lon'])
#da

Switch back to lead_time index instead of time...

In [16]:
pr_annual = reindex_forecast(pr_annual)
pr_annual

Unnamed: 0,Array,Chunk
Bytes,9.98 MB,277.25 kB
Shape,"(4, 9, 96, 19, 19)","(1, 1, 96, 19, 19)"
Count,19548 Tasks,36 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 9.98 MB 277.25 kB Shape (4, 9, 96, 19, 19) (1, 1, 96, 19, 19) Count 19548 Tasks 36 Chunks Type float64 numpy.ndarray",9  4  19  19  96,

Unnamed: 0,Array,Chunk
Bytes,9.98 MB,277.25 kB
Shape,"(4, 9, 96, 19, 19)","(1, 1, 96, 19, 19)"
Count,19548 Tasks,36 Chunks
Type,float64,numpy.ndarray


## Do the bias correction

In [None]:
## Shift leads back by one so that obsv dates match fcst dates
#obsv_stacked = obsv_stacked.assign_coords({'lead_time': obsv_stacked.lead_time-1}) \
#                           .where(obsv_stacked.lead_time > 0, drop=True)
#
#mask = (fcst_hack.time >= period[0]) & (fcst_hack.time <= period[1])
#    
#fcst_clim = fcst_hack.mean('ensemble').where(mask, drop=True).groupby('init_date.month').mean('init_date')
#obsv_clim = obsv_stacked.where(mask, drop=True).groupby('init_date.month').mean('init_date')
#
#model_mulbias_precip_ann = (fcst_clim / obsv_clim).chunk('auto').rename('precip_ann')
#model_mulbias_precip_ann = model_mulbias_precip_ann.xc.cache(f'CAFE-model-mulbias_precip-ann-accl_{REGION_NAME}.zarr', clobber=clobber)
#
#model_addbias_precip_ann = (fcst_clim - obsv_clim).chunk('auto').rename('precip_ann')
#model_addbias_precip_ann = model_addbias_precip_ann.xc.cache(f'CAFE-model-addbias_precip-ann-accl_{REGION_NAME}.zarr', clobber=clobber)

In [None]:
#period = (np.datetime64('1990-01-01'), 
#          np.datetime64('2020-12-31'))

# model_mulbias_precip = lf.estimate_model_biases(f5_precip, awap_precip, period, mode='multiplicative').rename('precip')
# model_addbias_precip = lf.estimate_model_biases(f5_precip, awap_precip, period, mode='additive').rename('precip')

#def estimate_model_biases(fcst, obsv, period, mode):
#    """ mode = 'additive' or 'multiplicative' """
#    stack_dates = fcst.sel(
#        init_date=slice(
#            (pd.to_datetime(period[0]) - pd.DateOffset(years=10)).strftime('%Y-%m-%d'),
#            pd.to_datetime(period[1]).strftime('%Y-%m-%d'))).init_date.values
#
#    obsv_stacked = stack_by_init_date(
#        obsv, stack_dates, len(fcst.lead_time))
#    
#    mask = (fcst.time >= period[0]) & (fcst.time <= period[1])
#    
#    fcst_clim = fcst.mean('ensemble').where(mask, drop=True).groupby('init_date.month').mean('init_date')
#    obsv_clim = obsv_stacked.where(mask, drop=True).groupby('init_date.month').mean('init_date')
#    
#    if mode == 'additive':
#        return (fcst_clim - obsv_clim).chunk('auto')
#    elif mode == 'multiplicative':
#        return (fcst_clim / obsv_clim).chunk('auto')
#    else: 
#        raise ValueError(f'Unrecognised mode {mode}')
        

In [None]:
#f5_precip_mbc = lf.remove_model_bias(f5_precip, model_mulbias_precip, mode='multiplicative').rename('precip')
#f5_precip_abc = lf.remove_model_bias(f5_precip, model_addbias_precip, mode='additive').rename('precip')

#def remove_model_bias(fcst, bias, mode):
#    """ mode = 'additive' or 'multiplicative' """
#    if mode == 'additive':
#        return (fcst.groupby('init_date.month') - bias).chunk(fcst.chunks)
#    elif mode == 'multiplicative':
#        return (fcst.groupby('init_date.month') / bias).chunk(fcst.chunks)
#    else: 
#        raise ValueError(f'Unrecognised mode {mode}')