In [1]:
#translates Steve Yeager's NCL scripts into python w/dask
#Processes and bias corrects CESM-DPLE output and writes out as a netcdf4
#Will work for CAM and POP fields
#Should be able to (eventually) handle annual, seasonal, and monthly means
#will need a separate script to handle 3D fields

#I've marked in ALL CAPS places that need to be altered
#-Liz Maroon 9/3/2018

#import packages
import dask                           #for using multiple cores 
import xarray as xr                   #for netcdf manipulation
from dask.distributed import Client   #for distributing job across cores
from dask_jobqueue import PBSCluster  #for cheyenne pbs scheduler
import numpy as np                    #for numerics
import dask.array as da               #for out-of-memory array setup
import datetime                       #to correct time issue so get annual means right
from collections import OrderedDict   #for setting netcdf attributes
import os                             #these last three packages used to detect username/script location
import pwd
import sys
import glob


import project as P

%matplotlib notebook

In [2]:
#HERE IS WHERE TO SET WHICH FILES TO PROCESS
VAR='TS'# VAR='TS'
MODEL='ATM'# MODEL='ATM' #SET HERE IF PROCESSING CAM OR POP OUTPUT - can write catches for LND/ICE later as needed
ISEL = {} #{'z_t': 19}

VARO = 'TS'

WHICHMEAN='ANN'

#WHERE ARE DPLE FILES CURRENTLY?
DPLE_DIR='/glade/p_old/decpred/CESM-DPLE/'

#WHERE AND WHAT DO YOU WANT TO CALL OUTPUT FILES?
RAWDPOUT=f'{P.dirt}/maroon/CESM-DP-LE.{VARO}.{WHICHMEAN.lower()}.mean.nc'
DRIFTOUT=f'{P.dirt}/maroon/CESM-DP-LE.{VARO}.{WHICHMEAN.lower()}.mean.drift.nc'
ANOMOUT=f'{P.dirt}/maroon/CESM-DP-LE.{VARO}.{WHICHMEAN.lower()}.mean.anom.nc'

#THINGS TO SPECIFY FOR CHEYENNE REQUEST GO HERE
projectCode='NCGD0011'
ncpu=36
numNodes=2
memory='80GB' #memory per node/worker

In [3]:
#a couple of commands to setup the dask stuff
#dask.config.set({'distributed.dashboard.link':'http://localhost:{port}/status'})


In [4]:
from dask.distributed import Client
from dask_jobqueue import PBSCluster

USER = os.environ['USER']

# Lots of arguments to this command are set in ~/.config/dask/jobqueue.yaml
cluster = PBSCluster(queue='regular',
                     cores = 36,
                     processes = 9,
                     memory = '100GB',                     
                     project = 'NCGD0033',
                     walltime = '04:00:00',
                     local_directory=f'/glade/scratch/{USER}/dask-tmp')
client = Client(cluster)

Port 8787 is already in use. 
Perhaps you already have a cluster running?
Hosting the diagnostics dashboard on a random port instead.


In [5]:
Nnodes = 2
cluster.scale(9*Nnodes)

In [10]:
#Make array for start years
first_syear=1954
last_syear=2015
S=np.arange(first_syear+1,last_syear+1.01,1,dtype='int32')  #hacky addition of 1.01 to end year b/c roundoff weirdness
prefix='b.e11.BDP.f09_g16.'

#hacky function to correct for xarray not working well with times outside the "allowed" range
def mfload(thesefiles,isel={}):
    d1800=1800*365 #assumes noleap calender
    if MODEL=='OCN':
        ds=xr.open_mfdataset(thesefiles,concat_dim="ensemble", chunks={"time": 122},decode_times=False)
        ds.time.values = ds.time_bound.mean(ds.time_bound.dims[-1]).isel(ensemble=0)
        ds['time'].values=ds['time'].values-d1800
        ds['time'].attrs['units']='days since 1800-01-01 00:00:00'
        ds['time_bound'].attrs['units']='days since 1800-01-01 00:00:00'
        ds['time_bound'].values=ds['time_bound'].values-d1800
        ds=xr.decode_cf(ds,decode_times=True)
    
    else:
        ds=xr.open_mfdataset(thesefiles,concat_dim="ensemble", chunks={"time": 122},decode_times=False)
        ds.time.values = ds.time_bnds.mean(ds.time_bnds.dims[-1]).isel(ensemble=0)
        ds=xr.decode_cf(ds,decode_times=True)
    if isel:
        ds=ds.isel(**isel)
        
    # MCL: eliminated below, taking average of time_bound above
    #ds['time']=ds['time']-np.timedelta64(15,'D')  #subtracting 15 days from 'time'
    #hacky correction for time vs time_bnds issue. time is for end of month, not middle, 
    #xarray treats that as the following month, not the previous one. 
    #Without correcting the annual means will be offset by 1 month 
    #Will test on POP later. I assume NCL is smarter and also reads the time_bnds.
    return ds
    
