In [1]:
%matplotlib inline

import os
from subprocess import call
from glob import glob
import yaml

from tqdm import tqdm_notebook as tqdm

import xarray as xr
import numpy as np

import cftime
from datetime import datetime
import dask 

USER = os.environ['USER']

In [2]:
PROJECT = os.environ['PBS_ACCOUNT']

## spin up dask cluster

In [4]:
if True:
    from dask.distributed import Client
    from dask_jobqueue import PBSCluster

    project = PROJECT

    cluster = PBSCluster(queue='economy',
                         cores = 36,
                         processes = 1,
                         memory = '40GB',          
                         project = project,
                         walltime = '01:00:00',
                         local_directory=f'/glade/scratch/{USER}/dask-tmp')

In [64]:
cluster.scale(30)

In [4]:
# cluster.adapt(minimum=30, maximum=40)  # auto-scale between 10 and 30 workers

<distributed.deploy.adaptive.Adaptive at 0x2aaae1900a20>

In [75]:
!qstat -u $USER


chadmin1: 
                                                            Req'd  Req'd   Elap
Job ID          Username Queue    Jobname    SessID NDS TSK Memory Time  S Time
--------------- -------- -------- ---------- ------ --- --- ------ ----- - -----
3140192.chadmin abanihi  economy  STDIN       37244   1   1    --  06:00 R 03:51
3141698.chadmin abanihi  economy  dask-worke  31925   1   1    --  01:00 R 01:00
3141700.chadmin abanihi  economy  dask-worke  47368   1   1    --  01:00 R 01:00
3141805.chadmin abanihi  economy  dask-worke  13350   1   1    --  01:00 R 00:01
3141806.chadmin abanihi  economy  dask-worke  65321   1   1    --  01:00 R 00:01
3141807.chadmin abanihi  economy  dask-worke  65060   1   1    --  01:00 R 00:01
3141808.chadmin abanihi  economy  dask-worke   1695   1   1    --  01:00 R 00:01
3141809.chadmin abanihi  economy  dask-worke  22133   1   1    --  01:00 R 00:01
3141810.chadmin abanihi  economy  dask-worke  20627   1   1    --  01:00 R 00:01
3141811.chadmin ab

In [76]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://10.148.0.173:46024  Dashboard: http://10.148.0.173:8787/status,Cluster  Workers: 30  Cores: 1080  Memory: 1.20 TB


## what to do?

In [77]:
var = 'O2' # O2 (3D var), HMXL (2D var)
varout = var
isel = {}

#varout = 'npac_O2'
#isel = {'nlat':slice(187,331),'nlon':slice(137,276)}

rename_var = False

n_ensemble_max = 10 # use only the first `n_ensemble_max` members

mean = 'ann'
n_forecast_lead = 10 # related to expected freq from "mean"

diro = f'/glade/scratch/{USER}/calcs/o2-prediction'

## more advanced options

In [78]:
file_format = 'zarr' # nc or zarr for output

file_out = {'mean': f'{diro}/CESM-DP-LE.{varout}.{mean}.mean.{file_format}',
            'drift': f'{diro}/CESM-DP-LE.{varout}.{mean}.mean.drift.{file_format}',
            'anom': f'{diro}/CESM-DP-LE.{varout}.{mean}.mean.anom.{file_format}'}

first_start_year, last_start_year = 1954, 2015
first_clim_year, last_clim_year = 1964, 2014

dir_dple = '/glade/p_old/decpred/CESM-DPLE'
case_prefix = 'b.e11.BDP.f09_g16'


#### Define function for writing output based on file extension (zarr or nc)

For better performance(speed, parallel write), we will make this function lazy with `dask.delayed`

In [79]:
@dask.delayed
def write_output(ds,file_out,attrs={}):
    '''Function to write output:
       - optionally add some file-level attrs
       - switch method based on file extension
       '''
    
    diro = os.path.dirname(file_out)
    if not os.path.exists(diro):
        call(['mkdir','-p',diro])

    if os.path.exists(file_out):
        call(['rm','-fr',file_out])          

    if attrs:
        ds.attrs = attrs
       
    ext = os.path.splitext(file_out)[1]    
    if ext == '.nc':
        print(f'writing {file_out}')
        ds.to_netcdf(file_out,compute=True)
        
    elif ext == '.zarr':
        print(f'writing {file_out}')
        ds.to_zarr(file_out,compute=True)
        
    else:
        raise ValueError('Unknown output file extension: {ext}')
        

#### Some helper functions for reading datasets

In [80]:
def time_bound_var(ds):
    tb_name = ''
    if 'bounds' in ds['time'].attrs:
        tb_name = ds['time'].attrs['bounds']
    elif 'time_bound' in ds:
        tb_name = 'time_bound'
    else:
        raise ValueError('No time_bound variable found')
    tb_dim = ds[tb_name].dims[-1]
    return tb_name,tb_dim

def fix_time(ds):
    tb_name,tb_dim = time_bound_var(ds)
    time = cftime.num2date(ds[tb_name].isel(M=0).mean(tb_dim),
                           units = ds.time.attrs['units'],
                           calendar = ds.time.attrs['calendar'])

    ds.time.values = time
    ds = ds.drop([tb_name])
    return ds

