In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from workflow.scripts.utils import global_avg, calculate_pooled_variance,diff_means_greater_than_varability
import matplotlib.pyplot as plt
from workflow.scripts.plotting_tools import get_model_colordict
from scipy.stats import t
config = snakemake.config
CI_alpha = snakemake.params.get('CI_alpha', 0.05)

dst_interest_regs = config['dust_interest_regions']

ctrl_cld_diag_dsets = {p.split('/')[-1].split('_')[3]: xr.open_dataset(p) for p in snakemake.input.ctrl_clddiag}
exp_cld_diag_dsets = {p.split('/')[-1].split('_')[3]: xr.open_dataset(p) for p in snakemake.input.exp_clddiag}


ctrl_ddiag = {p.split('/')[-1].split('_')[2]: xr.open_dataset(p) for p in snakemake.input.ctrl_ddiag}
exp_ddiag = {p.split('/')[-1].split('_')[2]: xr.open_dataset(p) for p in snakemake.input.exp_ddiag}


xr.set_options(keep_attrs=True)

In [None]:
def create_source_region_df(dset_dict: dict, 
                            interest_regions: dict,
                            variables: list):
    
    # print(df)
    # return df
    max_time = max([len(dset_dict[mod_id].time) for mod_id in dset_dict])

    out_dset = xr.Dataset(
                    coords=dict(region=('region',list(interest_regions.keys())),model=('model',list(dset_dict.keys())),time=('time',np.arange(max_time))),
                    attrs=dict(description='Dust diagnostics averaged over regions of interest')
                    )
    # return out_dset

    for mod_id, ds in dset_dict.items():
        if ds.lon.max() > 180:
            ds = ds.assign_coords({'lon':((ds.coords['lon'] + 180) % 360 - 180)}).sortby('lon')
            ds = ds.cf.add_bounds(['lon','lat'])
            dset_dict[mod_id] = ds

#         print(dset_dict[mod_id].lon)
    for dvar in variables:
        da_list = []
        attrs = dset_dict[mod_id][dvar].attrs.copy()
        for mod_id in dset_dict:
            _da = xr.DataArray(data=np.nan, dims=['region', 'model', 'time'], 
                                coords=dict(region=out_dset.region,
                                        model=("model",[mod_id]),
                                        time=out_dset.time))
            _da = _da.rename(dvar)
            scale_f = 1
            if dvar in ['lwp','clivi']:
                scale_f = 1000
                attrs['units'] = 'g m-2'
            for i ,source_region in enumerate(interest_regions):
                lonmax = interest_regions[source_region]['lonmax']
                latmax = interest_regions[source_region]['latmax']
                lonmin = interest_regions[source_region]['lonmin']
                latmin = interest_regions[source_region]['latmin']
                
                 
                _temp = dset_dict[mod_id][dvar].sel(lon=slice(lonmin,lonmax),lat=slice(latmin,latmax))
            
                # _temp = global_avg(_temp)
                _temp = _temp.mean(dim=['lat','lon'])
                # return _temp, _da
                if len(_da[i,0,:]) != len(_temp):
                    _da[i,0,:len(_temp)] = _temp.values[:]*scale_f
                else: 
                    _da[i,0,:] = _temp.values[:]*scale_f
            
            da_list.append(_da)
        da = xr.concat(da_list,dim='model')
        
        da.attrs['long_name'] = attrs.get('long_name','')
        da.attrs['units'] = attrs.get('units','')
        out_dset = out_dset.assign({dvar:da})

    return out_dset

In [None]:
def make_significant_mask(dd_exp, dd_ctrl, CI_alpha=0.05):
    gtthvar = xr.zeros_like(dd_exp.isel(time=0))
    t_vals = xr.zeros_like(dd_exp.isel(time=0))
    for var in dd_exp.data_vars:
        for i,reg in enumerate(dd_exp.region):
            for j,mod in enumerate(dd_exp.model):
                res = diff_means_greater_than_varability(dd_exp[var][i,j,:], dd_ctrl[var][i,j,:])
                gtthvar[var][i,j] = res[0]
                t_crit = t.ppf(q=1-CI_alpha/2, df=len(dd_exp[var][i,j,:])+ len(dd_ctrl[var][i,j,:]-2))

                t_vals[var][i,j] = np.abs(res[1]) > t_crit
    return gtthvar.astype('bool'), t_vals.astype('bool')

In [None]:
dust_diag_vars = ['tas', 'od550aer','abs550aer', 'concdust']

cld_diag_vars = ['pr','lwp','cl_high','cl_low','cl_middle','clivi','clt']

