In [None]:
import intake

from workflow.scripts.utils import regrid_global, calculate_pooled_variance, compute_annual_emission_budget
import time
import xarray as xr
import xesmf
import pandas as pd
from pyclim_noresm.general_util_funcs import global_avg
from intake_esm.derived import DerivedVariableRegistry
import numpy as np
import scipy.stats as st
import yaml
from scipy.stats._resampling import _bootstrap_resample
from functools import partial

In [None]:
burd_exp = xr.open_dataset(snakemake.input.burden_exp)

config = snakemake.config 


mod_id = snakemake.wildcards.model
exp_id = snakemake.wildcards.experiment
time_slice = config.get('time_slice', {'start':2, 'end': None})

time_slice = slice(time_slice['start'], time_slice['end'])

if snakemake.config['model_specific_variant'].get(exp_id, None):
    memb_id = snakemake.config['model_specific_variant'][exp_id].get(mod_id, snakemake.config['variant_default'])
else:
    memb_id = snakemake.config['variant_default']

In [None]:
dvr = DerivedVariableRegistry()

@dvr.register(variable='duDepRatio',
              query=dict(variable_id=['drydust','wetdust'],
                table_id='AERmon'))
def depo_ratio(ds):
    ds = ds.copy()
    out_ds = ds['drydust'] / (ds['wetdust']+ds['drydust'])
    out_ds.attrs['units'] = '1'
    out_ds.attrs['long_name'] = 'Dust deposition ratio'
    out_ds.attrs['standard_name'] = 'dust_deposition_ratio'
    ds = ds.assign(duDepRatio=out_ds)
    return ds   

@dvr.register(variable='radatm',
             query=dict(
             variable_id=['rlds','rsds','rlus','rsus','rlut','rsut', 'rsdt'],
             table_id='Amon'
                )
             )
def atm_col_netrad(ds):
    ds = ds.copy()
    out_ds = (ds['rlut']- ds['rlus'] + ds['rlds']) - (ds['rsdt']-ds['rsut'] - ds['rsds']+ds['rsus']) 
    out_ds.attrs['units'] = 'w m-2'
    out_ds.attrs['long_name'] = 'atmospheric radiative cooling'
    out_ds.attrs['standard_name'] = 'atmospheric_radiative_cooling'
    out_ds.attrs['comment'] = "net radiative loss of energy by the atmosphere due to SW and LW radiation flowing across TOA and surface"
    ds = ds.assign(radatm=out_ds)
    ds = ds.drop(['rlds','rsds','rlus','rsus','rlut','rsut', 'rsdt'])
    return ds

@dvr.register(variable='radatmcs',
             query=dict(
             variable_id=['rldscs','rsdscs','rlutcs','rsutcs', 'rsdt','rsuscs','rlus'],
             table_id='Amon'
                )
             )
def atm_col_netrad_cs(ds):
    ds = ds.copy()
    out_ds = (ds['rlutcs']-ds['rlus']-ds['rldscs'])-(ds['rsdt']-ds['rsutcs'] - ds['rsdscs'] + ds['rsuscs']) 
    out_ds.attrs['units'] = 'w m-2'
    out_ds.attrs['long_name'] = 'clear sky atmospheric radiative cooling'
    out_ds.attrs['standard_name'] = 'clear_sky_atmospheric_radiative_cooling'
    out_ds.attrs['comment'] = "net clear sky radiative loss of energy by the atmosphere due to SW and LW radiation flowing across TOA and surface"
    ds = ds.assign(radatmcs=out_ds)
    ds = ds.drop(['rldscs','rsdscs','rlutcs','rsutcs', 'rsdt','rsuscs','rlus'])
    return ds


def total_dust_deposition(ds):
    out_ds = ds['drydust'] + ds['wetdust']
    out_ds.attrs['units'] = 'kg m-2 s-1'
    out_ds.attrs['long_name'] = 'Total dust deposition'
    out_ds.attrs['standard_name'] = 'total_dust_deposition'
    ds = ds.assign(totdust=out_ds)
    
    return ds

