In [1]:
import os
import socket

import dask
import dask.distributed
import ncar_jobqueue

import xarray as xr
import numpy as np
import esmlab

import intake
import intake_esm

In [2]:
cluster = ncar_jobqueue.NCARCluster()
client = dask.distributed.Client(cluster)
n_workers = 9 * 10
cluster.scale(n_workers)

In [3]:
client

0,1
Client  Scheduler: tcp://10.148.10.13:49566  Dashboard: http://10.148.10.13/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [4]:
col = intake.open_esm_metadatastore(
    collection_input_file='cesm1-le-collection.yml',
    overwrite_existing=False)
col.df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 116276 entries, 0 to 116275
Data columns (total 18 columns):
resource            116276 non-null object
resource_type       116276 non-null object
direct_access       116276 non-null bool
experiment          116276 non-null object
case                116276 non-null object
component           116276 non-null object
stream              116276 non-null object
variable            116276 non-null object
date_range          116276 non-null object
ensemble            116276 non-null int64
files               116276 non-null object
files_basename      116276 non-null object
files_dirname       116276 non-null object
ctrl_branch_year    0 non-null float64
year_offset         15145 non-null float64
sequence_order      116276 non-null int64
has_ocean_bgc       116276 non-null bool
grid                13013 non-null object
dtypes: bool(2), float64(2), int64(2), object(12)
memory usage: 15.3+ MB


In [5]:
def weighted_mean_ds(ds, weights, dim):

    if callable(weights):
        weights = weights(ds)
        
   
    # find variables where all `dim` appear (not sure we need to be this restrictive)
    variable_list = [v for v, da in ds.variables.items() 
                     if all(d in da.dims for d in dim)]
   
    # copy vars
    dso = xr.Dataset({v: da for v, da in ds.variables.items() if v not in variable_list})

    # compute 
    apply_nan_mask = True
    for v in variable_list:
        dso[v] = esmlab.statistics.weighted_mean(ds[v], weights=weights, dim=dim,
                                                apply_nan_mask=apply_nan_mask)
        apply_nan_mask = False
    
    return dso

In [6]:
def vol3d(ds):

    nk = len(ds.z_t)
    nj = ds.KMT.shape[0]
    ni = ds.KMT.shape[1]

    # make 3D array of 0:km
    MASK = (xr.DataArray(np.arange(0, len(ds.z_t)), dims=('z_t'), 
                        coords={'z_t': ds.z_t}) *
            xr.DataArray(np.ones((nk, nj, ni)), dims=('z_t', 'nlat', 'nlon'),
                         coords={'z_t': ds.z_t}))

    # mask out cells where k is below KMT
    MASK = MASK.where(MASK <= ds.KMT - 1)
    MASK.values = np.where(MASK.notnull(), 1., 0.)

    MASKED_VOL = ds.dz * ds.TAREA * MASK
    MASKED_VOL.attrs['units'] = 'cm^3'
    MASKED_VOL.attrs['long_name'] = 'masked volume'

    return MASKED_VOL

In [7]:
ensembles = col.search(experiment=['20C', 'RCP85'], has_ocean_bgc=True).results.ensemble.unique().tolist()
ensembles = ensembles[1:4]
print(ensembles)

[2, 9, 10]


In [None]:
variable = ['O2']
query = dict(ensemble=ensembles, experiment=['20C', 'RCP85'], 
             stream='pop.h', variable=variable)

col_subset = col.search(**query)
    
# get a dataset
ds = col_subset.to_xarray()

# select the time range
#dso = esmlab.utils.time.sel_time(ds, slice('1920', '2100'))

# compute annual means
dso = esmlab.climatology.compute_ann_mean(ds)

# compute global average
dso = weighted_mean_ds(dso, weights=vol3d, dim=['z_t', 'nlat', 'nlon'])

# compute the dataset 
dso = dso.compute()

dso.to_netcdf(f'/glade/scratch/mclong/tmp/le-{"-".join(variable)}-global-mean.nc')

In [None]:
cluster.close()

In [None]:
ds

In [None]:
set(ds.coords) - set(ds.dims)