# Compare models

1. Load available configurations
2. Load validation predictions
    - calculate absolute error
    - select top N for plotting by MAE from smallest (best) to largest (worst) (top N as specified, default 5)
    - correlation per sample, correlation per feat, correlation overall
    - MAE plots
3. Load test data predictions
    - as for validation data
    - top N based on validation data

In [None]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
import seaborn as sns

import vaep
import vaep.imputation
import vaep.models
from vaep.io import datasplits
from vaep.analyzers import compare_predictions
import vaep.nb

pd.options.display.max_rows = 120
pd.options.display.min_rows = 50
pd.options.display.max_colwidth = 100

logger = vaep.logging.setup_nb_logger()

In [None]:
# catch passed parameters
args = None
args = dict(globals()).keys()

Papermill script parameters:

In [None]:
# files and folders
folder_experiment:str = 'runs/example' # Datasplit folder with data for experiment
folder_data:str = '' # specify data directory if needed
file_format: str = 'csv' # change default to pickled files
fn_rawfile_metadata: str = 'data/dev_datasets/HeLa_6070/files_selected_metadata_N50.csv' # Machine parsed metadata from rawfile workflow
models: str = 'Median,CF,DAE,VAE'  # picked models to compare (comma separated)
plot_to_n:int = 5 # Restrict plotting to top N methods for imputation based on error of validation data, maximum 10

Some argument transformations

In [None]:
args = vaep.nb.get_params(args, globals=globals())
args

In [None]:
args = vaep.nb.args_from_dict(args)
args

In [None]:
figures = {}
dumps = {}

In [None]:
TARGET_COL = 'observed'
METRIC = 'MAE'
MIN_FREQ = None
MODELS_PASSED = args.models.split(',')
MODELS = MODELS_PASSED.copy()

# MODELS = args.models.split(',')
# ORDER_MODELS = ['RSN', *MODELS]

In [None]:
# list(sns.color_palette().as_hex()) # string representation of colors
if args.plot_to_n > 10:
    logger.warning("Set maximum of models to 10 (maximum)")
    args.overwrite_entry('plot_to_n', 10)
COLORS_TO_USE = [sns.color_palette()[5] ,*sns.color_palette()[:5]]

In [None]:
def assign_colors(models):
    color_model_mapping = {
    'CF': sns.color_palette()[1],
    'DAE': sns.color_palette()[2],
    'VAE': sns.color_palette()[3]
    }
    other_colors = [sns.color_palette()[0] ,*sns.color_palette()[4:]]
    i=0
    ret_colors = list()
    for model in models:
        if model in color_model_mapping:
            ret_colors.append(color_model_mapping[model])
        else:
            pos = i % len(other_colors)
            ret_colors.append(other_colors[pos])
            i+=1
    if i > len(other_colors):
        logger.info("Reused some colors!")
    return ret_colors

assign_colors(['CF', 'DAE', 'knn', 'VAE'])

In [None]:
data = datasplits.DataSplits.from_folder(args.data, file_format=args.file_format)

In [None]:
vaep.plotting.make_large_descriptors('x-large')
fig, axes = plt.subplots(1, 2, sharey=True)

vaep.plotting.data.plot_observations(data.val_y.unstack(), ax=axes[0],
                                     title='Validation split',)
vaep.plotting.data.plot_observations(data.test_y.unstack(), ax=axes[1],
                                     title='Test split',)

fig.suptitle("Simulated missing values per sample", size=20)

fname = args.out_figures / 'fake_na_val_test_splits.png'
figures[fname.stem] = fname
vaep.savefig(fig, name=fname)
vaep.plotting.make_large_descriptors('xx-large')

## Across data completeness

In [None]:
# load frequency of training features... 
freq_feat = vaep.io.datasplits.load_freq(args.data, file='freq_features.json')   # needs to be pickle -> index.name needed

freq_feat.head() # training data

In [None]:
prop = freq_feat / len(data.train_X.index.levels[0])
prop.to_frame()

In [None]:
data.to_wide_format()
data.train_X

