In [10]:
import warnings
import numpy as np
import xarray as xr
from numba import jit
import proplot as pplt
pplt.rc.reso='hi'
warnings.filterwarnings('ignore')

In [11]:
FILEDIR = '/global/cfs/cdirs/m4334/sferrett/monsoon-pod/data/processed'
SAVEDIR = '/global/cfs/cdirs/m4334/sferrett/monsoon-pod/figs'
REGIONS = {
    'Eastern Arabian Sea':{'latmin':9.,'latmax':19.5,'lonmin':64.,'lonmax':72.}, 
    'Central India':{'latmin':18.,'latmax':24.,'lonmin':76.,'lonmax':83.},
    'Central Bay of Bengal':{'latmin':9.,'latmax':14.5,'lonmin':86.5,'lonmax':90.},
    'Equatorial Indian Ocean':{'latmin':5.,'latmax':10.,'lonmin':62.,'lonmax':67.5},
    'Konkan Coast':{'latmin':15.,'latmax':19.5,'lonmin':69.,'lonmax':72.5}} 
BINPARAMS = {
    'bl':{'min':-0.6,'max':0.1,'width':0.0025},
    'cape':{'min':-70.,'max':20.,'width':1.},
    'subsat':{'min':-20.,'max':70.,'width':1.}}
SAMPLETHRESH  = 50
MONTHPAIRS    = [(6,7),(7,8)]
NITERATIONS   = 1000
YEARSINSAMPLE = 5

# CASES   = {
#     'JJ':[(6,'June','#D42028'),(7,'July','#F2C85E')],
#     'JA':[(7,'July','#F2C85E'),(8,'August','#5BA7DA')]}

In [12]:
def load(filename,filedir=FILEDIR):
    filepath = f'{filedir}/{filename}'
    ds = xr.open_dataset(filepath)
    return ds.load()

In [13]:
hiresimergds  = load('ERA5_IMERG_pr_bl_terms.nc')
loresimergds  = load('LOW_ERA5_IMERG_pr_bl_terms.nc')
loresgpcpds   = load('ERA5_GPCP_pr_bl_terms.nc')

In [9]:
def get_region(data,key,regions=REGIONS):
    region = regions[key]
    return data.sel(lat=slice(region['latmin'],region['latmax']),lon=slice(region['lonmin'],region['lonmax']))

def get_month(data,months):
    if not isinstance(months,(list,tuple)):
        months = [months]
    monthmask = data.time.dt.month.isin(months)
    return data.sel(time=monthmask)

def get_bin_mean_pr(monthstats,bintype,samplethresh=SAMPLETHRESH):
    if bintype=='1D':
        Q0 = monthstats.Q0.values
        Q1 = monthstats.Q1.values
        Q0[Q0==0.0] = np.nan
        binmeanpr = Q1/Q0
        binmeanpr[Q0<samplethresh] = np.nan
        return xr.DataArray(binmeanpr,coords={'bl':monthstats.bl.values})
    if bintype=='2D':
        P0 = monthstats.P0.values
        P1 = monthstats.P1.values
        P0[P0==0.0] = np.nan
        binmeanpr = P1/P0
        binmeanpr[P0<samplethresh] = np.nan
        return xr.DataArray(binmeanpr,coords={'subsat':monthstats.subsat.values,'cape':monthstats.cape.values})

def get_pdf(monthstats,bintype,precipitating=False,samplethresh=SAMPLETHRESH):
    if bintype=='1D':
        Q0 = monthstats.Q0.values
        QE = monthstats.QE.values
        Q  = QE if precipitating else Q0
        Q[Q0<samplethresh] = np.nan
        pdf = Q/np.nansum(Q)
        return xr.DataArray(pdf,coords={'bl':monthstats.bl.values})
    if bintype=='2D':
        P0 = monthstats.P0.values
        PE = monthstats.PE.values
        P  = PE if precipitating else P0
        P[P0<samplethresh] = np.nan
        pdf = P/np.nansum(P)
        return xr.DataArray(pdf,coords={'subsat':monthstats.subsat.values,'cape':monthstats.cape.values})

In [None]:
def get_bin_edges(key,binparams=BINPARAMS):
    varname  = binparams[key]
    return np.arange(varname['min'],varname['max']+varname['width'],varname['width'])
     
