In [None]:
import intake
import xarray as xr
import numpy as np 
from functools import partial
from workflow.scripts.utils import  regrid_global
import matplotlib.pyplot as plt
from workflow.scripts.plotting_tools import global_map
import statsmodels.api as sm

In [None]:

var_id = 'od550dust'
exp_id = 'piClim-control'
table_id = snakemake.config['table_ids'].get(var_id, snakemake.config['table_id_default'])
tresh = snakemake.params.get('od_percentile', .9)

esm_cat = intake.open_esm_datastore(snakemake.input.catalog)

col = esm_cat.search(experiment_id = exp_id, variable_id=var_id)
mod_id2xdust = esm_cat.search(experiment_id = 'piClim-2xdust', variable_id=var_id).unique()['source_id']

mod_ids = list(col.unique()['source_id'])
mod_id2xdust = list(mod_id2xdust)

mod_ids = [mod_id for mod_id in mod_ids if mod_id in mod_id2xdust]

In [None]:
def regrid_dataset(ds, grid_params, grid_path):

    method=grid_params.get('method','conservative')
    if grid_path:
        out_grid = xr.open_dataset(grid_path)
        ds = regrid_global(ds, out_grid, method=method)
    elif grid_params.get('dxdy',None):
        dxdy = grid_params['dxdy']
        ds= regrid_global(ds, lon=dxdy[0], lat=dxdy[1], method=method,ignore_degenerate=True)
    else:
        print('No outgrid provided!')
    return ds

def get_dod_treshold(data,surfFrac=0.8):
    data_1d = data.values.ravel()
    edcf = sm.distributions.ECDF(data_1d)
    filtered = np.where(edcf.y>surfFrac, edcf.x, np.nan)
    filtered = filtered[~np.isnan(filtered)]
    return filtered.min()

In [None]:
dsets = {}
for mod_id in mod_ids:
    if snakemake.config['model_specific_variant'].get(exp_id, None):
        memb_id = snakemake.config['model_specific_variant'][exp_id].get(mod_id, 
                                                                         snakemake.config['variant_default'])
    else:
        memb_id = snakemake.config['variant_default']
    temp_col = col.search(source_id=mod_id, member_id=memb_id, table_id=table_id)
    ds = temp_col.to_dataset_dict(progressbar=False)
    ds = ds[list(ds.keys())[0]]
    ds = ds.drop('member_id').squeeze() 
    if snakemake.config.get('regrid_params', None) and snakemake.params.get('regrid', True):
        if ds.data_vars.get('lon_bnds') is not None:
            if 'time' in ds.lon_bnds.dims:
                ds = ds.drop('lon_bnds')
        grid_params=snakemake.config['regrid_params']
        regrid_func = partial(regrid_dataset,grid_params = snakemake.config['regrid_params'],
                                                grid_path = grid_params.get('grid_path',None))
    dsets[mod_id] = regrid_func(ds).groupby("time.month").mean().load()


In [None]:
def create_mask_dataset(dsets, tresh=.9):
    out_ds = xr.zeros_like(dsets[next(iter(dsets))])
    out_ds = out_ds.rename({'od550dust': 'mean'})
    masks = []
    for k, ds in dsets.items():
        temp_mask = ds['od550dust'].mean(dim='month')
        tresh_od = get_dod_treshold(temp_mask,tresh)
        temp_mask = xr.where(temp_mask>tresh_od,1,0)
        if 'wavelength' in temp_mask.coords:
            temp_mask = temp_mask.drop('wavelength')
        
        out_ds = out_ds.assign({k:temp_mask})
        
        masks.append(temp_mask)
    multi_models = xr.concat(masks,dim='mod').mean(dim='mod')
    out_ds = out_ds.assign({'mean':multi_models})
    return out_ds


In [None]:
mask = create_mask_dataset(dsets, tresh)
mask.to_netcdf(snakemake.output.outpath)