In [None]:
import intake

from workflow.scripts.utils import copy_meta_data_CMIP,transelate_aerocom_helper, regrid_global
from pyclim_noresm.general_util_funcs import yearly_avg
import time
import xarray as xr
from functools import partial
import xesmf
import pandas as pd
from intake_esm.derived import DerivedVariableRegistry

In [None]:
def update_datatracker():
    import yaml
    find_issue = lambda d: [k for k,v in d.items() if var_id in v]
    erfs = snakemake.config.get('variables')
    burdens_dict = snakemake.config.get('burdens_dict')
    burdens_vars = find_issue(burdens_dict)
    erfs_vars = find_issue(erfs)
    issue_vars = burdens_vars+erfs_vars
    if not issue_vars:
        issue_vars = [var_id]
        
    
    with open(snakemake.input.data_tracker, 'r') as f:
        data = yaml.load(f, yaml.SafeLoader)
 
    print(data)
    for v in issue_vars:
        for k,requests in data.items():
            data[k] = list(set(requests)-set(issue_vars)) 
    print(data)
    
    with open(snakemake.input.data_tracker, 'w') as f:
        yaml.dump(data,f, yaml.SafeDumper,default_flow_style=False)
    
    raise ValueError(f"Could not find {var_id} in, removing {','.join(issue_vars)} from data tracker ")

In [None]:
exp_id = snakemake.wildcards.experiment
mod_id = snakemake.wildcards.model
var_id = snakemake.wildcards.variable
freq = snakemake.wildcards.freq
table_id = snakemake.config['table_ids'].get(var_id, snakemake.config['table_id_default'])
if snakemake.config['model_specific_variant'].get(exp_id, None):
    memb_id = snakemake.config['model_specific_variant'][exp_id].get(mod_id, snakemake.config['variant_default'])
else:
    memb_id = snakemake.config['variant_default']

In [None]:
kind = snakemake.params.get("kind", "experiment")
params = snakemake.params
accumlative_vars = params.get('accumalative_vars',None)
use_derived_vars = snakemake.config['use_derived_vars'].get(var_id,False)
if isinstance(use_derived_vars, dict):
    use_derived_vars = use_derived_vars.get(mod_id, use_derived_vars['all'])


deptotal = {
    'depdust': { 
    'vars':['wetdust', 'drydust'],
    'long_name': "Total dust deposition",
    'standard_name' : "total_deposition_of_dust"
    },
    'depso4' :{ 
    'vars': ['wetso4', 'dryso4'],
    'long_name': "Total SO4 deposition",
    'standard_name' : "total_deposition_of_so4"
    },
    'depss' :{ 
    'vars': ['wetss', 'dryss'],
    'long_name': "Total Sea salt deposition",
    'standard_name' : "total_deposition_of_seasalt"
    },
    'depoa' : {
    'vars' : ['wetoa', 'dryoa'],
    'long_name': "Total deposition of Organic Aerosols",
    'standard_name' : 'total_deposition_of_oa'
    },
    'depso2' : {
    'vars' : ['wetso2', 'dryso2'],
    'long_name' : "Total deposition of SO2",
    'standard_name' : 'total_deposition_so2'
    }
}

if deptotal.get(var_id, None):
    depvar = var_id
else:
    depvar = None

In [None]:
dvr = DerivedVariableRegistry()
@dvr.register(variable='rlutaf', 
              query=dict(experiment_id=exp_id,
                    source_id=mod_id,
                     variable_id = ['rlutcs', 'rlut','rlutcsaf'],
                     member_id=memb_id,
                     )
             )
def calc_rlutaf(ds):
    ds['rlutaf'] = ds['rlut'] + (ds['rlutcsaf'] - ds['rlutcs'])
    attrs = ds['rlut'].attrs.copy()
    ds = ds.drop(['rlut','rlutcsaf','rlutcs'])
    ds['rlutaf'].attrs = attrs 
    ds['rlutaf'].attrs['long_name'] = 'TOA Outgoing aerosol free Longwave Radiation'
    ds.attrs['variable_id'] = 'rlutaf'
    return ds
    
@dvr.register(variable='rsutaf', 
              query=dict(experiment_id=exp_id,
                    source_id=mod_id,
                     variable_id = ['rsutcs', 'rsut','rsutcsaf'],
                     member_id=memb_id,
                     )
             )
def calc_rsutaf(ds):
    ds['rsutaf'] = ds['rsut'] + (ds['rsutcsaf'] - ds['rsutcs'])
    attrs = ds['rsut'].attrs.copy()
    ds = ds.drop(['rsut','rsutcsaf','rsutcs'])
    ds['rsutaf'].attrs = attrs 
    ds['rsutaf'].attrs['long_name'] = 'TOA Outgoing aerosol free Shortwave Radiation'
    ds.attrs['variable_id'] = 'rsutaf'
    return ds

