In [1]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt

from metpy.calc import dewpoint_from_relative_humidity
from metpy.units import units
from metpy.plots import SkewT

import sounding_utils
from xhistogram.xarray import histogram

from importlib import reload
from functools import partial
from joblib import dump

import sys
#sys.path.append('../../') # lets us import ptype package from the subdir

#import ptype.

from dask.distributed import Client, LocalCluster
from dask_jobqueue import PBSCluster
import dask
import glob
from os.path import join

In [44]:
client.shutdown()

In [45]:
cluster = PBSCluster(account='NAML0001',
                     queue='casper',
                     walltime='01:00:00',
                     memory="1000 GB",
                     resource_spec='select=1:ncpus=16:mem=50GB', # Specify resources
                     interface='ib0',
                     local_directory='/glade/work/dkimpara/dask/',
                     log_directory="/glade/work/dkimpara/dask_logs/")

# Change your url to the dask dashboard so you can see it
#dask.config.set({'distributed.dashboard.link':'https://jupyterhub.hpc.ucar.edu/stable/user/{USER}/proxy/{port}/status'})
print(f"Use this link to monitor the workload: {cluster.dashboard_link}")
client = Client(cluster)
cluster.scale(jobs=50)

Use this link to monitor the workload: https://jupyterhub.hpc.ucar.edu/stable/user/dkimpara/phil/proxy/8787/status


In [46]:
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/dkimpara/phil/proxy/8787/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/dkimpara/phil/proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.12.206.56:34430,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/dkimpara/phil/proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [47]:
def load_dask(model):
    with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        ds = xr.open_mfdataset(f"/glade/campaign/cisl/aiml/ptype/ptype_case_studies/kentucky/{model}/*/*/*.nc", 
                               parallel=True, engine='netcdf4', 
                               decode_cf=False, concat_dim='valid_time', combine='nested', 
                               chunks={'time':1, 'heightAboveGround': 21, 'isobaricInhPa': 37})
        ds.attrs['nwp'] = model
        return ds

## aggregation code

In [48]:
def agg_stats(ds, save_dir='/glade/scratch/dkimpara/composite_calcs'):
    with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        try:
            print(ds.attrs['nwp'])
        except: 
            raise ValueError('dataset must have nwp attr set')

        ds = sounding_utils.filter_latlon(ds)

        precip_mask = (
            (ds["crain"] == 1)
            | (ds["csnow"] == 1)
            | (ds["cicep"] == 1)
            | (ds["cfrzr"] == 1)
            )

        ds = ds.where(precip_mask)
        if 'wb_h' not in list(ds.keys()):
            ds = sounding_utils.wet_bulb_from_rel_humid(ds)

        ptypes = ['rain', 'snow', 'icep', 'frzr']
        prof_vars = ['t_h', 'dpt_h', 'wb_h']
        bins = np.arange(-40, 40, 0.5)
        quantiles = np.arange(0.0, 1.01, 0.1)

        persist_vars = (prof_vars + 
                        [f'ML_c{var}' for var in ptypes] +
                        [f'c{var}' for var in ptypes])

        ds[persist_vars].persist() 
        total_obs = ds.t_h.isel(heightAboveGround=0).count(dim=('x','y','time','valid_time'))

        res_dict = {'mean': [],
                    'quantiles': [],
                    'hist': [],
                    }
        metadata = {'total_obs': total_obs}

        for ptype in ptypes:
            for model in ['ML_c', 'c']:
                predtype = model + ptype
                subset = ds[prof_vars].where(ds[predtype] == 1)

                ### num_obs per hr
                counts = subset.t_h.count(dim=('x','y'))
                obs_per_hr = counts.isel(heightAboveGround=0).mean(dim=('time', 'valid_time'))
                metadata[f'{predtype}_obs_per_hr'] = obs_per_hr

                # num_obs of predtype==1
                num_obs = subset.t_h.isel(heightAboveGround=0).count(dim=('x','y','time','valid_time'))
                metadata[f'{predtype}_num_obs'] = num_obs

                # num_obs w frac abv zero
                for var in prof_vars:
                    metadata[f"{predtype}_{var}_frac_abv_zero"] = (
                        sounding_utils.frac_abv_zero(subset, var, num_obs)
                    )

                # means and quantiles
                mean = subset.mean(dim=('valid_time', 'time', 'x', 'y')) #returns dataset objects
                mean = mean.rename({var: f'{var}_mean' for var in prof_vars})
                
                qs = subset.quantile(quantiles, dim=('valid_time', 'time', 'x', 'y')) #returns dataset objects
                qs = qs.rename({var: f'{var}_qs' for var in prof_vars})
                
                #### densities ####
                densities = ({f'{var}_hist': (
                        histogram(subset[var], bins=bins, dim=['valid_time', 'time', 'x', 'y'], density=True)
                        .rename({f'{var}_bin': 'bin'})
                        ) for var in prof_vars})
                densities = xr.Dataset(densities) #arrays already named histograms

                res_datasets = {'mean': mean,
                            'quantiles': qs,
                            'hist': densities}

                for k, res_ds_list in res_dict.items():
                    res_ds_list.append(res_datasets[k].expand_dims({'predtype': [predtype]}))
                
        ds_concat = [xr.concat(res_ds_list, dim='predtype') for res_ds_list in res_dict.values()]
        result = xr.merge(ds_concat)
        #save
        result.to_netcdf(path=join(save_dir, ds.attrs['nwp']))
        dump(metadata, join(save_dir,f"{ds.attrs['nwp']}_metadata"))

        return result


