In [34]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd

from importlib import reload

from functools import partial

import os
from os.path import join
import subprocess
import time
import sys
sys.path.append('../../') # lets us import ptype package from the subdir
#import ptype.

from joblib import Parallel, delayed

from xhistogram.xarray import histogram


import soundings.utils as sounding_utils
import batch_jobs.xr_map_reduce as xmr

In [8]:
reload(xmr)
dirpaths = ['/glade/scratch/dkimpara/ptype_case_studies/kentucky/rap/20220224/0100',
            '/glade/scratch/dkimpara/ptype_case_studies/kentucky/rap/20220224/0600',
           '/glade/scratch/dkimpara/ptype_case_studies/kentucky/rap/20220223/0100',
           '/glade/scratch/dkimpara/ptype_case_studies/kentucky/rap/20220223/0600']
    
ds = xr.open_mfdataset(f'{dirpaths[0]}/*.nc', engine='netcdf4', concat_dim='step', decode_cf=False, combine='nested')


In [41]:
def compute_stats(subset, label, proftypes, predtype):
    num_obs = subset["t_h"].isel(heightAboveGround=0).count(dim=("x", "y"))
    num_obs = xr.Dataset({f"num_obs_{label}": num_obs}).drop_vars("heightAboveGround")

    frac_abv = sounding_utils.frac_abv_split_time(subset, proftypes)
    frac_abv = frac_abv.rename({var: f"{var}_fabv_disagree" for var in proftypes})
    ####### compute means ##############
    mean = subset[proftypes].mean(dim=("x", "y"))
    mean = mean.rename({var: f"{var}_mean_{label}" for var in proftypes})

    ####### compute histograms ############
    bins = np.arange(-60, 40, 0.1)
    densities = {
        f"{var}_hist_{label}": (
            histogram(
                subset[var], bins=bins, dim=["x", "y"], density=True
            ).rename({f"{var}_bin": "bin"})
        )
        for var in proftypes
    }
    densities = xr.Dataset(densities)

    results = {
        "num_obs": num_obs,
        "frac_abv": frac_abv,
        "means": mean,
        "hists": densities,
    }
    results = {
                k: v.expand_dims({"predtype": [predtype]}) for k, v in results.items()
              }
    return results

def compute_prob_and_disagree(ds):
    ptypes = ["icep", "frzr", "snow", "rain"]
    proftypes = ["t_h", "dpt_h"]#, "wb_h"]
    other_pred = ({f'ML_c{ptype}': f'c{ptype}' for ptype in ptypes} |
                  {f'c{ptype}': f'ML_c{ptype}' for ptype in ptypes}) 
    
    res_dict = {"num_obs": [], "frac_abv": [], "means": [], "hists": []}

    for ptype in ptypes:
        for model in ["ML_c", "c"]:
            #rescope this
            predtype = model + ptype
            
            # compute disagreement
            mask = (ds[predtype] == 1) & (ds[other_pred[predtype]] == 0)
            masked_ds = ds[proftypes].where(mask)
            results = compute_stats(masked_ds, 'disagree', proftypes, predtype)
            for k in res_dict.keys():
                res_dict[k].append(results[k])
            
            # compute confident preds
            for confidence in [0.3,0.5,0.7,0.9]:
                if 'ML' in predtype:
                    mask = (ds['ML_' + ptype] >= confidence) & (ds['ML_c' + ptype] == 1) #confident and predicted the ptype
                else:
                    mask = (ds['ML_' + ptype] >= confidence) & (ds[predtype] == 0) #ML confident and nwp did not predict the ptype
                masked_ds = ds[proftypes].where(mask)
                results = compute_stats(masked_ds, f'{confidence}', proftypes, predtype)
                for k in res_dict.keys():
                    res_dict[k].append(results[k])

    ds_concat = [
        xr.concat(res_ds_list, dim="predtype") for res_ds_list in res_dict.values()
    ]
    result = xr.merge(ds_concat)

    return result

In [35]:
results = [xmr.xr_map(path, compute_prob_and_disagree) for path in dirpaths]

  return func(*(_execute_task(a, cache) for a in args))
  return func(*(_execute_task(a, cache) for a in args))
  return func(*(_execute_task(a, cache) for a in args))
  return func(*(_execute_task(a, cache) for a in args))


In [36]:
results_t = [xmr.time_to_inithr(res) for res in results]

In [37]:
final = xr.merge(results_t)