In [None]:
import json
from pathlib import Path
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import vaep.pandas
import vaep.nb

import logging
from vaep.logging import setup_logger
logger = setup_logger(logger=logging.getLogger('vaep'))

sns.set_theme()

plt.rcParams['figure.figsize'] = [16.0, 7.0]

In [None]:
IDX =[['proteinGroups', 'aggPeptides', 'evidence'],
      ['median', 'interpolated', 'collab', 'DAE', 'VAE']]

In [None]:
def select_content(s:str):
    s = s.split('metrics_')[1]
    assert isinstance(s, str), f"More than one split: {s}"
    model, repeat = s.split('_')
    return model, int(repeat)
    
test_cases = ['model_metrics_DAE_0',
              'model_metrics_VAE_3',
              'model_metrics_collab_2']
 
for test_case in test_cases:
    print(f"{test_case} = {select_content(test_case)}")

In [None]:
all_metrics = {}
for fname in snakemake.input.metrics:
    fname = Path(fname)
    logger.info(f"Load file: {fname = }")
    model, repeat = select_content(fname.stem)
    # key = f"{fname.parents[1].name}_{model}_{repeat}"
    key = (fname.parents[1].name, repeat)
    # if key in all_metrics:
    #     raise KeyError(f"Key already in use: {key}")
        
    logger.debug(f"{key = }")
    with open(fname) as f:
        loaded = json.load(f)
    loaded = vaep.pandas.flatten_dict_of_dicts(loaded)
    # all_metrics[key] = loaded
    if key not in all_metrics:
        all_metrics[key] = loaded
        continue
    for k, v in loaded.items():
        if k in all_metrics[key]:
            logger.debug(f"Found existing key: {k = } ")
            assert all_metrics[key][k] == v, "Diverging values for {k}: {v1} vs {v2}".format(
                k=k,
                v1=all_metrics[key][k],
                v2=v)
        else:
            all_metrics[key][k] = v
        # raise ValueError()
metrics = pd.DataFrame(all_metrics).T
metrics.index.names = ('data level', 'repeat')
metrics

In [None]:
FOLDER = fname.parent.parent.parent
FOLDER

In [None]:
metrics = metrics.T.sort_index().loc[pd.IndexSlice[['NA interpolated', 'NA not interpolated'],
                                         ['valid_fake_na', 'test_fake_na'],
                                         ['median', 'interpolated', 'collab', 'DAE', 'VAE'],
                                         :]]
metrics.to_csv(FOLDER/ "metrics.csv")
metrics.to_excel(FOLDER/ "metrics.xlsx")
metrics

In [None]:
level, split = 'NA interpolated', 'valid_fake_na'
selected = metrics.loc[pd.IndexSlice[level,
                          split,
                          :, :]].stack()
selected

In [None]:
to_plot = selected.loc[level].loc[split].loc[pd.IndexSlice[:,'MAE',:]]
to_plot = to_plot.stack().unstack('repeat').T.describe().loc[['mean','std']].T.unstack(0)
to_plot = to_plot.loc[IDX[0], pd.IndexSlice[:, IDX[1]]]
to_plot.to_csv(FOLDER/ "model_performance_repeated_runs_avg.csv")
to_plot.to_excel(FOLDER/ "model_performance_repeated_runs_avg.xlsx")
to_plot

In [None]:
ax = to_plot['mean'].plot.bar(rot=0, width=.8, yerr=to_plot['std'])

In [None]:
level, split = 'NA interpolated', 'valid_fake_na'
selected = metrics.loc[pd.IndexSlice[level,
                          split,
                          :, 'MAE']].stack(1)
selected.index.names = ('x', 'split', 'model', 'metric', 'repeat')
# # selected.reset_index()
selected.stack().to_frame('MAE').reset_index()

In [None]:
fig = sns.barplot(x='data level',
            y='MAE',
            hue='model',
            order = IDX[0],
            ci=95,
            data=selected.stack().to_frame('MAE').reset_index())
fig = ax.get_figure()

In [None]:
vaep.savefig(fig, FOLDER/ "model_performance_repeated_runs.pdf" )