# compute drift
Drift is assessed by computing a lead-time dependent climatology

1. Assemble forecast ensemble variable into `start_year x lead_time x lat x lon` array. 
1. Compute ensemble mean
1. Filter start dates outside of the "verification window"
1. Compute mean over start dates as a function of lead time: this is the "drift"
1. Subtract drift from full field


First step, generate a `start_year` coordinate.

In [81]:
S = xr.DataArray(np.arange(first_start_year,last_start_year+1,1,dtype='int32')+1,
                 dims='S',
                 attrs={'long_name':'start year'})
S

<xarray.DataArray (S: 62)>
array([1955, 1956, 1957, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966,
       1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978,
       1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990,
       1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002,
       2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014,
       2015, 2016], dtype=int32)
Dimensions without coordinates: S
Attributes:
    long_name:  start year

Generate `lead_time` array

In [82]:
L = xr.DataArray(np.arange(1,n_forecast_lead+1,1,dtype='int32'),
                 dims='L',
                 attrs={'long_name':'forecast lead'})
L

<xarray.DataArray (L: 10)>
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
Dimensions without coordinates: L
Attributes:
    long_name:  forecast lead

## Find the files for each start year

In [83]:
files_by_year = {}
n = np.zeros(len(S),dtype='int32')

for i in range(0,len(S)):
    year = S.values[i]-1
    files_by_year[year] = sorted(glob(f'{dir_dple}/monthly/{var}/{case_prefix}.{year}*.nc'))[:n_ensemble_max]                              
    n[i] = len(files_by_year[year])
    
#-- ensure that we have the same number of files for each start year
n_ensemble = n[0]
np.testing.assert_equal(n_ensemble,n)

print(f'{n_ensemble} files for {len(S)} start years.')

10 files for 62 start years.


In [84]:
L

<xarray.DataArray (L: 10)>
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
Dimensions without coordinates: L
Attributes:
    long_name:  forecast lead

## Assemble forecast ensemble  

Generate a `start_year x lead_time x lat x lon` array

In [85]:
%%time

grid_vars = []
ds_list = []

def read_files(files, grid_vars=grid_vars):
    #-- open the datasets
    global dsattrs
    global attrs
    global encoding 
    
    ds = xr.open_mfdataset(files,
                           concat_dim='M',
                           decode_times = False, 
                           decode_coords = False,
                           engine='netcdf4')
    #-- store "grid" variables, attributes, and encoding
    if not grid_vars:
        grid_vars = [v for v,da in ds.variables.items() if 'time' not in da.dims]
        ds_grid = ds.drop([v for v in ds.variables if v not in grid_vars]).isel(M=0)
        dsattrs = ds.attrs
        attrs = {v:da.attrs for v,da in ds.variables.items()}
        encoding = {v:{key:val for key,val in da.encoding.items() 
                       if key in ['dtype','_FillValue','missing_value']}
                    for v,da in ds.variables.items()}
              
    #-- fix time and drop extraneous variables
    ds = fix_time(ds)
    ds = ds.drop([v for v in ds.variables if v not in ds.coords and v != var])
    

    #-- subset?
    if isel:
        ds = ds.isel(**isel)
        if 'coordinates' in attrs[var]:
            del attrs[var]['coordinates']
        if 'coordinates' in ds[var].attrs:
            del ds[var].attrs['coordinates']
            
    #-- compute appropriate average TODO: ds = calc.compute_{freq}_mean(ds).rename({'time':'L'})
    ds = ds.groupby('time.year').mean('time').rename({'year':'L'})   
    if len(ds.L) == n_forecast_lead+1:
        ds = ds.isel(L=slice(1,11))       
    ds['L'] = L 
    
    for v in ds.variables:
        if v in attrs:
            ds[v].attrs = attrs[v]
        if v in encoding:
            ds[v].encoding = encoding[v]
            
    return ds

ds_list = [read_files(files) for year,files in tqdm(files_by_year.items())]

HBox(children=(IntProgress(value=0, max=62), HTML(value='')))




CPU times: user 1min 4s, sys: 2.47 s, total: 1min 7s
Wall time: 1min 18s


In [86]:
#-- assemble into single dataset
ds = xr.concat(ds_list,dim='S')
ds['S'] = S

In [87]:
ds

