In [1]:
%matplotlib inline

import os
from subprocess import call
from glob import glob
import yaml

from tqdm import tqdm_notebook as tqdm

import xarray as xr
import numpy as np

import cftime
from datetime import datetime

USER = os.environ['USER']

In [2]:
PROJECT = os.environ['PBS_ACCOUNT']

## spin up dask cluster

In [3]:
if True:
    from dask.distributed import Client
    from dask_jobqueue import PBSCluster

    project = PROJECT

    cluster = PBSCluster(queue='economy',
                         cores = 36,
                         processes = 1,
                         memory = '100GB',          
                         project = project,
                         walltime = '02:00:00',
                         local_directory=f'/glade/scratch/{USER}/dask-tmp')

In [4]:
cluster.adapt(minimum=5, maximum=20)  # auto-scale between 20 and 50 workers

<distributed.deploy.adaptive.Adaptive at 0x2aaae184f198>

In [10]:
!qstat -u $USER


chadmin1: 
                                                            Req'd  Req'd   Elap
Job ID          Username Queue    Jobname    SessID NDS TSK Memory Time  S Time
--------------- -------- -------- ---------- ------ --- --- ------ ----- - -----
3139991.chadmin abanihi  economy  STDIN       48467   1   1    --  06:00 R 00:36
3140098.chadmin abanihi  economy  dask-worke  29007   1   1    --  02:00 R 00:00
3140099.chadmin abanihi  economy  dask-worke  37364   1   1    --  02:00 R 00:00
3140100.chadmin abanihi  economy  dask-worke  11972   1   1    --  02:00 R 00:00
3140101.chadmin abanihi  economy  dask-worke    --    1   1    --  02:00 Q   -- 
3140102.chadmin abanihi  economy  dask-worke    --    1   1    --  02:00 Q   -- 
3140104.chadmin abanihi  economy  dask-worke    --    1   1    --  02:00 Q   -- 


In [13]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://10.148.2.69:46251  Dashboard: http://10.148.2.69:8787/status,Cluster  Workers: 5  Cores: 180  Memory: 500.00 GB


## what to do?

In [14]:
var = 'O2' # O2 (3D var), HMXL (2D var)
varout = var
isel = {}

#varout = 'npac_O2'
#isel = {'nlat':slice(187,331),'nlon':slice(137,276)}

rename_var = False

n_ensemble_max = 2 # use only the first `n_ensemble_max` members

mean = 'ann'
n_forecast_lead = 10 # related to expected freq from "mean"

diro = f'/glade/scratch/{USER}/calcs/o2-prediction'

## more advanced options

In [15]:
file_format = 'nc' # nc or zarr for output

file_out = {'mean': f'{diro}/CESM-DP-LE.{varout}.{mean}.mean.{file_format}',
            'drift': f'{diro}/CESM-DP-LE.{varout}.{mean}.mean.drift.{file_format}',
            'anom': f'{diro}/CESM-DP-LE.{varout}.{mean}.mean.anom.{file_format}'}

first_start_year, last_start_year = 2011, 2015
first_clim_year, last_clim_year = 2013, 2014

dir_dple = '/glade/p_old/decpred/CESM-DPLE'
case_prefix = 'b.e11.BDP.f09_g16'


#### Define function for writing output based on file extension (zarr or nc)

In [16]:
def write_output(ds,file_out,attrs={}):
    '''Function to write output:
       - optionally add some file-level attrs
       - switch method based on file extension
       '''
    
    diro = os.path.dirname(file_out)
    if not os.path.exists(diro):
        call(['mkdir','-p',diro])

    if os.path.exists(file_out):
        call(['rm','-fr',file_out])          

    if attrs:
        ds.attrs = attrs
       
    ext = os.path.splitext(file_out)[1]    
    if ext == '.nc':
        print(f'writing {file_out}')
        ds.to_netcdf(file_out,compute=True)
        
    elif ext == '.zarr':
        print(f'writing {file_out}')
        ds.to_zarr(file_out,compute=True)
        
    else:
        raise ValueError('Unknown output file extension: {ext}')
        

#### Some helper functions for reading datasets

In [17]:
def time_bound_var(ds):
    tb_name = ''
    if 'bounds' in ds['time'].attrs:
        tb_name = ds['time'].attrs['bounds']
    elif 'time_bound' in ds:
        tb_name = 'time_bound'
    else:
        raise ValueError('No time_bound variable found')
    tb_dim = ds[tb_name].dims[-1]
    return tb_name,tb_dim