In [None]:
N_SAMPLES, M_FEAT = data.train_X.shape
print(f"N samples: {N_SAMPLES:,d}, M features: {M_FEAT}")

In [None]:
fname = args.folder_experiment / '01_2_performance_summary.xlsx'
dumps[fname.stem] = fname
writer = pd.ExcelWriter(fname)

# Model specifications
- used for bar plot annotations

In [None]:
import yaml
from vaep.models.collect_dumps import collect, select_content

def load_config_file(fname: Path, first_split='config_') -> dict:
    with open(fname) as f:
        loaded = yaml.safe_load(f)
    key = f"{select_content(fname.stem, first_split=first_split)}"
    return key, loaded


# model_key could be used as key from config file
# load only specified configs?
# case: no config file available?
all_configs = collect(
    paths=(fname for fname in args.out_models.iterdir()
           if fname.suffix == '.yaml'
           and 'model_config' in fname.name),
    load_fn=load_config_file
)
model_configs = pd.DataFrame(all_configs).set_index('model')
model_configs.T.to_excel(writer, sheet_name='model_params')
model_configs.T

Set Feature name (columns are features, rows are samples)

In [None]:
# index name
freq_feat.index.name = data.train_X.columns.name

In [None]:
# index name
sample_index_name = data.train_X.index.name

# Load predictions on validation and test data split


## Validation data
- set top N models to plot based on validation data split

In [None]:
pred_val = compare_predictions.load_split_prediction_by_modelkey(
    experiment_folder=args.folder_experiment,
    split='val',
    model_keys=MODELS_PASSED,
    shared_columns=[TARGET_COL])
pred_val[MODELS]

In [None]:
errors_val = (pred_val
              .drop(TARGET_COL, axis=1)
              .sub(pred_val[TARGET_COL], axis=0)
              [MODELS])
errors_val.describe() # over all samples, and all features

Describe absolute error

In [None]:
errors_val.abs().describe() # over all samples, and all features

## Select top N for plotting and set colors

In [None]:
ORDER_MODELS = (errors_val
                .abs()
                .mean()
                .sort_values()
                .index
                .to_list())
ORDER_MODELS

In [None]:
mae_stats_ordered = errors_val.abs().describe()[ORDER_MODELS]
mae_stats_ordered.to_excel(writer, sheet_name='mae_stats_ordered')
writer.close()
mae_stats_ordered

Hack color order, by assing CF, DAE and VAE unique colors no matter their order
Could be extended to all supported imputation methods

In [None]:
def assign_colors(models):
    color_model_mapping = {
        'CF': sns.color_palette()[2],
        'DAE': sns.color_palette()[3],
        'VAE': sns.color_palette()[4]
    }
    other_colors = [*sns.color_palette()[:2], *sns.color_palette()[5:]]
    i = 0
    ret_colors = list()
    for model in models:
        if model in color_model_mapping:
            ret_colors.append(color_model_mapping[model])
        else:
            pos = i % len(other_colors)
            ret_colors.append(other_colors[pos])
            i += 1
    if i > len(other_colors):
        logger.info("Reused some colors!")
    return ret_colors