@dvr.register(variable='loaddust', 
              query=dict(experiment_id=exp_id,
                    source_id=mod_id,
                     variable_id = ['mmrdust', 'airmass'],
                     member_id=memb_id,
                     )
             )
def calc_loaddust(ds):
    ds['loaddust'] = ds['airmass']*ds['mmrdust']
    attrs = ds['mmrdust'].attrs.copy()
    ds = ds.drop(['mmrdust','airmass'])
    ds['loaddust'].attrs = attrs
    ds['loaddust'].attrs['long_name'] = "Load of Dust"
    ds['loaddust'].attrs['units'] = "kg m-2"
    ds['loaddust'].attrs['standard_name'] = "load_of_dust"
    ds.attrs['variable_id'] = 'loaddust'
    return ds

if depvar:
    @dvr.register(variable=depvar,
                  query=dict(experiment_id=exp_id,
                        source_id=mod_id,
                         variable_id = deptotal[depvar]['vars'],
                         member_id=memb_id,
                 ))
    def calc_dep(ds):
        wdp = deptotal[depvar]['vars'][0]
        ddp = deptotal[depvar]['vars'][1]
        ds[depvar] = ds[wdp] + ds[ddp]
        attrs = ds[ddp].attrs.copy()
        ds = ds.drop([ddp, wdp])
        ds[depvar].attrs = attrs
        ds[depvar].attrs['long_name'] = deptotal[depvar]['long_name']
        ds[depvar].attrs['standard_name'] = deptotal[depvar]['long_name']
        ds.attrs['variable_id'] = depvar
        return ds



In [None]:
if use_derived_vars:
    esm_cat = intake.open_esm_datastore(snakemake.input.catalog, registry=dvr)
else:
    esm_cat = intake.open_esm_datastore(snakemake.input.catalog, registry=dvr)
col = esm_cat.search(experiment_id=exp_id,
                source_id=mod_id,
                 variable_id = var_id,
                 member_id=memb_id,
)
if var_id in col.unique()['variable_id'] and (use_derived_vars==False):
    col.derivedcat = DerivedVariableRegistry()
    col = col.search(variable_id = var_id)
    if var_id == 'concdust':
        esm_cat.derivedcat = DerivedVariableRegistry()
        col = esm_cat.search(experiment_id=exp_id,
                source_id=mod_id,
                 variable_id = 'loaddust',
                 member_id=memb_id,
                 table_id='Emon')
    elif col.nunique()['table_id'] > 1:
        col = col.search(table_id=table_id)

        
if col.nunique().version > 1 and not col.unique()['derived_variable_id']:
    latest = max(col.df['version'].unique())
    col = col.search(version=[latest])


In [None]:

try:
    ds = col.to_dataset_dict(xarray_open_kwargs={'use_cftime':True},)
    ds = ds[list(ds.keys())[0]]
    ds = ds.drop('member_id').squeeze()
except IndexError:
    update_datatracker()

if col.derivedcat.values():
    expected_variables = col.derivedcat[var_id].query['variable_id']
    if sorted(expected_variables) != sorted(col.unique()['variable_id']):
        update_datatracker()

In [None]:
if var_id == 'concdust' and 'concdust' not in ds.data_vars:
    esm_cat.derivedcat = DerivedVariableRegistry()
    col = esm_cat.search(experiment_id=exp_id,
                source_id=mod_id,
                 variable_id = 'loaddust',
                 member_id=memb_id,
                 table_id='Emon')
    ds = col.to_dataset_dict(xarray_open_kwargs={'use_cftime':True},)
    ds = ds[list(ds.keys())[0]]
    ds = ds.drop('member_id').squeeze()
    ds = ds.rename({'loaddust':'concdust'})
    ds.attrs['variable_id'] = 'concdust'

In [None]:
ds = ds.reset_coords()
if 'ps' not in ds.data_vars and 'lev' in ds.dims:
    col = esm_cat.search(experiment_id=exp_id,
                source_id=mod_id,
                 variable_id = 'ps',
                 member_id=memb_id,
                        table_id='Amon')
    ps = col.to_dataset_dict(xarray_open_kwargs={'use_cftime':True},)
    ps = ps[list(ps.keys())[0]]
    ps = ps.drop('member_id').squeeze()
    ds = ds.assign(ps=ps['ps'])

In [None]:
dvar_orr = set(list(ds.data_vars))
ds_orr = ds.copy()

In [None]:
def check_bounds(ds, variable):
    return 'time' in set(ds.cf.get_bounds(variable).dims) - set(ds.cf[variable].dims)
if 'lon_bnds' in ds.coords or 'lat_bnds' in ds.coords:
    if check_bounds(ds, 'longitude') or check_bounds(ds, 'latitude'):
        ds = ds.drop('lon_bnds')
        ds = ds.drop('lat_bnds')
        ds = ds.cf.add_bounds('latitude', dim='lat')
        ds = ds.cf.add_bounds('longitude', dim='lon')