@jit(nopython=True)
def fast_binned_stats(blidxs,capeidxs,subsatidxs,prdata,nblbins,ncapebins,nsubsatbins,prthresh=PRTHRESH): 
    Q0 = np.zeros(nblbins)
    QE = np.zeros(nblbins)
    Q1 = np.zeros(nblbins)
    P0 = np.zeros((nsubsatbins,ncapebins))
    PE = np.zeros((nsubsatbins,ncapebins))
    P1 = np.zeros((nsubsatbins,ncapebins))
    for i in range(prdata.size):
        blidx     = blidxs.flat[i]
        capeidx   = capeidxs.flat[i]
        subsatidx = subsatidxs.flat[i]
        prval     = prdata.flat[i]
        if 0<=blidx<nblbins and np.isfinite(prval):
            Q0[blidx] += 1
            Q1[blidx] += prval
            if prval>prthresh:
                QE[blidx] += 1
        if 0<=subsatidx<nsubsatbins and 0<=capeidx<ncapebins and np.isfinite(prval):
            P0[subsatidx,capeidx] += 1
            P1[subsatidx,capeidx] += prval
            if prval>prthresh:
                PE[subsatidx,capeidx] += 1
    return Q0,QE,Q1,P0,PE,P1

def calc_binned_stats(data,binparams=BINPARAMS,prthresh=PRTHRESH,author=AUTHOR,email=EMAIL):
    bl,cape,subsat,pr = (data[var].values for var in ['bl','cape','subsat','pr'])
    bins = {key:get_bin_edges(key,binparams) for key in ['bl','cape','subsat']}
    blidxs     = ((bl-binparams['bl']['min'])/binparams['bl']['width']+0.5).astype(np.int32)
    capeidxs   = ((cape-binparams['cape']['min'])/binparams['cape']['width']-0.5).astype(np.int32)
    subsatidxs = ((subsat-binparams['subsat']['min'])/binparams['subsat']['width']-0.5).astype(np.int32)
    Q0,QE,Q1,P0,PE,P1 = fast_binned_stats(blidxs,capeidxs,subsatidxs,pr,bins['bl'].size,bins['cape'].size,bins['subsat'].size,prthresh)
    ds = xr.Dataset(data_vars={'Q0':('bl',Q0),'QE':('bl',QE),'Q1':('bl',Q1),
                               'P0':(('subsat','cape'),P0),'PE':(('subsat','cape'),PE),'P1':(('subsat','cape'),P1)},
                    coords={'bl':bins['bl'],'cape':bins['cape'],'subsat':bins['subsat']})
    ds.Q0.attrs     = dict(long_name='Count of points in each bin')
    ds.QE.attrs     = dict(long_name=f'Count of precipitating ( > {prthresh} mm/day) points in each bin')
    ds.Q1.attrs     = dict(long_name='Sum of precipitation in each bin',units='mm/day')
    ds.P0.attrs     = dict(long_name='Count of points in each bin')
    ds.PE.attrs     = dict(long_name=f'Count of precipitating ( > {prthresh} mm/day) points in each bin')
    ds.P1.attrs     = dict(long_name='Sum of precipitation in each bin',units='mm/day')
    ds.bl.attrs     = dict(long_name='Average buoyancy in the lower troposphere',units='m/s²')
    ds.cape.attrs   = dict(long_name='Undilute plume buoyancy',units='K')
    ds.subsat.attrs = dict(long_name='Subsaturation in the lower free-troposphere',units='K')
    ds.attrs        = dict(history=f'Created on {datetime.today().strftime("%Y-%m-%d")} by {author} ({email})')
    return ds call 📱 at noon at all this time for I am working with my friend for the next chapter 

In [None]:
@jit(nopython=True)
def fast_1D_counts(bl,pr,blbinmin,blbinwidth,nblbins):
    Q0 = np.zeros(nblbins)
    Q1 = np.zeros(nblbins)
    blidxs = ((bl-blbinmin)/blbinwidth+0.5).astype(np.int32)
    for i in range(bl.size):
        blidx = blidxs.flat[i]
        prval = pr.flat[i]
        if 0<=blidx<nblbins and np.isfinite(prval):
            Q0[blidx] += 1
            Q1[blidx] += prval
    return Q0,Q1

def calc_binned_stats(data,binparams=BINPARAMS):
    def get_bin_edges(key,binparams=BINPARAMS):
        varname  = binparams[key]
        binedges = np.arange(varname['min'],varname['max']+varname['width'],varname['width'])
        return binedges
    blbins  = get_bin_edges('bl',binparams)
    Q0,Q1 = fast_1D_counts(data.bl.values,data.pr.values,binparams['bl']['min'],binparams['bl']['width'],blbins.size)
    ds = xr.Dataset(data_vars={'Q0':(('bl'),Q0),'Q1':(('bl'),Q1)},coords={'bl':blbins})
    ds.Q0.attrs = dict(long_name='Count of points in each bin')
    ds.Q1.attrs = dict(long_name='Sum of precipitation in each bin',units='mm/day')
    ds.bl.attrs = dict(long_name='Average buoyancy in the lower troposphere', units='m/s²')
    return ds