<xarray.Dataset>
Dimensions:       (L: 10, M: 10, S: 62, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
Coordinates:
  * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
  * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
  * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
  * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
  * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
  * L             (L) int32 1 2 3 4 5 6 7 8 9 10
  * S             (S) int32 1955 1956 1957 1958 1959 ... 2013 2014 2015 2016
Dimensions without coordinates: M, nlat, nlon
Data variables:
    O2            (S, L, M, z_t, nlat, nlon) float32 dask.array<shape=(62, 10, 10, 60, 384, 32

In [88]:
print('dataset size in GB {:0.2f}\n'.format(ds.nbytes / 1e9))

dataset size in GB 182.85



In [90]:
#-- rechunk to more suitable sizes
# new_chunks = {'S':1,'L':len(L),'M':n_ensemble}
# new_chunks = {'S':1,'L':1,'M':1}
# if 'z_t' in ds[var].dims:
#    new_chunks = {'S':len(S),'L':len(L),'M':n_ensemble,'nlat':16,'nlon':16}
 

# ds = ds.chunk(new_chunks)
# ds

Call `persist` on this dataset to load it into distributed memory.

In [91]:
from dask.distributed import wait

In [92]:
%%time
ds = ds.persist()
# wait(ds)
# ds.info()

CPU times: user 4.43 s, sys: 52 ms, total: 4.48 s
Wall time: 4.47 s


### generate a `verification_time` matrix

In [None]:
verification_time = S + 0.5 + L - 1
verification_time

<xarray.DataArray (S: 62, L: 10)>
array([[1955.5, 1956.5, 1957.5, ..., 1962.5, 1963.5, 1964.5],
       [1956.5, 1957.5, 1958.5, ..., 1963.5, 1964.5, 1965.5],
       [1957.5, 1958.5, 1959.5, ..., 1964.5, 1965.5, 1966.5],
       ...,
       [2014.5, 2015.5, 2016.5, ..., 2021.5, 2022.5, 2023.5],
       [2015.5, 2016.5, 2017.5, ..., 2022.5, 2023.5, 2024.5],
       [2016.5, 2017.5, 2018.5, ..., 2023.5, 2024.5, 2025.5]])
Dimensions without coordinates: S, L

### compute ensemble mean across forecast ensemble

In [None]:
dse = ds.mean('M')
dse

<xarray.Dataset>
Dimensions:       (L: 10, S: 62, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
Coordinates:
  * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
  * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
  * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
  * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
  * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
  * L             (L) int32 1 2 3 4 5 6 7 8 9 10
  * S             (S) int32 1955 1956 1957 1958 1959 ... 2013 2014 2015 2016
Dimensions without coordinates: nlat, nlon
Data variables:
    O2            (S, L, z_t, nlat, nlon) float32 dask.array<shape=(62, 10, 60, 384, 320), chunksize=(1,

### Compute "drift"
Filter start dates outside of the "verification window" 
Compute mean over start dates as a function of lead time: this is the "drift"

In [None]:
drift = dse.where((first_clim_year<verification_time) & 
                  (verification_time<last_clim_year+1) ).mean('S')

for v in drift.variables:
    if v in attrs:
        drift[v].attrs = attrs[v]
    if v in encoding:
        drift[v].encoding = encoding[v]
        
drift

<xarray.Dataset>
Dimensions:       (L: 10, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
Coordinates:
  * L             (L) int64 1 2 3 4 5 6 7 8 9 10
  * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
  * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
  * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
  * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
  * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
Dimensions without coordinates: nlat, nlon
Data variables:
    O2            (L, z_t, nlat, nlon) float32 dask.array<shape=(10, 60, 384, 320), chunksize=(1, 60, 384, 320)>

### Compute bias correction

In [None]:
anom = ds - drift

for v in anom.variables:
    if v in attrs:
        anom[v].attrs = attrs[v]
    if v in encoding:
        anom[v].encoding = encoding[v]
anom

<xarray.Dataset>
Dimensions:       (L: 10, M: 10, S: 62, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
Coordinates:
  * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
  * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
  * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
  * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
  * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
  * L             (L) int32 1 2 3 4 5 6 7 8 9 10
  * S             (S) int32 1955 1956 1957 1958 1959 ... 2013 2014 2015 2016
Dimensions without coordinates: M, nlat, nlon
Data variables:
    O2            (S, L, M, z_t, nlat, nlon) float32 dask.array<shape=(62, 10, 10, 60, 384, 32

### write output to file

In [None]:
dsattrs['history'] = f'created by {USER} on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}'

In [None]:
%%time
if var != varout and rename_var:
    ds.rename({var:varout},inplace=True)
    
mean_output = write_output(ds, 
             file_out = file_out['mean'], 
             attrs = dsattrs)

CPU times: user 88 ms, sys: 4 ms, total: 92 ms
Wall time: 90.4 ms


In [None]:
dsattrs['climatology'] = f'{first_clim_year:d}-{last_clim_year}, computed separately for each lead time'

In [None]:
%%time
if rename_var:
    anom.rename({var:'anom'},inplace=True)

anom_output = write_output(anom, 
             file_out = file_out['anom'], 
             attrs = dsattrs)

CPU times: user 920 ms, sys: 8 ms, total: 928 ms
Wall time: 924 ms


In [None]:
%%time
if rename_var:
    drift.rename({var:'climo'},inplace=True)

climo_output = write_output(drift, 
             file_out = file_out['drift'], 
             attrs = dsattrs)

CPU times: user 268 ms, sys: 4 ms, total: 272 ms
Wall time: 270 ms


In [None]:
%%time
dask.compute([mean_output, anom_output, climo_output])



**pip install version_information**

In [None]:
%load_ext version_information
%version_information netcdf4, xarray, dask, numpy, zarr