expected = [(0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
            (0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
            (0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
            (0.5803921568627451, 0.403921568627451, 0.7411764705882353),
            (1.0, 0.4980392156862745, 0.054901960784313725)]

actual = assign_colors(['CF', 'DAE', 'knn', 'VAE', 'lls'])

assert expected == actual

COLORS_TO_USE = assign_colors(ORDER_MODELS)

In [None]:
# For top_N -> define colors
TOP_N_ORDER = ORDER_MODELS[:args.plot_to_n]

TOP_N_COLOR_PALETTE = {model: color for model,
                       color in zip(TOP_N_ORDER, COLORS_TO_USE)}

TOP_N_ORDER

### Correlation overall

In [None]:
pred_val_corr = pred_val.corr()
ax = (pred_val_corr
      .loc[TARGET_COL, ORDER_MODELS]
      .plot
      .bar(
          # title='Correlation between Fake NA and model predictions on validation data',
          ylabel='correlation overall'))
ax = vaep.plotting.add_height_to_barplot(ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
fname = args.out_figures / 'pred_corr_val_overall.pdf'
figures[fname.stem] = fname
vaep.savefig(ax.get_figure(), name=fname)
pred_val_corr

### Correlation per sample

In [None]:
corr_per_sample_val = (pred_val
                       .groupby(sample_index_name)
                       .aggregate(
                           lambda df: df.corr().loc[TARGET_COL]
                       )[ORDER_MODELS])

kwargs = dict(ylim=(0.7, 1), rot=90,
              # title='Corr. betw. fake NA and model pred. per sample on validation data',
              ylabel='correlation per sample')
ax = corr_per_sample_val[TOP_N_ORDER].plot.box(**kwargs)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45,
                   horizontalalignment='right')
fname = args.out_figures / 'pred_corr_val_per_sample.pdf'
figures[fname.stem] = fname
vaep.savefig(ax.get_figure(), name=fname)

fname = args.out_figures/'pred_corr_val_per_sample.xlsx'
dumps[fname.stem] = fname
with pd.ExcelWriter(fname) as writer:
    corr_per_sample_val.describe().to_excel(writer, sheet_name='summary')
    corr_per_sample_val.to_excel(writer, sheet_name='correlations')

identify samples which are below lower whisker for models

In [None]:
treshold = vaep.pandas.get_lower_whiskers(corr_per_sample_val[TOP_N_ORDER]).min()
mask = (corr_per_sample_val[TOP_N_ORDER] < treshold).any(axis=1)
corr_per_sample_val.loc[mask].style.highlight_min(axis=1) if mask.sum() else 'Nothing to display'

### Error plot

In [None]:
c_error_min = 4.5
mask = (errors_val[MODELS].abs() > c_error_min).any(axis=1)
errors_val.loc[mask].sort_index(level=1)

In [None]:
errors_val = errors_val.abs().groupby(freq_feat.index.name).mean() # absolute error
errors_val = errors_val.join(freq_feat)
errors_val = errors_val.sort_values(by=freq_feat.name, ascending=True)
errors_val

Some interpolated features are missing

In [None]:
errors_val.describe()  # mean of means

In [None]:
c_avg_error = 2
mask = (errors_val[MODELS] >= c_avg_error).any(axis=1)
errors_val.loc[mask]

In [None]:
ax = vaep.plotting.plot_rolling_error(errors_val[TOP_N_ORDER + ['freq']],
                                      metric_name=METRIC,
                                      window=int(len(errors_val)/15),
                                      min_freq=MIN_FREQ,
                                      colors_to_use=COLORS_TO_USE)

In [None]:
fname = args.out_figures / 'performance_methods_by_completness.pdf'
figures[fname.stem] = fname
vaep.savefig(
    ax.get_figure(),
    name=fname)

### Error by non-decimal number of intensity
- number of observations in parentheses. 

In [None]:
ax, errors_binned = vaep.plotting.errors.plot_errors_binned(
    pred_val[
        ['observed']+TOP_N_ORDER
    ],
    palette=TOP_N_COLOR_PALETTE)
fname = args.out_figures / 'errors_binned_by_int_val.pdf'
figures[fname.stem] = fname
vaep.savefig(ax.get_figure(), name=fname)

In [None]:
errors_binned.head()
dumps[fname.stem] = fname.with_suffix('.csv')
errors_binned.to_csv(fname.with_suffix('.csv'))
errors_binned.head()

## test data

In [None]:
pred_test = compare_predictions.load_split_prediction_by_modelkey(
    experiment_folder=args.folder_experiment,
    split='test',
    model_keys=MODELS_PASSED,
    shared_columns=[TARGET_COL])
pred_test = pred_test.join(freq_feat, on=freq_feat.index.name)
SAMPLE_ID, FEAT_NAME = pred_test.index.names
pred_test

### Correlation overall

In [None]:
pred_test_corr = pred_test.corr()
ax = pred_test_corr.loc[TARGET_COL, ORDER_MODELS].plot.bar(
    # title='Corr. between Fake NA and model predictions on test data',
    ylabel='correlation coefficient overall',
    ylim=(0.7,1)
)
ax = vaep.plotting.add_height_to_barplot(ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
fname = args.out_figures / 'pred_corr_test_overall.pdf'
figures[fname.stem] = fname
vaep.savefig(ax.get_figure(), name=fname)
pred_test_corr

### Correlation per sample

In [None]:
corr_per_sample_test = (pred_test
                        .groupby(sample_index_name)
                        .aggregate(lambda df: df.corr().loc[TARGET_COL])
                        [ORDER_MODELS])
corr_per_sample_test = corr_per_sample_test.join(
    pred_test
    .groupby(sample_index_name)[TARGET_COL]
    .count()
    .rename('n_obs')
)
too_few_obs = corr_per_sample_test['n_obs'] < 3
corr_per_sample_test.loc[~too_few_obs].describe()

In [None]:
kwargs = dict(ylim=(0.7,1), rot=90,
              # title='Corr. betw. fake NA and model predictions per sample on test data',
              ylabel='correlation per sample')
ax = (corr_per_sample_test
      .loc[~too_few_obs, TOP_N_ORDER]
      .plot
      .box(**kwargs))
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
fname = args.out_figures / 'pred_corr_test_per_sample.pdf'
figures[fname.stem] = fname
vaep.savefig(ax.get_figure(), name=fname)

dumps[fname.stem] = fname.with_suffix('.xlsx')
with pd.ExcelWriter(fname.with_suffix('.xlsx')) as writer:   
    corr_per_sample_test.describe().to_excel(writer, sheet_name='summary')
    corr_per_sample_test.to_excel(writer, sheet_name='correlations')

identify samples which are below lower whisker for models

In [None]:
treshold = vaep.pandas.get_lower_whiskers(corr_per_sample_test[TOP_N_ORDER]).min()
mask = (corr_per_sample_test[TOP_N_ORDER] < treshold).any(axis=1)
corr_per_sample_test.loc[mask].style.highlight_min(axis=1) if mask.sum() else 'Nothing to display'

In [None]:
feature_names = pred_test.index.levels[-1]
N_SAMPLES = pred_test.index
M = len(feature_names)
pred_test.loc[pd.IndexSlice[:, feature_names[random.randint(0, M)]], :]

In [None]:
options = random.sample(set(feature_names), 1)
pred_test.loc[pd.IndexSlice[:, options[0]], :]

### Correlation per feature

In [None]:
corr_per_feat_test = pred_test.groupby(FEAT_NAME).aggregate(lambda df: df.corr().loc[TARGET_COL])[ORDER_MODELS]
corr_per_feat_test = corr_per_feat_test.join(pred_test.groupby(FEAT_NAME)[
                                   TARGET_COL].count().rename('n_obs'))

too_few_obs = corr_per_feat_test['n_obs'] < 3
corr_per_feat_test.loc[~too_few_obs].describe()

In [None]:
corr_per_feat_test.loc[too_few_obs].dropna(thresh=3, axis=0)

In [None]:
kwargs = dict(rot=90,
              # title=f'Corr. per {FEAT_NAME} on test data',
              ylabel=f'correlation per {FEAT_NAME}')
ax = (corr_per_feat_test
      .loc[~too_few_obs, TOP_N_ORDER]
      .plot
      .box(**kwargs)
      )
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
fname = args.out_figures / 'pred_corr_test_per_feat.pdf'
figures[fname.stem] = fname
vaep.savefig(ax.get_figure(), name=fname)
dumps[fname.stem] = fname.with_suffix('.xlsx')
with pd.ExcelWriter(fname.with_suffix('.xlsx')) as writer:
    corr_per_feat_test.loc[~too_few_obs].describe().to_excel(writer, sheet_name='summary')
    corr_per_feat_test.to_excel(writer, sheet_name='correlations')

In [None]:
feat_count_test = data.test_y.stack().groupby(FEAT_NAME).count()
feat_count_test.name = 'count'
feat_count_test.head()

In [None]:
treshold = vaep.pandas.get_lower_whiskers(corr_per_feat_test[TOP_N_ORDER]).min()
mask = (corr_per_feat_test[TOP_N_ORDER] < treshold).any(axis=1)

def highlight_min(s, color, tolerence=0.00001):
    return np.where((s - s.min()).abs() < tolerence, f"background-color: {color};", None)

view = (corr_per_feat_test
  .join(feat_count_test)
  .loc[mask]
  .sort_values('count'))

if not view.empty:
    display(view
        .style.
        apply(highlight_min, color='yellow', axis=1, subset=corr_per_feat_test.columns)
    )
else:
    print("None found")

### Error plot

In [None]:
metrics = vaep.models.Metrics()
test_metrics = metrics.add_metrics(pred_test.drop('freq', axis=1), key='test data')
test_metrics = pd.DataFrame(test_metrics)[TOP_N_ORDER]
test_metrics

In [None]:
n_in_comparison = int(test_metrics.loc['N'].unique()[0])
n_in_comparison

In [None]:
_to_plot = test_metrics.loc[METRIC].to_frame().T
_to_plot.index = [feature_names.name]
_to_plot

In [None]:
def build_text(s):
    ret = ''
    if not np.isnan(s["latent_dim"]):
        ret += f'LD: {int(s["latent_dim"])} '
    if not np.isnan(s["hidden_layers"]):
        t = ",".join(str(x) for x in s["hidden_layers"])
        ret += f"HL: {t}"
    return ret

text = model_configs[["latent_dim", "hidden_layers"]].apply(
    build_text,
    axis=1)

_to_plot.loc["text"] = text
_to_plot = _to_plot.fillna('')
_to_plot

In [None]:
fig, ax = plt.subplots(figsize=(10,8))
ax = _to_plot.loc[[feature_names.name]].plot.bar(rot=0,
                                                 ylabel=f"{METRIC} for {feature_names.name} (based on {n_in_comparison:,} log2 intensities)",
                                                 # title=f'performance on test data (based on {n_in_comparison:,} measurements)',
                                                 color=COLORS_TO_USE,
                                                 ax=ax,
                                                 width=.8)
ax = vaep.plotting.add_height_to_barplot(ax)
ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc["text"], size=16)
ax.set_xticklabels([])
fname = args.out_figures / 'performance_test.pdf'
figures[fname.stem] = fname
vaep.savefig(fig, name=fname)

In [None]:
dumps[fname.stem] = fname.with_suffix('.csv')
_to_plot_long = _to_plot.T
_to_plot_long = _to_plot_long.rename({feature_names.name: 'metric_value'}, axis=1)
_to_plot_long['data level'] = feature_names.name
_to_plot_long = _to_plot_long.set_index('data level', append=True)
_to_plot_long.to_csv(fname.with_suffix('.csv'))

In [None]:
errors_test = vaep.pandas.calc_errors_per_feat(pred_test.drop("freq", axis=1), freq_feat=freq_feat)[[*TOP_N_ORDER, 'freq']]
errors_test

### Error plot by frequency

In [None]:
ax = vaep.plotting.plot_rolling_error(
    errors_test,
    metric_name=METRIC,
    window=int(len(errors_test)/15),
    min_freq=MIN_FREQ, 
    colors_to_use=COLORS_TO_USE)
fname = args.out_figures / 'errors_rolling_avg_test.pdf'
figures[fname.stem] = fname
vaep.savefig(ax.get_figure(), name=fname)

### Error by non-decimal number of intensity

- number of observations in parentheses. 

In [None]:
ax, errors_bind = vaep.plotting.errors.plot_errors_binned(
    pred_test[
        ['observed']+TOP_N_ORDER
    ],
    palette=TOP_N_COLOR_PALETTE)
fname = args.out_figures / 'errors_binned_by_int_test.pdf'
figures[fname.stem] = fname
vaep.savefig(ax.get_figure(), name=fname)

In [None]:
dumps[fname.stem] = fname.with_suffix('.csv')
errors_bind.to_csv(fname.with_suffix('.csv'))
errors_bind.head()

## Figures dumped to disk

In [None]:
figures

In [None]:
dumps