dust_diag_exp = create_source_region_df(exp_ddiag,dst_interest_regs, dust_diag_vars)
dust_diag_ctrl = create_source_region_df(ctrl_ddiag,dst_interest_regs, dust_diag_vars)

cld_diag_exp = create_source_region_df(exp_cld_diag_dsets,dst_interest_regs, cld_diag_vars)
cld_diag_ctrl = create_source_region_df(ctrl_cld_diag_dsets,dst_interest_regs, cld_diag_vars)

In [None]:
gtthvar_cld, two_sided_tt_cld = make_significant_mask(cld_diag_exp, cld_diag_ctrl)
gtthvar_dd, two_sided_tt_dd = make_significant_mask(dust_diag_exp, dust_diag_ctrl)

In [None]:
m_cld_diag_exp = cld_diag_exp.mean(dim='time')
m_cld_diag_ctrl = cld_diag_ctrl.mean(dim='time')
m_dust_diag_exp = dust_diag_exp.mean(dim='time')
m_dust_diag_ctrl = dust_diag_ctrl.mean(dim='time')

In [None]:
diff_cld_diag = m_cld_diag_exp - m_cld_diag_ctrl
diff_dust_diag = m_dust_diag_exp - m_dust_diag_ctrl


In [None]:
rel_diff = (diff_dust_diag/m_dust_diag_ctrl)*100
rel_diff_cld = (diff_cld_diag/m_cld_diag_ctrl)*100

In [None]:
def plot_dust_regional_effects(diag_abs: xr.Dataset,
                        dia_rel: xr.Dataset,
                        grid_spec: np.ndarray,
                        variables: list,
                        xticks = [-50,0,50,100],
                        minorticks = [-25,25,75],
                        xticklabels=['-50','0','50','100'],
                        final=False,
                        first=False,
                        start_row=0,
                        start_col=1,
                        xlim=(-75,125),
                        signif_mask=None):
    """
    
    
    
    """

    def format_number(number):
        abs_number = abs(number)

        if abs_number < 1e-7:  # Adjust this threshold as needed
            return "0"

        if abs_number < 0.01:
            formatted = f"{number:.1e}"
            parts = formatted.split("e")
            exponent = int(parts[1])
            return f"{parts[0]}e{exponent:+}"


        if abs_number >= 1000:
            return f"{number:.0f}"

        if abs_number < 0.1:
            return f"{number:.2f}"        

        return f"{number:.1f}"

    def _set_lims(axt, nmodels):
        axt.set_ylim(0,(nmodels*2)/10+.2)
        axt.set_xlim(xlim)
        axt.set_yticklabels([])
        axt.set_yticks([])
        axt.set_xticklabels([])
        axt.set_xticks([])
        axt.spines['bottom'].set_visible(False)
        axt.spines['top'].set_visible(False)
        axt.spines['right'].set_visible(False)
        axt.spines['left'].set_visible(False)
        axt.axvline(0, color='#A9A9A9', linewidth=1, linestyle='--')
    colors = get_model_colordict()
    nmodels = dia_rel.model.size
    regions = diag_abs.region
    fig = plt.gcf()
    axes = np.zeros((len(variables),len(regions)), dtype=object)
    for i, region in enumerate(regions):
        for j, var in enumerate(variables):
            ax_i = fig.add_subplot(grid_spec[j+start_row,i+start_col])
            
            
            if j == 0 and first:
                reg = region.values
                reg = str(reg)
                if 'and' in reg.split(' '):
                    prts = reg.split('and')
                    reg = '\n'.join(prts)
                        
                ax_i.set_title(reg, fontsize=8)
                # ax_i.spines['top'].set_visible(True)
            _set_lims(ax_i,nmodels)
            if j == len(variables)-1 and final:
                ax_i.spines['bottom'].set_visible(True)
                ax_i.set_xticks(xticks)
                ax_i.set_xticklabels(xticklabels)
                ax_i.set_xticks(minorticks, minor=True)
                ax_i.set_xlabel('Relative change (%)', fontsize=6)
            
            axes[j,i] = ax_i
    for i, region in enumerate(regions):
        _diag_t_abs = diag_abs.sel(region=region)[variables].to_dataframe().drop(columns=['region'])
        _diag_t_rel = dia_rel.sel(region=region)[variables].to_dataframe().drop(columns=['region'])
        for j, var in enumerate(variables):
            spc= 0.2
            for k,mod in enumerate(_diag_t_rel.index):
                bar=axes[j,i].barh(spc,_diag_t_rel.loc[mod,var],height=0.2,color=colors[mod])
                bar = [*bar][0]
                xval = bar.get_width()
                # if xval < 0:
                    # x = 0
                # else:
                if signif_mask is not None:
                    if signif_mask[var][i,k]:
                        x = xlim[1] 
                        axes[j,i].text(x, bar.get_y() + bar.get_height()/2,f'{format_number(_diag_t_abs.loc[mod,var])}'
                                ,ha='right', va='center', fontsize=7)
                else:
                    x = xlim[1] 
                    axes[j,i].text(x, bar.get_y() + bar.get_height()/2,f'{format_number(_diag_t_abs.loc[mod,var])}'
                            ,ha='right', va='center', fontsize=7)

                spc=spc+0.2


    return axes




