In [None]:
import intake

from workflow.scripts.utils import regrid_global, calculate_pooled_variance, compute_annual_emission_budget
import time
import xarray as xr
from functools import partial
import xesmf
import pandas as pd
from pyclim_noresm.general_util_funcs import global_avg
from intake_esm.derived import DerivedVariableRegistry
import numpy as np
import scipy.stats as st
import yaml
from scipy.stats._resampling import _bootstrap_resample

In [None]:
burd_exp = xr.open_dataset(snakemake.input.burden_exp)

config = snakemake.config 


mod_id = snakemake.wildcards.model
exp_id = snakemake.wildcards.experiment
time_slice = config.get('time_slice', slice(2,None))


if snakemake.config['model_specific_variant'].get(exp_id, None):
    memb_id = snakemake.config['model_specific_variant'][exp_id].get(mod_id, snakemake.config['variant_default'])
else:
    memb_id = snakemake.config['variant_default']

In [None]:
dvr = DerivedVariableRegistry()

@dvr.register(variable='duDepRatio',
              query=dict(variable_id=['drydust','wetdust'],
                table_id='AERmon'))
def depo_ratio(ds):
    ds = ds.copy()
    out_ds = ds['drydust'] / (ds['wetdust']+ds['drydust'])
    out_ds.attrs['units'] = '1'
    out_ds.attrs['long_name'] = 'Dust deposition ratio'
    out_ds.attrs['standard_name'] = 'dust_deposition_ratio'
    ds = ds.assign(duDepRatio=out_ds)
    return ds

def total_dust_deposition(ds):
    out_ds = ds['drydust'] + ds['wetdust']
    out_ds.attrs['units'] = 'kg m-2 s-1'
    out_ds.attrs['long_name'] = 'Total dust deposition'
    out_ds.attrs['standard_name'] = 'total_dust_deposition'
    ds = ds.assign(totdust=out_ds)
    
    return ds

def calc_MEE(ds):
    mass_exinct = ((global_avg(ds['od550dust']))/((ds['concdust']*ds['cell_area']*1e-9).sum(dim=['lon','lat'])))
    mass_exinct_map = (ds['od550dust']/(ds['concdust']*ds['cell_area']*1e-9))
    mass_exinct.attrs['units'] = 'g-1'
    mass_exinct_map.attrs['units'] = 'g-1'
    ds = ds.assign(od550dustMEE = mass_exinct_map)
    ds = ds.assign(mass_ext = mass_exinct)

    return ds

def calc_lifetime(ds):
    total_deposition=(ds['emidust']*ds['cell_area']).sum(dim=['lon','lat'])
    burden=(ds['concdust']*ds['cell_area']).sum(dim=['lon','lat'])
    lifetime = (burden/total_deposition)*365
    lifetime.attrs['units'] = 'Days'
    ds = ds.assign(lifetime=lifetime)
    return ds
    
def get_emissions_source_region(ds, latmin,latmax,lonmin, lonmax):
    var_id = ds.variable_id
    if ds.lon.max() < 180:
        ds = ds.assign_coords(lon=(ds.lon + 180) % 360 - 180)
        ds = ds.sortby(ds.lon)
    ds = ds.sel(lat=slice(latmin,latmax), lon=slice(lonmin,lonmax))
    ds = global_avg(ds[var_id])
    ds.attrs['units'] = 'kg m-2 s-1'
    return ds
    
def resample_time(data):
    vname = data.variable_id
    ds = data.copy()
    attrs = data[vname].attrs.copy()

    
    with xr.set_options(keep_attrs=True):
        
        if data[vname].units == 'kg m-2 s-1': # annual emission / deposition 
            data=data.resample(time='Y').mean()*365*24*60*60 # convert to kg m-2 yr-1
            data[vname].attrs['units'] = '{} year-1'.format(' '.join(ds[vname].attrs['units'].split(' ')[:-1]))
            data.attrs['history'] = data.attrs.get('history', '') + f', annual average converted to kg m-2 yr-1'
        else:
            data=data.resample(time='Y').mean()
            # data[vname].attrs = attrs
            data.attrs['history'] = data.attrs.get('history','') + f', annual average'    
    if ds.cf.bounds.get('lon'):
        data = data.assign({ds.cf.bounds['lon'][0]:ds[ds.cf.bounds['lon'][0]]})
        data = data.assign({ds.cf.bounds['lat'][0]:ds[ds.cf.bounds['lat'][0]]})
    else:
        data = data.cf.add_bounds(['lon','lat'])
    return data

In [None]:
diagnostic_variables = ['od550dust','duDepRatio', 'emidust','abs550aer']
cat = intake.open_esm_datastore(snakemake.input.catalog, registry=dvr)
cat_exp = cat.search(experiment_id=exp_id,  source_id=mod_id,
                 member_id=memb_id)
exp_data = cat_exp.search(variable_id=diagnostic_variables, table_id='AERmon')

cell_area = xr.open_dataset(snakemake.input.model_area_mask)

In [None]:
def calc_emission_per_source_reg(ds, source_regs):
    for name, bbox in source_regs.items():
        ds = ds.assign({
            f'{name} emidust': ds['emidust'].sel(lat=slice(bbox['latmin'], bbox['latmax']), lon=slice(bbox['lonmin'], bbox['latmax'])).sum(dim=['lat', 'lon'])

        })
    return ds

In [None]:
exp_dict = exp_data.to_dataset_dict(aggregate=True,skip_on_error=True ,preprocess=resample_time)
ds_exp = exp_dict[list(exp_dict.keys())[0]]


try:
    ds_exp = total_dust_deposition(ds_exp)
except KeyError:
    ds_exp = ds_exp.assign(totdep=np.nan)


In [None]:
source_regs = config['dust_source_regions']
ds_exp = calc_emission_per_source_reg(ds_exp, source_regs)

ds_exp = ds_exp.compute()

ds_exp = ds_exp.assign(cell_area=cell_area['cell_area'])


In [None]:
burden_exp = burd_exp.cf.add_bounds(['lon','lat'])

burden_exp = resample_time(burden_exp)


ds_exp = ds_exp.assign(concdust=burden_exp['concdust'])

In [None]:

ds_exp = calc_MEE(ds_exp)

ds_exp = calc_lifetime(ds_exp)

In [None]:
ds_exp = ds_exp.isel(member_id=0)
ds_exp.to_netcdf(snakemake.output.dust_diag_exp)