def dust_calc_MEE(ds, odvar='od550aer'):
    mass_exinct = global_avg(ds[odvar])/(global_avg(ds['concdust'])*1e3)
    mass_exinct_map =ds[odvar]/(ds['concdust']*1e3)
    mass_exinct.attrs['units'] = 'm2 g-1'
    mass_exinct_map.attrs['units'] = 'm2 g-1'
    if odvar =='abs550aer':
        mass_exinct.attrs['long_name'] = f'Mass absorbtion coefficient {odvar}'
        mass_exinct_map.attrs['long_name'] = f'Mass absorbtion coefficient {odvar}'
        mass_exinct.attrs['standard_name'] = f'mass_absorbtion_coefficient {odvar}'
        mass_exinct_map.attrs['standard_name'] = f'mass_absorbtion_coefficient {odvar}'
        ds = ds.assign({f'{odvar}_MAE' : mass_exinct_map})
        ds = ds.assign({f'{odvar}_mass_abs' : mass_exinct})
    else:
        mass_exinct.attrs['long_name'] = f'Mass extinction coefficient {odvar}'
        mass_exinct_map.attrs['long_name'] = f'Mass extinction coefficient {odvar}'
        mass_exinct.attrs['standard_name'] = f'mass_extinction_coefficient {odvar}'
        mass_exinct_map.attrs['standard_name'] = f'mass_extinction_coefficient {odvar}'
        ds = ds.assign({f'{odvar}_MEE' : mass_exinct_map})
        ds = ds.assign({f'{odvar}_mass_ext' : mass_exinct})
    return ds

def calc_lifetime(ds):
    total_deposition=(ds['emidust']*ds['cell_area']).sum(dim=['lon','lat'])
    burden=(ds['concdust']*ds['cell_area']).sum(dim=['lon','lat'])
    lifetime = (burden/total_deposition)*365
    lifetime.attrs['units'] = 'Days'
    ds = ds.assign(lifetime=lifetime)
    return ds
    
def get_emissions_source_region(ds, latmin,latmax,lonmin, lonmax):
    var_id = ds.variable_id
    ds = ds.sel(lat=slice(latmin,latmax), lon=slice(lonmin,lonmax))
    ds = global_avg(ds[var_id])
    ds.attrs['units'] = 'kg m-2 s-1'
    return ds


def calc_ang4487aer(ds):
    if ds.get('od870aer', None) is None:
        return ds

    if ds.get('od440aer', None) is not None:
        od_ds = ds['od440aer']
        od0 = 440
    else:
        od_ds = ds['od550aer']
        od0 = 550
        
    ang4487aer = -np.log(od_ds/ds['od870aer'])/np.log(od0/870)
    ang4487aer.attrs['units'] = '1'
    ang4487aer.attrs['long_name'] = f'Angstrom exponent {od0}-870 nm'
    ds = ds.assign(ang4487aer=ang4487aer)
    return ds


    
def resample_time(data, ref=None):
    vname = data.variable_id
    ds = data.copy()
    attrs = data[vname].attrs.copy()
    if ref is not None:
        if np.all(ref.lon.values==data.lon.values) == False and ref.lon.shape==data.lon.shape:
            data = data.assign_coords(lon=ref.lon)
        if np.all(ref.lat.values==data.lat.values) == False and ref.lat.shape==data.lat.shape:
            data = data.assign_coords(lat=ref.lat)
    with xr.set_options(keep_attrs=True):
        
        if data[vname].units == 'kg m-2 s-1': # annual emission / deposition 
            data=data.assign({vname : data[vname]*365*24*60*60}) # convert to kg m-2 yr-1
            data = data.resample(time='Y').mean()
            data[vname].attrs['units'] = '{} year-1'.format(' '.join(ds[vname].attrs['units'].split(' ')[:-1]))
            data[vname].attrs['history'] = data.attrs.get('history', '') + f', annual average converted to kg m-2 yr-1'
        else:
            data=data.resample(time='Y').mean()
            # data[vname].attrs = attrs
            data.attrs['history'] = data.attrs.get('history','') + f', annual average'    
    if ds.cf.bounds.get('lon'):
        data = data.assign({ds.cf.bounds['lon'][0]:ds[ds.cf.bounds['lon'][0]]})
        data = data.assign({ds.cf.bounds['lat'][0]:ds[ds.cf.bounds['lat'][0]]})
    if ds.source_id == 'CNRM-ESM2-1' or not ds.cf.bounds.get('lon'):
        pass
    else:
        data = data.drop([ds.cf.bounds['lon'][0],ds.cf.bounds['lat'][0]])
    return data

In [None]:
diagnostic_variables = ['od550dust','duDepRatio', 'od550aer' ,'emidust','abs550aer', 'od440aer','od870aer']
diag_amon_variables = ['tas','radatm','radatmcs']
cat = intake.open_esm_datastore(snakemake.input.catalog, registry=dvr)
cat_exp = cat.search(experiment_id=exp_id,  source_id=mod_id,
                 member_id=memb_id)
exp_data = cat_exp.search(variable_id=diagnostic_variables, table_id='AERmon')
exp_data_Amon = cat_exp.search(variable_id=diag_amon_variables, table_id='Amon')
cell_area = xr.open_dataset(snakemake.input.model_area_mask)

