In [40]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from importlib import reload
import os
from os.path import join
import subprocess
import sys
sys.path.append('../../') # lets us import ptype package from the subdir
#import ptype.
import time
from collections.abc import Sequence
from joblib import Parallel, delayed

from xhistogram.xarray import histogram

import soundings.utils as utils
import soundings.mr_analysis as mra

In [50]:
ds = xr.open_dataset("/glade/work/dkimpara/ptype-aggs/test.nc", engine='netcdf4')
ds 

In [73]:
class SoundingQuery():
    def __init__(self, datasets):
        datasets = self._to_sequence(datasets) 
        self.ds = xr.merge(datasets)
            
    def query(self, predtypes, variables, stats, sel={}):
        # code to change single inputs to a list
        predtypes = self._to_sequence(predtypes) 
        variables = self._to_sequence(variables)
        stats = self._to_sequence(stats)
        
        ds = self.ds.sel(sel)
        query_vars = [f'{var}_{stat}' for var,stat in zip(variables, stats)]
        ds = ds.sel({'predtype':predtypes})
        
        total_obs = ds['num_obs'].sum(dim=('case_study_day', 'step', 'init_hr'))
        res = ds[query_vars] * ds['num_obs'] 
                
        return res.sum(dim=('case_study_day', 'step', 'init_hr')) / total_obs
    
    def quantile(self, quantiles, predtypes, variables, sel={}):
        # code to change single inputs to a list
        quantiles = np.sort(self._to_sequence(quantiles))
        
        if np.any(quantiles > 1.0) or np.any(quantiles < 0.0):
            raise ValueError('Specified quantiles has value less than 0')
        
        hist = self.query(predtypes, variables, 'hist', sel)

        norm_const = hist.isel(heightAboveGround=0).sum(dim='bin')
        
        results = []
        for var in list(hist.keys()):
            cdf = self._compute_cdf(hist[var])
            for q in quantiles:
                q_csum = cdf.where(cdf >= q * norm_const[var])
                qs = q_csum.idxmin(dim='bin')
                qs = qs.expand_dims({'quantile': [q]})
                results.append(qs)
        return xr.merge(results)
    
    def _compute_cdf(self, hist):
        csum = hist.cumsum(dim='bin')
        return csum
    
    def _to_sequence(self, obj):
        if self._seq_but_not_str(obj):
            return obj
        else:
            return [obj]
    
    def _seq_but_not_str(self, obj):
        return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes, bytearray))

In [74]:
query = SoundingQuery([ds])