In [49]:
def agg_delayed(ds, save_dir='/glade/scratch/dkimpara/composite_calcs'):
    with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        try:
            print(ds.attrs['nwp'])
        except: 
            raise ValueError('dataset must have nwp attr set')

        ds = sounding_utils.filter_latlon(ds)

        precip_mask = (
            (ds["crain"] == 1)
            | (ds["csnow"] == 1)
            | (ds["cicep"] == 1)
            | (ds["cfrzr"] == 1)
            )

        ds = ds.where(precip_mask)
        print('filtered')
        #if 'wb_h' not in list(ds.keys()):
        #    ds = sounding_utils.wet_bulb_from_rel_humid(ds)
        #print('wb computed')
        ptypes = ['rain', 'snow', 'icep', 'frzr']
        prof_vars = ['t_h', 'dpt_h']#, 'wb_h']
        bins = np.arange(-40, 40, 0.5)
        quantiles = np.arange(0.0, 1.01, 0.1)

        persist_vars = (prof_vars + 
                        [f'ML_c{var}' for var in ptypes] +
                        [f'c{var}' for var in ptypes])

        #ds[persist_vars].persist() 
        print('persisted')
        total_obs = ds.t_h.isel(heightAboveGround=0).count(dim=('x','y','time','valid_time'))

        res_dict = {'mean': [],
                    quantiles': [],
                    'hist': [],
                    }
        metadata = {'total_obs': total_obs}
        print('total obs')

        lazy_results = []
        ####################################
        remote_ds = client.scatter(ds)
        for ptype in ptypes:
            for model in ['ML_c', 'c']:
                predtype = model + ptype
                
                lazy_result = dask.delayed(agg_parallel)(predtype, remote_ds) #this fn returns a dict of datasets
                lazy_results.append(lazy_result)
                
        print('computing')
        for i in range(len(ptypes)):
            res, meta = lazy_results[i].compute()
            metadata = metadata | meta #merge metadata dictionary
            for k in res.keys():
                res_dict[k].append(res[k])
        print('extracted')
        ds_concat = [xr.concat(res_ds_list, dim='predtype') for res_ds_list in res_dict.values()]
        result = xr.merge(ds_concat)
        #save
        result.to_netcdf(path=join(save_dir, ds.attrs['nwp']))
        dump(metadata, join(save_dir, f"{ds.attrs['nwp']}_metadata"))

        return result

def agg_parallel(predtype, ds):
    metadata = {}
    prof_vars = ['t_h', 'dpt_h']
    bins = np.arange(-40, 40, 0.5)
    quantiles = np.arange(0.0, 1.01, 0.1)
    
    subset = ds[prof_vars].where(ds[predtype] == 1)

    ### num_obs per hr
    counts = subset.t_h.count(dim=('x','y'))
    obs_per_hr = counts.isel(heightAboveGround=0).mean(dim=('time', 'valid_time'))
    metadata[f'{predtype}_obs_per_hr'] = obs_per_hr

    # num_obs of predtype==1
    num_obs = subset.t_h.isel(heightAboveGround=0).count(dim=('x','y','time','valid_time'))
    metadata[f'{predtype}_num_obs'] = num_obs

    # num_obs w frac abv zero
    for var in prof_vars:
        metadata[f"{predtype}_{var}_frac_abv_zero"] = (
            frac_abv_zero(subset, var, num_obs)
        )

    # means and quantiles
    mean = subset.mean(dim=('valid_time', 'time', 'x', 'y')) #returns dataset objects
    mean = mean.rename({var: f'{var}_mean' for var in prof_vars})

    qs = subset.chunk(dict(valid_time=-1, time=-1)).quantile(quantiles, dim=('valid_time', 'time', 'x', 'y')) #returns dataset objects
    qs = qs.rename({var: f'{var}_qs' for var in prof_vars})

    #### densities ####
    densities = ({f'{var}_hist': (
            histogram(subset[var], bins=bins, dim=['valid_time', 'time', 'x', 'y'], density=True)
            .rename({f'{var}_bin': 'bin'})
            ) for var in prof_vars})
    densities = xr.Dataset(densities) #arrays already named histograms

    res_datasets = {'mean': mean,
                'quantiles': qs,
                'hist': densities}
    return res_datasets, metadata