#function that opens all the datasets for one start year and does the annual mean
#will implement seasonal mean later
def readyear(year,whichmean):
    #reads in ensemble for one startyear
    loadthesefiles=sorted(glob.glob(f"{DPLE_DIR}/monthly/{VAR}/{prefix}{year}*.nc"))
    ds=mfload(loadthesefiles,isel=ISEL)
    #ds=xr.open_mfdataset(f"{DPLE_DIR}/monthly/{VAR}/{prefix}{year}*.nc",concat_dim="ensemble", chunks={"time": 122})
    #chunk size is set to length of each monthly mean file (122 months)
    #will need to change if working with daily files
    if whichmean=='ANN':
        #we're using the ".data" attribute at the end to only return a dask array, not xarray
        return (ds[VAR].groupby('time.year').mean('time')).isel(year=slice(1,11)).data
    else:    
        return ds[VAR].data

#here's where the function for reading/meaning is called, but it is "delayed"
#dask makes a graph for how to split the computation up across nodes but does not compute until
#explicitly called later
#the array for each start year are read into a list
#pulling them in here as dask arrays (not xarrays) b/c want to make one big xarray of all startyears
#and figured out how to concatenate this way first. There might be a better way to do this.
dask_arrays=[dask.delayed(readyear)(ss,WHICHMEAN) for ss in S-1]
print('after delayed call')

after delayed call


In [11]:
#because reading in rest of arrays as dask array, can't pull dim sizes and attrs from dask_arrays
#need to open one startyear in as an xarray to grab dim sizes and attrs
loadthesefiles=sorted(glob.glob(f"{DPLE_DIR}/monthly/{VAR}/{prefix}1954*.nc"))
oneds=mfload(loadthesefiles,isel=ISEL)
#oneds=xr.open_mfdataset(f"{DPLE_DIR}/monthly/{VAR}/{prefix}1954*.nc",concat_dim="ensemble",chunks={"time": 122})

#catch for ATM vs OCN lat/lon dims
if MODEL=='ATM':
    lat=oneds['lat']
    lon=oneds['lon']
elif MODEL=='OCN':
    nlat=oneds['nlat']
    nlon=oneds['nlon']
mems=oneds['ensemble']

#getting attributes from original file
ncattrs=oneds.attrs
varattrs=oneds[VAR].attrs
dimattrs={}
for dd in oneds.dims:
    dimattrs[dd]=oneds[dd].attrs
print('got array for dimensions')



got array for dimensions


In [12]:
#here's one of the heavy-lifting steps. The * is for the list format
dask_arrays=dask.compute(*dask_arrays)
print('did first compute step')

did first compute step


In [13]:
#The list of dask arrays is now "stacked" into one big dask array
wholedaskarray=da.stack(dask_arrays,axis=0)
print('stacked into one big dask array')

stacked into one big dask array


In [14]:
#function for turning dask array into xarray w/ dims consistent w/ past processed files
def makedparray(daskarray,model):
    if model=='ATM':
        x=('lon',lon.values)
        y=('lat',lat.values)
    elif model=='OCN':
        x=('nlon',nlon.values)
        y=('nlat',nlat.values)
    Lnew=np.arange(1,numL+1,1,dtype='int32')
    Mnew=np.arange(1,len(mems.values)+1,1,dtype='int32')
    newarray=xr.DataArray(daskarray.transpose([0, 2, 1, 3, 4]),\
              coords={'S':S,'L':Lnew,\
                      'M':Mnew,y[0]:y[1],x[0]:x[1]},\
              dims=['S','L','M', y[0],x[0]])
    return newarray

if WHICHMEAN=='ANN':
    numL=10
    array=makedparray(wholedaskarray,MODEL)         
print('made into an xarray')

made into an xarray


In [15]:
#a couple of functions for adding attributes to the dataarray
def add_varattrs(da):
    da.attrs=varattrs
    for cc in da.coords:
        for aa in dimattrs[cc]:
            da[cc].attrs=dimattrs[cc]
    return da

def add_ncattrs(ds):
    ds.attrs=ncattrs
    ds.attrs['script']=os.path.basename(sys.argv[0])
    now=datetime.datetime.now()
    ds.attrs['history']='created by '+pwd.getpwuid(os.getuid()).pw_name+' on '+str(now)
    return ds