def fix_time(ds):
    tb_name,tb_dim = time_bound_var(ds)
    time = cftime.num2date(ds[tb_name].isel(M=0).mean(tb_dim),
                           units = ds.time.attrs['units'],
                           calendar = ds.time.attrs['calendar'])

    ds.time.values = time
    ds = ds.drop([tb_name])
    return ds

# compute drift
Drift is assessed by computing a lead-time dependent climatology

1. Assemble forecast ensemble variable into `start_year x lead_time x lat x lon` array. 
1. Compute ensemble mean
1. Filter start dates outside of the "verification window"
1. Compute mean over start dates as a function of lead time: this is the "drift"
1. Subtract drift from full field


First step, generate a `start_year` coordinate.

In [18]:
S = xr.DataArray(np.arange(first_start_year,last_start_year+1,1,dtype='int32')+1,
                 dims='S',
                 attrs={'long_name':'start year'})
S

<xarray.DataArray (S: 5)>
array([2012, 2013, 2014, 2015, 2016], dtype=int32)
Dimensions without coordinates: S
Attributes:
    long_name:  start year

Generate `lead_time` array

In [19]:
L = xr.DataArray(np.arange(1,n_forecast_lead+1,1,dtype='int32'),
                 dims='L',
                 attrs={'long_name':'forecast lead'})
L

<xarray.DataArray (L: 10)>
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
Dimensions without coordinates: L
Attributes:
    long_name:  forecast lead

## Find the files for each start year

In [20]:
files_by_year = {}
n = np.zeros(len(S),dtype='int32')

for i in range(0,len(S)):
    year = S.values[i]-1
    files_by_year[year] = sorted(glob(f'{dir_dple}/monthly/{var}/{case_prefix}.{year}*.nc'))[:n_ensemble_max]                              
    n[i] = len(files_by_year[year])
    
#-- ensure that we have the same number of files for each start year
n_ensemble = n[0]
np.testing.assert_equal(n_ensemble,n)

print(f'{n_ensemble} files for {len(S)} start years.')

2 files for 5 start years.


In [21]:
L

<xarray.DataArray (L: 10)>
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
Dimensions without coordinates: L
Attributes:
    long_name:  forecast lead

## Assemble forecast ensemble  

Generate a `start_year x lead_time x lat x lon` array

In [22]:
%%time

grid_vars = []
ds_list = []

#-- loop over files for each year
for year,files in tqdm(files_by_year.items()):

    #-- open the datasets
    ds = xr.open_mfdataset(files,
                           concat_dim='M',
                           decode_times = False, 
                           decode_coords = False,
                           engine='netcdf4')
    
    #-- store "grid" variables, attributes, and encoding
    if not grid_vars:
        grid_vars = [v for v,da in ds.variables.items() if 'time' not in da.dims]
        ds_grid = ds.drop([v for v in ds.variables if v not in grid_vars]).isel(M=0)
        dsattrs = ds.attrs
        attrs = {v:da.attrs for v,da in ds.variables.items()}
        encoding = {v:{key:val for key,val in da.encoding.items() 
                       if key in ['dtype','_FillValue','missing_value']}
                    for v,da in ds.variables.items()}
              
    #-- fix time and drop extraneous variables
    ds = fix_time(ds)
    ds = ds.drop([v for v in ds.variables if v not in ds.coords and v != var])
    
    #-- subset?
    if isel:
        ds = ds.isel(**isel)
        if 'coordinates' in attrs[var]:
            del attrs[var]['coordinates']
        if 'coordinates' in ds[var].attrs:
            del ds[var].attrs['coordinates']    
            
    #-- compute appropriate average TODO: ds = calc.compute_{freq}_mean(ds).rename({'time':'L'})
    ds = ds.groupby('time.year').mean('time').rename({'year':'L'})   
    if len(ds.L) == n_forecast_lead+1:
        ds = ds.isel(L=slice(1,11))       
    ds['L'] = L 
    
    for v in ds.variables:
        if v in attrs:
            ds[v].attrs = attrs[v]
        if v in encoding:
            ds[v].encoding = encoding[v]
            
    ds_list.append(ds)

#-- assemble into single dataset
ds = xr.concat(ds_list,dim='S')
ds['S'] = S

#-- rechunk to more suitable sizes
new_chunks = {'S':1,'L':len(L),'M':n_ensemble}

#if 'z_t' in ds[var].dims:
#    new_chunks = {'S':len(S),'L':len(L),'M':n_ensemble,'nlat':16,'nlon':16}
    
ds = ds.chunk(new_chunks)

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))


CPU times: user 1.96 s, sys: 120 ms, total: 2.08 s
Wall time: 3.62 s


