In [1]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt

from metpy.calc import dewpoint_from_relative_humidity
from metpy.units import units
from metpy.plots import SkewT

import sounding_utils
from xhistogram.xarray import histogram

from importlib import reload
from functools import partial
from joblib import dump

import sys
sys.path.append('../../') # lets us import ptype package from the subdir

#import ptype.

from dask.distributed import Client, LocalCluster
from dask_jobqueue import PBSCluster
import dask
import glob
from os.path import join

cluster = PBSCluster(account='NAML0001',
                     queue='casper',
                     walltime='01:00:00',
                     memory="200 GB",
                     cores=2,
                     interface='ib0',
                     local_directory='/glade/work/dkimpara/dask/',
                     log_directory="/glade/work/dkimpara/dask_logs/")

# Change your url to the dask dashboard so you can see it
#dask.config.set({'distributed.dashboard.link':'https://jupyterhub.hpc.ucar.edu/stable/user/{USER}/proxy/{port}/status'})
print(f"Use this link to monitor the workload: {cluster.dashboard_link}")
client = Client(cluster)
cluster.scale(jobs=50, memory="200GB")

def load_dask(model):
    ds = xr.open_mfdataset(f"/glade/campaign/cisl/aiml/ptype/ptype_case_studies/kentucky/{model}/20220223/*/*.nc", 
                           parallel=True, engine='netcdf4', 
                           decode_cf=False, concat_dim='valid_time', combine='nested', 
                           chunks={'time':1})
    ds.attrs['nwp'] = model

In [4]:
model = 'rap'


In [None]:
%%time
ds = xr.open_mfdataset(f"/glade/campaign/cisl/aiml/ptype/ptype_case_studies/kentucky/{model}/20220223/0000/*.nc", 
                       parallel=True, engine='netcdf4', 
                       decode_cf=False, concat_dim='valid_time', combine='nested', 
                       chunks={'time':1})
ds.attrs['nwp'] = model

In [None]:
ds

In [None]:
def agg_stats(ds, save_dir='/glade/scratch/dkimpara/composite_calcs'):
    try:
        print(ds.attrs['nwp'])
    except: 
        raise ValueError('dataset must have nwp attr set')
        
    ds = sounding_utils.filter_latlon(ds)
    
    precip_mask = (
        (ds["crain"] == 1)
        | (ds["csnow"] == 1)
        | (ds["cicep"] == 1)
        | (ds["cfrzr"] == 1)
        )
    
    ds = ds.where(precip_mask)
    if 'wb_h' not in list(ds.keys()):
        ds = sounding_utils.wet_bulb_from_rel_humid(ds)
    
    ptypes = ['rain', 'snow', 'icep', 'frzr']
    prof_vars = ['t_h', 'dpt_h', 'wb_h']
    quantiles = np.arange(0,1.01,0.1)
    bins = np.arange(-40, 40, 0.5)
    
    persist_vars = (prof_vars + 
                    [f'ML_c{var}' for var in ptypes] +
                    [f'c{var}' for var in ptypes])
    
    ds[persist_vars].persist() 
    total_obs = ds.t_h.isel(heightAboveGround=0).count(dim=('x','y','time','valid_time'))

    ds_mean, ds_q, ds_hist  = [], [], []
    metadata = {}
    
    ################# by confidence level ##################
    mean_by_prob = []
    q_by_prob = []
    for ptype in ptypes:
        predtype = f'ML_{ptype}'
        subset = ds[prof_vars + [f'c{ptype}']].where(ds[predtype] >= 0.7)
        
        mean = subset.mean(dim=['valid_time', 'time', 'x', 'y'])
        mean = mean.rename({var: f'{var}>0.7_mean' for var in prof_vars})
        qs = subset.chunk(chunks=dict(valid_time=-1)).quantile(quantiles, dim=('valid_time','time', 'x', 'y'))
        qs = qs.rename({var: f'{var}>0.7_qs' for var in prof_vars})
        
        mean = mean.expand_dims({'predtype': [predtype]})
        qs = qs.expand_dims({'predtype': [predtype]})
        ####################
        predtype = f'c{ptype}'
        subset = subset.where(subset[predtype] == 1) #confident ML and nwp predicts
        
        mean2 = subset.mean(dim=['valid_time', 'time', 'x', 'y'])
        mean2 = mean2.rename({var: f'{var}>0.7_mean' for var in prof_vars})
        
        qs2 = subset.chunk(chunks=dict(valid_time=-1)).quantile(quantiles, dim=('valid_time','time', 'x', 'y'))
        qs2 = qs2.rename({var: f'{var}>0.7_qs' for var in prof_vars})
        
        mean2= mean2.expand_dims({'predtype': [predtype]})
        qs2 = qs2.expand_dims({'predtype': [predtype]})
        #########################
        mean = xr.concat((mean, mean2), dim='predtype').drop_vars(predtype)
        qs = xr.concat((qs, qs2), dim='predtype').drop_vars(predtype)
        
        mean_by_prob.append(mean)
        q_by_prob.append(qs)
    
    ####################################
    ds_mean = xr.concat(ds_mean, dim='predtype')
    ds_q = xr.concat(ds_q, dim='predtype')
    ds_hist = xr.concat(ds_hist, dim='predtype')

    result = xr.merge((ds_mean, ds_q, ds_hist))
    #save
    result.to_netcdf(path=join(save_dir, f"{ds.attrs['nwp']}_by_prob"))
    dump(metadata, join(save_dir,f"{ds.attrs['nwp']}_metadata_by_prob"))
    
    return result
    

In [None]:
'''
client.shutdown()
import subprocess
subprocess.run("qdel $PBS_JOBID", shell=True, capture_output=True, encoding='utf-8')