In [None]:
transelate_vars = dict(
    od550aer='AOD$_{550}$ \n (all aerosols)',
    abs550aer='AAOD$_{550}$\n (all aerosols)',
    pr = 'Precipitation \n (mm/year)',
    lwp = 'LWP \n (g/m$^2$)',
    clivi = 'IWP \n (g/m$^2$)',
    cl_high = 'CLDfrac$_{high}$ \n (%)',
    cl_middle = 'CLDfrac$_{mid}$ \n (%)',
    cl_low = 'CLDfrac$_{low}$ \n (%)',
)

In [None]:
def add_label(var,gs, row, start_col=0):
    axt = fig.add_subplot(gs[row,start_col])
    axt.spines['bottom'].set_visible(False)
    axt.spines['top'].set_visible(False)
    axt.spines['right'].set_visible(False)
    axt.spines['left'].set_visible(False)
    axt.set_xticks([])
    axt.set_yticks([])
    axt.set_yticklabels([])
    axt.set_xticklabels([])
    axt.text(0.5,0.5,transelate_vars.get(var, var),fontsize=8,ha='center',va='center', rotation=90)    


In [None]:
import matplotlib.gridspec as gridspec
fig = plt.figure(figsize=(8.3,10.8), constrained_layout=True)

gs = gridspec.GridSpec(9,7,height_ratios=[1,1,0.45,1,1,1,1,1,1], width_ratios=[0.3,1,1,1,1,1,1], figure=fig)

ddust_vars = ['od550aer','abs550aer']

ddust_cld_vars = ['pr','lwp','clivi','cl_high','cl_middle','cl_low']

axes = plot_dust_regional_effects(diff_dust_diag,rel_diff,gs,ddust_vars, final=True, first=True,
    xlim=(-10,110), xticks=[0,50,100], xticklabels=['0','50','100'],minorticks=[25,75],
    signif_mask=two_sided_tt_dd)

# gs.update(hspace=0.1)
axes1 = plot_dust_regional_effects(diff_cld_diag,rel_diff_cld,gs,ddust_cld_vars, final=True, first=False, start_row=3,
                            xlim=(-15,15), xticks=[-10,0,10],minorticks=[-5,5], xticklabels=['-10','0','10'],
                            signif_mask=two_sided_tt_cld)

gs.update(hspace=0.01, wspace=0.06)

add_label('od550aer',gs,0)
add_label('abs550aer',gs,1)
add_label('pr',gs,3)
add_label('lwp',gs,4)
add_label('clivi',gs,5)
add_label('cl_high',gs,6)
add_label('cl_middle',gs,7)
add_label('cl_low',gs,8)


from matplotlib.patches import Patch
from matplotlib.lines import Line2D

colors = get_model_colordict()

legments = [
    Line2D([0],[0], markerfacecolor=c, marker='s', label= m, color = 'w', markersize=10)
    for m, c in colors.items()
]

# legments.append(Line2D([0],[0], markerfacecolor='#FF005E', marker='*', label= 'Model mean', color = 'w',
#                     markeredgecolor='k',markersize=20))

for i in range(axes.shape[0]):
    if i % 2 == 0:
        for ax in axes[i,:]:
            ax.set_facecolor('#F0F0F0')


for i in range(axes1.shape[0]):
    if i % 2 == 0:
        for ax in axes1[i,:]:
            ax.set_facecolor('#F0F0F0')



fig.legend(handles=legments,ncol=3, loc='lower center',frameon=False, bbox_to_anchor=(0.555,0.015), fontsize=8)
fig.suptitle('Change between piClim-2xdust and piClim-control', fontsize=12, y=0.93, x=0.555)
plt.tight_layout()

plt.savefig(snakemake.output.outpath,bbox_inches='tight')
