In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import pandas as pd
import xesmf as xe
import matplotlib as mpl
from workflow.scripts.utils import global_avg,regrid_global


In [None]:
ds_exp = {p.split("_")[-2]: xr.open_dataset(p).isel(time=slice(1,None)) for p in snakemake.input.exp_data}
ds_ctrl = {p.split("_")[-2]: xr.open_dataset(p) for p in snakemake.input.ctrl_data}

In [None]:
dust_source_regions=list(snakemake.config['dust_source_regions'].keys())
source_regions_dict = snakemake.config['dust_source_regions']

In [None]:
def create_source_region_df(dset_dict, 
                            source_regions, 
                            var_id='emidust'):
    df = pd.DataFrame(index=dset_dict.keys(), columns= source_regions)
    for mod_id in dset_dict:
        for source_region in source_regions:
            df.loc[mod_id,source_region] = dset_dict[mod_id][f'{source_region} {var_id}'].mean().values
    df = df.astype(float).T
    return df


In [None]:
df_source = create_source_region_df(ds_exp, dust_source_regions)

In [None]:
def multi_model_mean(ds_dict, var_id, common_grid=None):
    dsets = []
    for modelN,ds in ds_dict.items():
        temp_ds = ds.mean(dim='time')
        temp_ds = temp_ds.drop(labels=['wavelength','height','member_id'],errors='ignore')
        temp_ds = regrid_global(temp_ds[[var_id]],ds_out=common_grid)
        dsets.append(temp_ds)
    # return dsets
    return xr.concat(dsets, dim='model').mean(dim='model')
        
            

In [None]:
ga = xr.open_dataset(snakemake.input.universial_area_mask)
# ga = ga.cf.add_bounds(['lon','lat'])
mm_dust_ctrl =multi_model_mean(ds_ctrl, 'emidust', common_grid=ga)
mm_dust_exp =multi_model_mean(ds_exp, 'emidust', common_grid=ga)


In [None]:
context_dict = {
    'axes.labelsize':7.5,
    'axes.spines.left': False,
    'axes.spines.right': False,
    'axes.spines.top': False,
    'axes.spines.bottom': False,
    'xtick.labelsize': 6,
    'ytick.labelsize': 7
}
import seaborn as sns

# Get the Set3 palette with 10 colors
palette = sns.color_palette("colorblind", 10)

color_dict = {
    'Western north Africa': palette[0],
    'Eastern north Africa': palette[1],
    'Sahel': palette[2],
    'Middle east': palette[3],
    'Central Asia': palette[4],
    'East Asia': palette[5],
    'Southern Africa': palette[6],
    'North America': palette[7],
    'Australia': palette[8],
    'South America': palette[9]
}


In [None]:
import numpy as np
import statsmodels.api as sm
def get_treshold(data,surfFrac=0.8):
    data_1d = data.values.ravel()
    edcf = sm.distributions.ECDF(data_1d)
    filtered = np.where(edcf.y>surfFrac, edcf.x, np.nan)
    filtered = filtered[~np.isnan(filtered)]
    return filtered.min()

In [None]:
mm_dust_ctrl_da =mm_dust_ctrl['emidust']*ga['cell_area']*1e-9
mm_dust_exp_da =mm_dust_exp['emidust']*ga['cell_area']*1e-9
thresh=get_treshold(mm_dust_ctrl_da, surfFrac=0.9)
mm_dust_ctrl_da = mm_dust_ctrl_da.where(mm_dust_ctrl_da>1e-2)
mm_dust_exp_da = mm_dust_exp_da.where(mm_dust_exp_da>1e-2)