In [None]:
def regrid_dataset(ds, grid_params, grid_path):

    method=grid_params.get('method','conservative')
    if grid_path:
        out_grid = xr.open_dataset(grid_path)
        ds = regrid_global(ds, out_grid, method=method)
    elif grid_params.get('dxdy',None):
        dxdy = grid_params['dxdy']
        ds= regrid_global(ds, lon=dxdy[0], lat=dxdy[1], method=method,ignore_degenerate=True)
    else:
        print('No outgrid provided!')
    return ds

In [None]:
if snakemake.config.get('regrid_params', None) and snakemake.params.get('regrid', True):
    if ds.data_vars.get('lon_bnds') is not None:
        if 'time' in ds.lon_bnds.dims:
            ds = ds.drop('lon_bnds')
    grid_params=snakemake.config['regrid_params']
    regrid_func = partial(regrid_dataset,grid_params = snakemake.config['regrid_params'],
                                            grid_path = grid_params.get('grid_path',None))
    ds = regrid_func(ds)
    

In [None]:
with xr.set_options(keep_attrs=True):
    if not ds.cf.bounds.get('lon', None):
        ds = ds.cf.add_bounds(['lon', 'lat'])
    if freq == 'Ayear':

        data=ds
        vname = ds.variable_id
        with xr.set_options(keep_attrs=True):
            if data[vname].units == 'kg m-2 s-1': # annual emission / deposition 
                data=data[data.variable_id].resample(time='Y').mean()*365*24*60*60 # convert to kg m-2 yr-1
                data.attrs['units'] = '{} year-1'.format(' '.join(data.attrs['units'].split(' ')[:-1]))
                data = data.to_dataset(name=vname)
                data.attrs = ds.attrs.copy()
                data.attrs['history'] = ds.attrs.get('history', '') + f', annual average converted to kg m-2 yr-1'
            
            else:
                data=data[data.variable_id].resample(time='Y').mean()
                data = data.to_dataset(name=vname)
                data.attrs = ds.attrs.copy()
                data.attrs['history'] = data.attrs.get('history','') + f', annual average'    
            
            if 'ps' in ds.data_vars:
                ps = ds['ps'].resample(time='Y').mean().copy()
                data = data.assign(ps=ps)
        t=data.time.dt.strftime('%Y')
        data = data.assign_coords({'time':xr.cftime_range(start=str(t[0].values), end= str(t[-1].values), freq='YS')})
        data = data.assign({ds.cf.bounds['lon'][0]:ds[ds.cf.bounds['lon'][0]]})
        data = data.assign({ds.cf.bounds['lat'][0]:ds[ds.cf.bounds['lat'][0]]})
        if 'time_bnds' in ds.data_vars:
            data = ds.drop_vars('time_bnds')
        if snakemake.params.get('regrid', True):
            remove_vars = {'time_bnds', 'lon_bnds', 'lat_bnds'}
        else:
            remove_vars = {}
    elif freq == 'clim':
        t0 = data.time[0].dt.strftime('%Y/%m').values
        t1 = data.time[0].dt.strftime('%Y/%m').values
        data = data.groupby('time.month').mean('time')
        data[data.variable_id].attrs['history'] = data[data.variable_id].attrs.get('history','') + f', clim mean {t0}-{t1}'
        if wildcards.freq=='2010':
            import cftime
            import pandas as pd
            data = data.rename(month='time')
            cftimes = cftime.date2num(pd.date_range('2010-01-31','2010-12-31', freq='M').to_list(),
                                              'days since 2010-01-01', 
                                                  has_year_zero=False, calendar = 'gregorian')
            data = data.assign_coords(time=cftimes)
            data.time.attrs['units'] = 'days since 2010-01-01'
        dvar_attrs = copy_meta_data_CMIP(data[var_id].attrs)
        
    elif freq == 'Amon':
        data = ds
        if 'time_bnds' in ds.data_vars:
            data = ds.drop_vars('time_bnds')
        if snakemake.params.get('regrid', True):
            remove_vars = {'time_bnds', 'lon_bnds', 'lat_bnds'}
        else:
            remove_vars = {}
    else:
        raise(ValueError(f'{wildcards.freq} is an invalid frequency'))

In [None]:
if 'orog' in ds.data_vars:
    data = data.assign(orog=ds['orog'].copy())
dvar_regrid = set(list(data.data_vars))

lost_vars = list(dvar_orr-dvar_regrid-remove_vars)

if lost_vars:
    data = data.assign({v : ds_orr[v] for v in lost_vars})

In [None]:
data = data.compute()

In [None]:
data.attrs['frequency'] = snakemake.wildcards.freq
data.to_netcdf(snakemake.output.outpath)