def get_bootstrap_samples(data,niterations=NITERATIONS,yearsinsample=YEARSINSAMPLE):
    allyears = np.unique(data.time.dt.year.values)
    nyears   = len(allyears)
    yearidxs = np.random.choice(nyears,size=(niterations,yearsinsample),replace=True)
    samples = []
    for iterationidxs in yearidxs:
        selectedyears = allyears[iterationidxs]
        sample = data.sel(time=data.time.dt.year.isin(selectedyears))
        samples.append(sample)
    return samples

def calc_bootstrap_stats(data,niterations=NITERATIONS,yearsinsample=YEARSINSAMPLE):
    samples   = get_bootstrap_samples(data,niterations,yearsinsample)
    statslist = [calc_binned_stats(sample) for sample in samples]
    return xr.concat(statslist,dim='iteration')

def get_confidence_intervals(data,confidencelevel=0.95): 
    lowend  = (1-confidencelevel)/2
    highend = 1-lowend
    cilower = data.quantile(lowend,dim='iteration')
    ciupper = data.quantile(highend,dim='iteration')
    mean    = data.mean('iteration')
    return mean,cilower,ciupper

In [5]:
# def get_region(data,key,regions=REGIONS):
#     region = regions[key]
#     return data.sel(lat=slice(region['latmin'],region['latmax']),lon=slice(region['lonmin'],region['lonmax']))

# def get_month(data,months):
#     if not isinstance(months,(list,tuple)):
#         months = [months]
#     monthmask = data.time.dt.month.isin(months)
#     return data.sel(time=monthmask)

# def get_bin_mean_pr(monthstats,bintype,samplethresh=SAMPLETHRESH):
#     Q0 = monthstats.Q0.values
#     Q1 = monthstats.Q1.values
#     Q0[Q0==0.0] = np.nan
#     binmeanpr = Q1/Q0
#     binmeanpr[Q0<samplethresh] = np.nan
#     coords = {'bl':monthstats.bl.values}
#     return xr.DataArray(binmeanpr,coords=coords)

#     def get_pdf(monthstats):
#     counts     = monthstats.P0.values
#     subsatbins = monthstats.subsat.values
#     capebins   = monthstats.cape.values
#     pdf = counts/(np.nansum(counts)*np.diff(subsatbins)[0]*np.diff(capebins)[0])
#     subsatmaxidx,capemaxidx = np.where(counts==np.nanmax(counts))
#     subsatmode,capemode     = subsatbins[subsatmaxidx],capebins[capemaxidx]
#     return xr.DataArray(pdf,coords={'subsat':subsatbins,'cape':capebins}),subsatmode,capemode


# def get_bin_mean_pr(stats,samplethresh=SAMPLETHRESH):
#     blbins = stats.bl.values
#     Q0 = stats.Q0.values
#     Q1 = stats.Q1.values
#     Q0[Q0==0.0] = np.nan
#     binmeanpr = Q1/Q0
#     binmeanpr[Q0<samplethresh] = np.nan
#     return xr.DataArray(binmeanpr,coords={'bl':blbins})

# def get_pdf(monthstats,precipitating=False,samplethresh=SAMPLETHRESH):
#     if bintype=='1D':
#         Q0 = stats.Q0.values
#         QE = stats.QE.values
#         Q  = QE if precipitating else Q0
#         Q[Q0<samplethresh] = np.nan
#         pdf = Q/np.nansum(Q)
#         coords = {'bl':monthstats.bl.values}
#     if bintype=='2D':
        
#     blbins = stats.bl.values
#     Q0 = stats.Q0.values
#     QE = stats.QE.values
#     Q  = QE if precipitating else Q0
#     Q[Q0<samplethresh] = np.nan
#     pdf = Q/np.nansum(Q)
#     return xr.DataArray(pdf,coords={'bl':stats.bl.values})

@jit(nopython=True)
def fast_binned_stats(bl,pr,blbinmin,blbinwidth,nblbins):
    Q0 = np.zeros(nblbins)
    Q1 = np.zeros(nblbins)
    blidxs = ((bl-blbinmin)/blbinwidth+0.5).astype(np.int32)
    for i in range(bl.size):
        blidx = blidxs.flat[i]
        prval = pr.flat[i]
        if 0<=blidx<nblbins and np.isfinite(prval):
            Q0[blidx] += 1
            Q1[blidx] += prval
    return Q0,Q1