#preparing to turn the DataArray back into a DataSet (so it can be written out as a netcdf)
array.name=VAR
array.attrs=varattrs
dimattrs['S']=OrderedDict([('long_name','start year')])
dimattrs['L']=OrderedDict([('long_name','lead year')])
dimattrs['M']=OrderedDict([('long_name','ensemble member')])
array=add_varattrs(array)

#turning DataArray into DataSet and adding ncattrs
newds=array.to_dataset()
newds=add_ncattrs(newds)

In [16]:
#this is another .compute() slow step. More loading into memory
newds=newds.compute()
#this step is separate from the to_netcdf() call below b/c "newds" will be used for calculating drift.
#only want to read all of this into memory once, and if append .to_netcdf() after this .compute()
#the "newobj" cannot be used to calculate drift b/c it's not the right type of object
print('finished loading')
newobj=newds.to_netcdf(RAWDPOUT,engine='netcdf4',compute=False)
print('made an object for writing')
#for some reason, the netcdf write needs to be delayed (compute=False), or dask hangs
please_ncwrite = newobj.persist()  #here's where it actually writes to disk
print('DP array with not-bias corrected output written out')

finished loading
made an object for writing
DP array with not-bias corrected output written out


In [17]:
#DRIFT CALCULATION STARTS HERE
#years to compare against verification
climy0=1964
climy1=2014

#function that converts S,L to verification time for lead-time dependent climo calc
def make_verification_time(ds):
    ds['VER_TIME']=ds['S']+0.5+ds['L']-1
    return ds

#function to calculate drift 
def calc_drift(ds):
    ds_ver=make_verification_time(ds)
    vertime=ds_ver['VER_TIME']
    var=ds_ver[VAR]
    dummy=var.mean('M')    
    #here's the key step that creates an array of booleans that select which variable entries are used
    truefalse=np.squeeze([(vertime.values>climy0) & (vertime.values<(climy1+1))])
    #if you don't believe me (I didn't for a while), uncomment the line below and 
    #it'll print which array entries are used in the lead-time dependent climatology
    #print(truefalse)
    dummy.values[~truefalse,:,:]=np.nan #setting entries we don't want for mean to NaN
    print(dummy)
    drift=dummy.mean('S') #mean across start years, climo calculated here
    
    biascorr=ds_ver[VAR]-drift  #anomalies calculated here
    drift=add_varattrs(drift)   
    biascorr=add_varattrs(biascorr)   
    
    return drift,biascorr

#dataarrays of climo and anomalies created here:    
drift_da,biascorr_da=calc_drift(newds)  
print('done')

<xarray.DataArray 'TS' (S: 62, L: 10, lat: 192, lon: 288)>
array([[[[      nan, ...,       nan],
         ...,
         [      nan, ...,       nan]],

        ...,

        [[220.48404, ..., 220.78227],
         ...,
         [249.90042, ..., 249.89986]]],


       ...,


       [[[      nan, ...,       nan],
         ...,
         [      nan, ...,       nan]],

        ...,

        [[      nan, ...,       nan],
         ...,
         [      nan, ...,       nan]]]], dtype=float32)
Coordinates:
  * S        (S) int32 1955 1956 1957 1958 1959 ... 2012 2013 2014 2015 2016
  * L        (L) int32 1 2 3 4 5 6 7 8 9 10
  * lat      (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 87.17 88.12 89.06 90.0
  * lon      (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
done


In [18]:
#dataarray->dataset->netcdf for drift+climo
drift_ds=drift_da.persist().to_dataset(name='climo')
drift_ds=add_ncattrs(drift_ds)
drift_ds.attrs['climatology']=str(climy0)+"-"+str(climy1)+", computed separately for each lead time"
driftobj=drift_ds.to_netcdf(DRIFTOUT,engine='netcdf4',compute=False)
please_ncwrite = driftobj.persist()
print('written drift/climatology file')

written drift/climatology file


In [19]:
#same deal for anomalies
biascorr_ds=biascorr_da.persist().to_dataset(name='anom')
biascorr_ds=add_ncattrs(biascorr_ds)
biascorr_ds.attrs['climatology']=str(climy0)+"-"+str(climy1)+", computed separately for each lead time"
biasobj=biascorr_ds.to_netcdf(ANOMOUT,engine='netcdf4',compute=False)
please_ncwrite = biasobj.persist()
print('written anomalies file. phew, it worked. I hope.')

written anomalies file. phew, it worked. I hope.
