In [2]:
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
from pyclim_noresm.general_util_funcs import global_avg
from workflow.scripts.utils import (calc_error, 
                                compute_annual_emission_budget, calc_relative_change)
import matplotlib
import numpy as np

In [3]:
keys = list(snakemake.input.keys())

In [4]:
exp_keys = sorted([k for k in keys if 'exp' in k.split('_')])
ctrl_keys = sorted([k for k in keys if 'ctrl' in k.split('_')])
remainingkeys = list(set(keys)-set(exp_keys+ctrl_keys+ ['areacello']))
areapaths = sorted(snakemake.input['areacello'])

In [5]:
data={}
for exp_key, ctrl_key in zip(exp_keys,ctrl_keys):
    exp_paths = sorted(snakemake.input[exp_key])
    ctrl_paths = sorted(snakemake.input[ctrl_key])
    vname = xr.open_dataset(exp_paths[0]).variable_id
    if 'emidust' in exp_key.split('_'):
        iterator = zip(exp_paths, ctrl_paths, areapaths)
    else:
        iterator = zip(exp_paths, ctrl_paths)
    data[vname] = {}
    for paths in iterator:
        
        if 'emidust' in exp_key.split('_'):
            
            ga = xr.open_dataset(paths[2])
            ds_ctrl = xr.open_dataset(paths[1])
            ds_exp = xr.open_dataset(paths[0])
            exp_budget = compute_annual_emission_budget(ds_exp, ga)
            ctrl_budget = compute_annual_emission_budget(ds_ctrl,ga)
            delta_emi = exp_budget-ctrl_budget
            data[vname][ds_ctrl.source_id] = float(delta_emi.values)
        else:
            ds_exp = xr.open_dataset(paths[0]).load()
            ds_ctrl = xr.open_dataset(paths[1]).load()
            diff = calc_relative_change(ds_ctrl, ds_exp)
            diff = global_avg(diff)[vname]
            data[vname][ds_ctrl.source_id] = diff.values 
            

In [12]:
for key in remainingkeys:
    paths = snakemake.input[key]
    for path in paths:
        ds = xr.open_dataset(path)
        vname = list(ds.data_vars)[0]
        source_id = ds.source_id
        if vname not in data.keys():
            data[vname] = {}
        ds = ds.mean(dim='year')
        erf = global_avg(ds)
        data[vname][source_id] = float(erf[vname].values)

In [25]:
df = pd.DataFrame(data).transpose()
df = df.astype(float)
multimodel_mean =  df.mean(axis=1)

std = df.std(axis=1)
error = std/np.sqrt(len(df.columns))

df['Multi model mean'] = multimodel_mean
df['Error'] = error

In [61]:
if snakemake.output.outpath.endswith('csv'):
    df.to_csv(snakemake.output.outpath)
elif snakemake.output.outpath.endswith('tex'):
    with open(snakemake.output.outpath, 'w') as f:
        df.to_latex(buf=f, float_format="%.3f")