def get_bin_edges(key,binparams=BINPARAMS):
    varname  = binparams[key]
    binedges = np.arange(varname['min'],varname['max']+varname['width'],varname['width'])
    return binedges

def calc_binned_stats(data,binparams=BINPARAMS):
    blbins  = get_bin_edges('bl',binparams)
    nblbins = blbins.size
    blbinmin   = binparams['bl']['min']
    blbinwidth = binparams['bl']['width']
    Q0,Q1 = fast_binned_stats(data.bl.values,data.pr.values,blbinmin,blbinwidth,nblbins)
    ds = xr.Dataset(data_vars={'Q0':(('bl'),Q0),'Q1':(('bl'),Q1)},coords={'bl':blbins})
    ds.Q0.attrs = dict(long_name='Count of points in each bin')
    ds.Q1.attrs = dict(long_name='Sum of precipitation in each bin',units='mm/day')
    ds.bl.attrs = dict(long_name='Average buoyancy in the lower troposphere', units='m/s²')
    return ds

def get_bootstrap_samples(data,niterations=NITERATIONS,yearsinsample=YEARSINSAMPLE):
    allyears = np.unique(data.time.dt.year.values)
    nyears   = len(allyears)
    yearidxs = np.random.choice(nyears,size=(niterations,yearsinsample),replace=True)
    samples = []
    for iterationidxs in yearidxs:
        selectedyears = allyears[iterationidxs]
        sample = data.sel(time=data.time.dt.year.isin(selectedyears))
        samples.append(sample)
    return samples

def calc_bootstrap_stats(data,niterations=NITERATIONS,yearsinsample=YEARSINSAMPLE):
    samples   = get_bootstrap_samples(data,niterations,yearsinsample)
    statslist = [calc_binned_stats(sample) for sample in samples]
    return xr.concat(statslist,dim='iteration')

def get_confidence_intervals(data,confidencelevel=0.95): 
    lowend  = (1-confidencelevel)/2
    highend = 1-lowend
    cilower = data.quantile(lowend,dim='iteration')
    ciupper = data.quantile(highend,dim='iteration')
    mean    = data.mean('iteration')
    return mean,cilower,ciupper

def prepare_bootstrap_data(stats,niterations=NITERATIONS,yearsinsample=YEARSINSAMPLE):
    results = {}
    for region in stats.region.values:
        results[region] = {}
        casekey = 'JJ' if region in ['Eastern Arabian Sea','Central India','Central Bay of Bengal'] else 'JA'
        for month,label,color in CASES[casekey]:
            monthstats = stats.sel(region=region,month=month)
            results    = calc_bootstrap_stats(monthstats)
            binmeanpr  = (results.Q1/results.Q0).where(results.Q0>=SAMPLETHRESH)
            mean,cilower,ciupper = get_confidence_intervals(binmeanpr)
            
            blpdf   = get_pdf(monthstats,bintype='1D',precipitating=True)
            modeidx = np.nanargmax(blpdf.values)
            mode = blpdf.bl.values[modeidx]
            
            results[region][month] = {
                'mean_pr': mean,
                'ci_lower':cilower,
                'ci_upper':ciupper,
                'blpdf':blpdf,
                'mode':mode,
                'label':label,
                'color':color}
    return results

def create_plot(results):
    fig = pplt.figure(refwidth=2.5)
    axs = fig.subplots(nrows=2,ncols=3,share=False)
    
    blkwargs = dict(s=10,alpha=0.5)
    prkwargs = dict(s=30,marker='o')
    
    for i, region in enumerate(results.keys()):
        ax = axs[i]
        ax.format(titleloc='l', title=f'{region}')
        bx = ax.twinx()
        bx.format(yscale='log', yformatter='log', ylim=(10e-6, 0.1))
        if i == 0 or i == 1 or i == 3:
            bx.format(yticklabels=[])
        
        casekey = 'JJ' if i < 3 else 'JA'
        axnum = 0 if casekey == 'JJ' else 3
        
        for month, data in results[region].items():
            bx.scatter(data['blpdf'], color=data['color'], **blkwargs)
            ax.scatter(data['mean_pr'], color=data['color'], label=data['label'], **prkwargs)
            ax.fill_between(data['mean_pr'].bl, data['ci_lower'], data['ci_upper'], color=data['color'], alpha=0.3)
            ax.axvline(data['mode'], color=data['color'], linestyle='--', alpha=0.5)
        
        axs[axnum].legend(loc='ul', ncols=1)
    
    fig.format(xlabel='BL (m/s²)', ylabel='Precipitation (mm/day)')
    pplt.show()