def frac_abv_zero(ds, x_col, total):
    num_over_zero = (ds[x_col] > 0).any(dim="heightAboveGround").sum()
    return num_over_zero / total

# run jobs

In [50]:
%%time
model = 'rap'
ds = load_dask(model)

CPU times: user 4min 49s, sys: 17.4 s, total: 5min 7s
Wall time: 12min 33s


%%time
reload(sounding_utils)
res_rap = agg_delayed(ds)
del ds

In [51]:
dask.config.set(**{'array.slicing.split_large_chunks': True})
try:
    print(ds.attrs['nwp'])
except: 
    raise ValueError('dataset must have nwp attr set')

ds = sounding_utils.filter_latlon(ds)

precip_mask = (
    (ds["crain"] == 1)
    | (ds["csnow"] == 1)
    | (ds["cicep"] == 1)
    | (ds["cfrzr"] == 1)
    )

ds = ds.where(precip_mask)
print('filtered')
#if 'wb_h' not in list(ds.keys()):
#    ds = sounding_utils.wet_bulb_from_rel_humid(ds)
#print('wb computed')
ptypes = ['rain', 'snow', 'icep', 'frzr']
prof_vars = ['t_h', 'dpt_h']#, 'wb_h']
bins = np.arange(-40, 40, 0.5)
quantiles = np.arange(0.0, 1.01, 0.1)

persist_vars = (prof_vars + 
                [f'ML_c{var}' for var in ptypes] +
                [f'c{var}' for var in ptypes])

#ds[persist_vars].persist() 
print('persisted')
total_obs = ds.t_h.isel(heightAboveGround=0).count(dim=('x','y','time','valid_time'))

res_dict = {'mean': [],
            #quantiles': [],
            'hist': [],
            }
metadata = {'total_obs': total_obs}
print('total obs')

rap
filtered
persisted
total obs


In [52]:
%%time
lazy_results = []
remote_ds = client.scatter(ds)

for ptype in ptypes:
    for model in ['ML_c', 'c']:
        predtype = model + ptype

        lazy_result = dask.delayed(agg_parallel)(predtype, remote_ds) #this fn returns a dict of datasets and a metadata dict
        lazy_results.append(lazy_result)
        
        #lazy_result.append(agg_parallel(predtype, ds))
#need to do client.gather to compute?

CPU times: user 7.08 s, sys: 3.38 s, total: 10.5 s
Wall time: 16.3 s


In [56]:
%%time
for i in range(len(ptypes)):
    res, meta = lazy_results[i].compute()
    metadata = metadata | meta #merge metadata dictionary
    for k in res.keys():
        res_dict[k].append(res[k])
print('extracted')


extracted
CPU times: user 19.7 s, sys: 3.45 s, total: 23.1 s
Wall time: 37.9 s


In [57]:
%%time
ds_concat = [xr.concat(res_ds_list, dim='predtype') for res_ds_list in res_dict.values()]
result = xr.merge(ds_concat)
#save
save_dir='/glade/scratch/dkimpara/composite_calcs'
result.to_netcdf(path=join(save_dir, ds.attrs['nwp']))
dump(metadata, join(save_dir, f"{ds.attrs['nwp']}_metadata"))

This may cause some slowdown.
Consider scattering data ahead of time and using futures.


CPU times: user 14min 17s, sys: 35.4 s, total: 14min 52s
Wall time: 25min 50s


['/glade/scratch/dkimpara/composite_calcs/rap_metadata']

In [55]:
def timer(tic):
    toc = time.time()
    duration = toc - tic
    minutes = int(duration/60)
    print(f"Elapsed time: {str(minutes) + ' minutes, ' if minutes else ''}{int(duration % 60)} seconds")

In [None]:
import time
for model in ['rap', 'gfs', 'hrrr']:
    tic = time.time()
    ds = load_dask(model)
    timer(tic)
    
    tic = time.time()
    _ = agg_stats(ds)
    timer(tic)
    del ds
    

In [None]:
%%time
model = 'gfs'
ds = load_dask(model)

In [None]:
%%time
res_gfs = agg_stats(ds)
del ds

In [None]:
%%time
model = 'hrrr'
ds = load_dask(model)

In [None]:
%%time
res_hrrr = agg_stats(ds)

In [None]:
client.shutdown()
import subprocess
subprocess.run("qdel $PBS_JOBID", shell=True, capture_output=True, encoding='utf-8')

In [None]:
client.shutdown()