In [None]:
def calc_emission_per_source_reg(ds, source_regs):
    for name, bbox in source_regs.items():
        temp_ds = ds.sel(lat=slice(bbox['latmin'], bbox['latmax']), lon=slice(bbox['lonmin'], bbox['lonmax']))
        ds = ds.assign({
            f'{name} emidust': (temp_ds['emidust']*temp_ds['cell_area']).sum(dim=['lat', 'lon'])

        })
    return ds

In [None]:
exp_dict_amon = exp_data_Amon.search(variable_id=diag_amon_variables[:-2]).to_dataset_dict(aggregate=True,skip_on_error=False 
                                                                                           ,preprocess=resample_time)
ds_exp_amon = exp_dict_amon[list(exp_dict_amon.keys())[0]].squeeze()

ds_radamt = exp_data_Amon.search(variable_id=diag_amon_variables[-2]).to_dataset_dict(aggregate=True,skip_on_error=False 
                                                                                           ,preprocess=resample_time)[list(exp_dict_amon.keys())[0]].squeeze()

ds_radamtcs = exp_data_Amon.search(variable_id=diag_amon_variables[-1]).to_dataset_dict(aggregate=True,skip_on_error=False 
                                                                                           ,preprocess=resample_time)[list(exp_dict_amon.keys())[0]].squeeze()

resamp_func = partial(resample_time, ref=ds_exp_amon)
exp_dict = exp_data.to_dataset_dict(aggregate=True,skip_on_error=False ,preprocess=resamp_func)

ds_exp = exp_dict[list(exp_dict.keys())[0]].squeeze()
ds_l = [ds_exp_amon, ds_radamt, ds_radamtcs]
if ds_exp_amon.lon.shape != ds_exp.lon.shape:
    regrid_temp = True
else:
    regrid_temp=False
    for d in ds_l:
        ds_exp = ds_exp.merge(d)

In [None]:

ds_exp=ds_exp.squeeze()
ds_exp = ds_exp.assign(cell_area=cell_area['cell_area'])
if mod_id == 'NorESM2-LM' and exp_id == 'piClim-control':
    ds_exp = ds_exp.sel(time=burd_exp.time)
    burd_exp = burd_exp.assign(time=ds_exp.time)
else:
    try:   
        burd_exp = burd_exp.assign(time=ds_exp.time)
    except ValueError:
        ds_exp = ds_exp.sel(time=burd_exp.time)
        burd_exp = burd_exp.assign(time=ds_exp.time)
if np.all(burd_exp.lon.values==ds_exp.lon.values) == False:
    burd_exp = burd_exp.assign_coords(lon=ds_exp.lon)

if np.all(burd_exp.lat.values==ds_exp.lat.values) == False:
    burd_exp = burd_exp.assign_coords(lat=ds_exp.lat)

if mod_id == 'GISS-E2-1-G':
    ds_exp = ds_exp.assign(abs550aer=ds_exp['od550aer']-ds_exp['abs550aer'])

ds_exp = ds_exp.merge(burd_exp)
# ds_exp = ds_exp.isel(time=time_slice)



try:
    ds_exp = total_dust_deposition(ds_exp)
except KeyError:
    ds_exp = ds_exp.assign(totdep=np.nan)

In [None]:


x = ds_exp.cf['X'].name

if ds_exp.lon.max() > 180:
    ds_exp = ds_exp.assign_coords({x:((ds_exp.coords[x] + 180) % 360 - 180)}).sortby(x)
    ds_exp=ds_exp.cf.add_bounds(['lon','lat'])
        


if regrid_temp:
    import xesmf as xe
#     ds_exp_amon = ds_exp_amon.cf.add_bounds(['lon','lat'])
    ds_out = xr.zeros_like(ds_exp['emidust']).to_dataset(name='tas')
    ds_exp_amon = ds_exp_amon.cf.rename_like(ds_out)
    
    rg = xe.Regridder(ds_exp_amon,ds_out, 'bilinear')
    for d in ds_l:
        d = rg(d)
        ds_exp = ds_exp.merge(d)

In [None]:
source_regs = config['dust_source_regions']

ds_exp = ds_exp.compute()

ds_exp = calc_emission_per_source_reg(ds_exp, source_regs)




In [None]:
ds_exp = dust_calc_MEE(ds_exp,odvar='od550dust')
ds_exp = calc_lifetime(ds_exp)
ds_exp = calc_ang4487aer(ds_exp)

In [None]:
ds_exp.to_netcdf(snakemake.output.dust_diag_exp)