In [None]:
with mpl.rc_context(context_dict):
    fig = plt.figure(figsize=(8.3,5.8))
    gs1 = fig.add_gridspec(4,5,
                          width_ratios=[1,1,0.1,1,1])

    gs1.update(top=0.90,left=0.14, bottom=0.2,wspace=.05)
    ax1 = fig.add_subplot(gs1[:2,:2],projection=ccrs.Robinson())
    cmap = mpl.colormaps.get_cmap('YlOrRd').resampled(14)
    norm = mpl.colors.LogNorm(vmin=0.1, vmax=2e1)
    mm_dust_ctrl_da.plot(ax=ax1, transform=ccrs.PlateCarree(),cmap=cmap, norm=norm,
                         cbar_kwargs={'label':'Model mean dust emission 1850 (Tg yr$^{-1}$)','orientation':'horizontal','shrink':0.7,'pad':0.035,
                                      'ticks':[0.1,1,10,20],'aspect':30,'extend':'max',
                                      'format':'%.1f'})
    ax1.coastlines()
    for source_region in source_regions_dict:
        ax1.add_artist(
            mpl.patches.Rectangle((source_regions_dict[source_region]['lonmin'],
                                  source_regions_dict[source_region]['latmin']),width=source_regions_dict[source_region]['lonmax']-source_regions_dict[source_region]['lonmin']-0.5,
                                  height=source_regions_dict[source_region]['latmax']-source_regions_dict[source_region]['latmin']-0.6,facecolor='none',
                                  edgecolor=color_dict[source_region],transform=ccrs.PlateCarree(),linewidth=1.5)
            )
    ax2 = fig.add_subplot(gs1[:2,3:],projection=ccrs.Robinson())

    mm_dust_ctrl_da = mm_dust_ctrl_da.where(mm_dust_ctrl_da>thresh)
    mm_dust_exp_da = mm_dust_exp_da.where(mm_dust_exp_da>thresh)

    rel_cange = (mm_dust_exp_da-mm_dust_ctrl_da)/mm_dust_ctrl_da*100
    cmap = mpl.cm.get_cmap('RdYlBu_r').resampled(14)
    rel_cange.plot(ax=ax2, transform=ccrs.PlateCarree(),cmap=cmap,vmin=60  ,vmax=140,
                         cbar_kwargs={'label':'Relative change in dust emission 2xdust-ctrl (%)','orientation':'horizontal','shrink':0.7,'pad':0.035,
                                      'ticks':[70,90,110,130],'aspect':30,'extend':'both',
                                      'format':'%.0f'})
    ax2.coastlines()
    df_color = df_source.index.map(color_dict)
    def my_autopct(pct):
        return '{:1.0f}%'.format(pct) if pct > 3 else ''
    i = 0
    t_i = 1
    gs2 = fig.add_gridspec(2,5,
                            width_ratios=[1,1,1,1,1],
                            )
    gs2.update(top=0.46,left=0.13)
    for name, data in df_source.iteritems():
        # print(i, data[1])

        if i > 4:

            tax = fig.add_subplot(gs2[1, t_i])
            t_i = t_i + 1
        else:
            tax = fig.add_subplot(gs2[0, i])
        data.plot.pie(ax=tax,colors=df_color, labels=None, autopct=my_autopct, textprops={'fontsize': 6.5}, pctdistance=1.2,)
        tax.set_title(name, fontsize=8)
        tax.set_ylabel('')
        i = i + 1
    from matplotlib.patches import Patch
    legelem = [Patch(facecolor=color_dict[source], label=source) for source in color_dict]
    fig.legend(handles=legelem, loc='lower center', ncol=5, fontsize=7, bbox_to_anchor=(0.5, 0.02), frameon=False)

    fig.text(0.1,0.9, 'a)',  fontsize=12, fontweight='bold', va='top', ha='left')
    fig.text(0.51,0.9, 'b)', fontsize=12, fontweight='bold', va='top', ha='left')
    fig.text(0.1,0.5, 'c)', fontsize=12, fontweight='bold', va='top', ha='left')
    plt.savefig(snakemake.output.outpath, dpi=300, bbox_inches='tight')