In [1]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

import sounding_utils
from importlib import reload

from functools import partial

import os
from os.path import join
import subprocess

import sys
sys.path.append('../../') # lets us import ptype package from the subdir
#import ptype.

from joblib import Parallel, delayed

from xhistogram.xarray import histogram

import time

In [2]:
subprocess.run("qstat -f $PBS_JOBID | grep Resource_List.ncpus", shell=True, capture_output=True, encoding='utf-8').stdout.split()[-1]

72

In [3]:
def get_metadata(path):
    path = os.path.normpath(path)
    split_path = path.split(os.sep)    
    metadata_dict = {'case_study': [split_path[-4]],
                     'day': [split_path[-2]],
                     'model': [split_path[-3]]
                    }
    return metadata_dict
    
def xr_map_reduce(base_path, func, n_jobs=-1):
    dirpaths = []
    for (dirpath, dirnames, filenames) in os.walk(base_path):
        #if there are subdirs in the dir skip this loop
        if dirnames or not filenames: continue 
    
        dirpaths.append(dirpath)
    if n_jobs == -1:
        num_cpus = (subprocess.run("qstat -f $PBS_JOBID | grep Resource_List.ncpus", 
                                  shell=True, capture_output=True, encoding='utf-8').stdout.split()[-1]
                    if 'glade' in os.getcwd() else
                    os.cpu_count()
        ) 
        n_jobs = min(len(dirpaths), num_cpus)
        
    ########################## map and reduce ##############################
    results = Parallel(n_jobs=n_jobs)(delayed(xr_map)(path, func) for path in dirpaths)
    return xr.concat(results, dim=('time')) #each result ds will be for a different time
        
def xr_map(dirpath, func):
    ds = xr.open_mfdataset(join(dirpath, "*.nc"), 
                            concat_dim='valid_time', 
                            combine='nested')
    
    if 'wb_h' not in list(ds.keys()):
        ds = sounding_utils.wet_bulb_from_rel_humid(ds)

    ds = ds.where((
          (ds['crain'] == 1) | 
          (ds['csnow'] == 1) | 
          (ds['cicep'] == 1) | 
          (ds['cfrzr'] == 1)
    ))
    
    #adds metadata corresponding to each folder
    metadata_dict = get_metadata(dirpath)
    ds = ds.expand_dims(metadata_dict)
    return func(ds)

def compute_func(ds):
    res = xr.Dataset()
    for ptype in ['icep', 'frzr', 'snow', 'rain']:
        for model in ['ML_c', 'c']:
            predtype = model + ptype
            subset = ds.where(ds[predtype] == 1)
            
            num_obs = subset['t_h'].count(dim=('x','y')).max(dim='heightAboveGround')
            res[predtype + '_mean_obs'] = num_obs
            
            prof_means = []
            hists = []
            for proftype in ['t_h','dpt_h','wb_h']:
                ####### compute means ##############
                mean = subset[proftype].mean(dim=('x','y'))
                prof_means.append(mean.expand_dims({'profile': [proftype]}))

                ####### compute histograms ############
                bins = np.arange(-40,40,0.5)
                h_x = histogram(subset[proftype], bins=[bins], dim=['x','y'], density=True)
                h_x = h_x.rename({f"{proftype}_bin": "bins"})
                hists.append(h_x)
            
            res[predtype + '_hist'] = xr.concat(hists, dim='profile')
            res[predtype + '_mean'] = xr.concat(prof_means, dim='profile')
    
    return res
        

In [4]:
def timer(tic):
    toc = time.time()
    duration = toc - tic
    minutes = int(duration/60)
    print(f"Elapsed time: {str(minutes) + ' minutes, ' if minutes else ''}{int(duration % 60)} seconds")

In [5]:
tic = time.time()

In [None]:
for model in ['rap', 'gfs','hrrr']:
    res = xr_map_reduce(f"/glade/campaign/cisl/aiml/ptype/ptype_case_studies/kentucky/{model}/20220223/", 
                        compute_func, 4)
    res.to_netcdf(f'/glade/work/dkimpara/ptype-aggs/trial_{model}.nc')

  int_num = np.asarray(num, dtype=np.int64)
  return func(*(_execute_task(a, cache) for a in args))
  int_num = np.asarray(num, dtype=np.int64)


In [None]:
timer(tic)