In [23]:
ds_list

[<xarray.Dataset>
 Dimensions:       (L: 10, M: 2, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
 Coordinates:
   * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
   * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
   * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
   * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
   * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
   * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
   * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
   * L             (L) int32 1 2 3 4 5 6 7 8 9 10
 Dimensions without coordinates: M, nlat, nlon
 Data variables:
     O2            (L, M, z_t, nlat, nlon) float32 dask.array<shape=(10, 2, 60, 384, 320), chunksize=(1, 1, 60, 384, 320)>,
 <xarray.Dataset>
 Dimensions:       (L: 1

In [24]:
ds

<xarray.Dataset>
Dimensions:       (L: 10, M: 2, S: 5, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
Coordinates:
  * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
  * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
  * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
  * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
  * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
  * L             (L) int32 1 2 3 4 5 6 7 8 9 10
  * S             (S) int32 2012 2013 2014 2015 2016
Dimensions without coordinates: M, nlat, nlon
Data variables:
    O2            (S, L, M, z_t, nlat, nlon) float32 dask.array<shape=(5, 10, 2, 60, 384, 320), chunksize=(1, 10, 2, 60,

Call `persist` on this dataset to load it into distributed memory.

In [25]:
from dask.distributed import wait

In [26]:
%%time
#ds = ds.persist()
#wait(ds)
ds.info()

xarray.Dataset {
dimensions:
	L = 10 ;
	M = 2 ;
	S = 5 ;
	lat_aux_grid = 395 ;
	moc_z = 61 ;
	nlat = 384 ;
	nlon = 320 ;
	z_t = 60 ;
	z_t_150m = 15 ;
	z_w = 60 ;
	z_w_bot = 60 ;
	z_w_top = 60 ;

variables:
	float32 z_t(z_t) ;
		z_t:units = centimeters ;
		z_t:long_name = depth from surface to midpoint of layer ;
		z_t:valid_min = 500.0 ;
		z_t:valid_max = 537500.0 ;
		z_t:positive = down ;
	float32 z_w(z_w) ;
		z_w:units = centimeters ;
		z_w:long_name = depth from surface to top of layer ;
		z_w:valid_min = 0.0 ;
		z_w:valid_max = 525000.9375 ;
		z_w:positive = down ;
	float32 moc_z(moc_z) ;
		moc_z:units = centimeters ;
		moc_z:long_name = depth from surface to top of layer ;
		moc_z:valid_min = 0.0 ;
		moc_z:valid_max = 549999.0625 ;
		moc_z:positive = down ;
	float32 z_w_top(z_w_top) ;
		z_w_top:units = centimeters ;
		z_w_top:long_name = depth from surface to top of layer ;
		z_w_top:valid_min = 0.0 ;
		z_w_top:valid_max = 525000.9375 ;
		z_w_top:positive = down ;
	float32 z_w_bot

### generate a `verification_time` matrix

In [27]:
verification_time = S + 0.5 + L - 1
verification_time

<xarray.DataArray (S: 5, L: 10)>
array([[2012.5, 2013.5, 2014.5, 2015.5, 2016.5, 2017.5, 2018.5, 2019.5,
        2020.5, 2021.5],
       [2013.5, 2014.5, 2015.5, 2016.5, 2017.5, 2018.5, 2019.5, 2020.5,
        2021.5, 2022.5],
       [2014.5, 2015.5, 2016.5, 2017.5, 2018.5, 2019.5, 2020.5, 2021.5,
        2022.5, 2023.5],
       [2015.5, 2016.5, 2017.5, 2018.5, 2019.5, 2020.5, 2021.5, 2022.5,
        2023.5, 2024.5],
       [2016.5, 2017.5, 2018.5, 2019.5, 2020.5, 2021.5, 2022.5, 2023.5,
        2024.5, 2025.5]])
Dimensions without coordinates: S, L

### compute ensemble mean across forecast ensemble

In [28]:
dse = ds.mean('M')
dse

<xarray.Dataset>
Dimensions:       (L: 10, S: 5, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
Coordinates:
  * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
  * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
  * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
  * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
  * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
  * L             (L) int32 1 2 3 4 5 6 7 8 9 10
  * S             (S) int32 2012 2013 2014 2015 2016
Dimensions without coordinates: nlat, nlon
Data variables:
    O2            (S, L, z_t, nlat, nlon) float32 dask.array<shape=(5, 10, 60, 384, 320), chunksize=(1, 10, 60, 384, 320)>

### Compute "drift"
Filter start dates outside of the "verification window" 
Compute mean over start dates as a function of lead time: this is the "drift"

In [29]:
drift = dse.where((first_clim_year<verification_time) & 
                  (verification_time<last_clim_year+1) ).mean('S')

for v in drift.variables:
    if v in attrs:
        drift[v].attrs = attrs[v]
    if v in encoding:
        drift[v].encoding = encoding[v]
        
drift

<xarray.Dataset>
Dimensions:       (L: 10, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
Coordinates:
  * L             (L) int64 1 2 3 4 5 6 7 8 9 10
  * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
  * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
  * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
  * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
  * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
Dimensions without coordinates: nlat, nlon
Data variables:
    O2            (L, z_t, nlat, nlon) float32 dask.array<shape=(10, 60, 384, 320), chunksize=(10, 60, 384, 320)>

### Compute bias correction

In [30]:
anom = ds - drift

for v in anom.variables:
    if v in attrs:
        anom[v].attrs = attrs[v]
    if v in encoding:
        anom[v].encoding = encoding[v]
anom

<xarray.Dataset>
Dimensions:       (L: 10, M: 2, S: 5, lat_aux_grid: 395, moc_z: 61, nlat: 384, nlon: 320, z_t: 60, z_t_150m: 15, z_w: 60, z_w_bot: 60, z_w_top: 60)
Coordinates:
  * z_t           (z_t) float32 500.0 1500.0 2500.0 ... 512502.8 537500.0
  * z_w           (z_w) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * moc_z         (moc_z) float32 0.0 1000.0 2000.0 ... 525000.94 549999.06
  * z_w_top       (z_w_top) float32 0.0 1000.0 2000.0 ... 500004.7 525000.94
  * z_w_bot       (z_w_bot) float32 1000.0 2000.0 3000.0 ... 525000.94 549999.06
  * lat_aux_grid  (lat_aux_grid) float32 -79.48815 -78.952896 ... 89.47441 90.0
  * z_t_150m      (z_t_150m) float32 500.0 1500.0 2500.0 ... 13500.0 14500.0
  * L             (L) int32 1 2 3 4 5 6 7 8 9 10
  * S             (S) int32 2012 2013 2014 2015 2016
Dimensions without coordinates: M, nlat, nlon
Data variables:
    O2            (S, L, M, z_t, nlat, nlon) float32 dask.array<shape=(5, 10, 2, 60, 384, 320), chunksize=(1, 10, 2, 60,

### write output to file

In [31]:
dsattrs['history'] = f'created by {USER} on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}'

In [32]:
%%time
if var != varout and rename_var:
    ds.rename({var:varout},inplace=True)
    
write_output(ds, 
             file_out = file_out['mean'], 
             attrs = dsattrs)

writing /glade/scratch/abanihi/calcs/o2-prediction/CESM-DP-LE.O2.ann.mean.nc
CPU times: user 7.77 s, sys: 596 ms, total: 8.37 s
Wall time: 2min 22s


In [33]:
dsattrs['climatology'] = f'{first_clim_year:d}-{last_clim_year}, computed separately for each lead time'

In [34]:
%%time
if rename_var:
    anom.rename({var:'anom'},inplace=True)

write_output(anom, 
             file_out = file_out['anom'], 
             attrs = dsattrs)

writing /glade/scratch/abanihi/calcs/o2-prediction/CESM-DP-LE.O2.ann.mean.anom.nc
CPU times: user 6.98 s, sys: 556 ms, total: 7.54 s
Wall time: 1min 50s


In [35]:
%%time
if rename_var:
    drift.rename({var:'climo'},inplace=True)

write_output(drift, 
             file_out = file_out['drift'], 
             attrs = dsattrs)

writing /glade/scratch/abanihi/calcs/o2-prediction/CESM-DP-LE.O2.ann.mean.drift.nc
CPU times: user 6.7 s, sys: 420 ms, total: 7.12 s
Wall time: 1min 46s


In [36]:
print('dataset size in GB {:0.2f}\n'.format(ds.nbytes / 1e9))

dataset size in GB 2.95



**pip install version_information**

In [37]:
%load_ext version_information
%version_information netcdf4, xarray, dask, numpy, zarr

Software,Version
Python,3.6.6 64bit [GCC 4.8.2 20140120 (Red Hat 4.8.2-15)]
IPython,7.0.1
OS,Linux 3.12.62 60.64.8 default x86_64 with SuSE 12 x86_64
netcdf4,1.4.1
xarray,0.10.9
dask,0.19.4
numpy,1.15.2
zarr,2.2.0
Mon Oct 29 00:26:37 2018 MDT,Mon Oct 29 00:26:37 2018 MDT
