# Compare models

In [None]:
import logging
import random
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
import seaborn as sns

pd.options.display.max_rows = 120
pd.options.display.min_rows = 50

import vaep
import vaep.imputation
from vaep import sampling
import vaep.models
from vaep.io import datasplits
from vaep.analyzers import compare_predictions


import vaep.nb
matplotlib.rcParams['figure.figsize'] = [10.0, 8.0]


logger = vaep.logging.setup_nb_logger()

In [None]:
models = ['collab', 'DAE', 'VAE']
ORDER_MODELS = ['random shifted normal', 'median', 'interpolated',
                'collab', 'DAE', 'VAE',
                ]

In [None]:
# files and folders
folder_experiment:str = 'runs/experiment_03/df_intensities_proteinGroups_long_2017_2018_2019_2020_N05015_M04547/Q_Exactive_HF_X_Orbitrap_Exactive_Series_slot_#6070' # Datasplit folder with data for experiment
folder_data:str = '' # specify data directory if needed
file_format: str = 'pkl' # change default to pickled files
fn_rawfile_metadata: str = 'data/files_selected_metadata.csv' # Machine parsed metadata from rawfile workflow

In [None]:
# # Parameters
# fn_rawfile_metadata = "data/ALD_study/processed/raw_meta.csv"
# folder_experiment = "runs/appl_ald_data/plasma/proteinGroups"

In [None]:
args = vaep.nb.Config()

args.fn_rawfile_metadata = fn_rawfile_metadata
del fn_rawfile_metadata

args.folder_experiment = Path(folder_experiment)
del folder_experiment
args.folder_experiment.mkdir(exist_ok=True, parents=True)

args.file_format = file_format
del file_format

args = vaep.nb.add_default_paths(args, folder_data=folder_data)
del folder_data

args

In [None]:
data = datasplits.DataSplits.from_folder(args.data, file_format=args.file_format)

In [None]:
fig, axes = plt.subplots(1, 2, sharey=True)

ax = axes[0]
_ = data.val_y.unstack().notna().sum(axis=1).sort_values().plot(
        ax=ax,
        title='Validation data',
        ylabel='number of feat')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')

ax = axes[1]
_ = data.test_y.unstack().notna().sum(axis=1).sort_values().plot(
        ax=ax,
        title='Test data')
fig.suptitle("Fake NAs per sample availability.", size=24)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
vaep.savefig(fig, name='fake_na_val_test_splits', folder=args.out_figures)

## Across data completeness

In [None]:
# freq_feat = sampling.frequency_by_index(data.train_X, 0)
# freq_feat.name = 'freq'
freq_feat = vaep.io.datasplits.load_freq(args.data, file='freq_train.json')   # needs to be pickle -> index.name needed

freq_feat.head() # training data

In [None]:
prop = freq_feat / len(data.train_X.index.levels[0])
prop.to_frame()

# reference methods

- drawing from shifted normal distribution
- drawing from (-) normal distribution?
- median imputation

In [None]:
data.to_wide_format()
data.train_X

In [None]:
N_SAMPLES, M_FEAT = data.train_X.shape
print(f"N samples: {N_SAMPLES:,d}, M features: {M_FEAT}")

In [None]:
mean = data.train_X.mean()
std = data.train_X.std()

imputed_shifted_normal = vaep.imputation.impute_shifted_normal(data.train_X, mean_shift=1.8, std_shrinkage=0.3, axis=0)
imputed_shifted_normal

In [None]:
medians_train = data.train_X.median()
medians_train.name = 'median'

# Model specifications

In [None]:
import yaml 
def select_content(s:str, stub='metrics_'):
    s = s.split(stub)[1]
    assert isinstance(s, str), f"More than one split: {s}"
    entries = s.split('_')
    if len(entries) > 1:
        s = '_'.join(entries[:-1])
    return s

from functools import partial


all_configs = {}
for fname in args.out_models.iterdir():
    fname = Path(fname)
    if fname.suffix != '.yaml':
        continue
    # "grandparent" directory gives name beside name of file
    key = f"{select_content(fname.stem, 'config_')}"
    print(f"{key = }")
    with open(fname) as f:
        loaded = yaml.safe_load(f)   
    if key not in all_configs:
        all_configs[key] = loaded
        continue
    for k, v in loaded.items():
        if k in all_configs[key]:
            if not all_configs[key][k] == v:
                print(
                    "Diverging values for {k}: {v1} vs {v2}".format(
                k=k,
                v1=all_configs[key][k],
                v2=v)
                )
        else:
            all_configs[key][k] = v

model_configs = pd.DataFrame(all_configs).T
model_configs.T

# load predictions

- calculate correlation -> only makes sense per feature (and than save overall correlation stats)

## test data

In [None]:
freq_feat.index.name = data.train_X.columns.name

In [None]:
split = 'test'
pred_files = [f for f in args.out_preds.iterdir() if split in f.name]
pred_test = compare_predictions.load_predictions(pred_files)
# pred_test = pred_test.join(medians_train, on=prop.index.name)
pred_test['random shifted normal'] = imputed_shifted_normal
pred_test = pred_test.join(freq_feat, on=freq_feat.index.name)
SAMPLE_ID, FEAT_NAME = pred_test.index.names
pred_test

In [None]:
pred_test_corr = pred_test.corr()
ax = pred_test_corr.loc['observed', ORDER_MODELS].plot.bar(
    # title='Corr. between Fake NA and model predictions on test data',
    ylabel='correlation coefficient overall',
    ylim=(0.7,1)
)
ax = vaep.plotting.add_height_to_barplot(ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
vaep.savefig(ax.get_figure(), name='pred_corr_test_overall', folder=args.out_figures)
pred_test_corr

In [None]:
corr_per_sample_test = pred_test.groupby('Sample ID').aggregate(lambda df: df.corr().loc['observed'])[ORDER_MODELS]
corr_per_sample_test = corr_per_sample_test.join(pred_test.groupby('Sample ID')[
                                       'median'].count().rename('n_obs'))
too_few_obs = corr_per_sample_test['n_obs'] < 3
corr_per_sample_test.loc[~too_few_obs].describe()

In [None]:
kwargs = dict(ylim=(0.7,1), rot=90,
              # title='Corr. betw. fake NA and model predictions per sample on test data',
              ylabel='correlation per sample')
ax = corr_per_sample_test.plot.box(**kwargs)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
vaep.savefig(ax.get_figure(), name='pred_corr_test_per_sample', folder=args.out_figures)
with pd.ExcelWriter(args.out_figures/'pred_corr_test_per_sample.xlsx') as writer:   
    corr_per_sample_test.describe().to_excel(writer, sheet_name='summary')
    corr_per_sample_test.to_excel(writer, sheet_name='correlations')

identify samples which are below lower whisker for models

In [None]:
treshold = vaep.pandas.get_lower_whiskers(corr_per_sample_test[models]).min()
mask = (corr_per_sample_test[models] < treshold).any(axis=1)
corr_per_sample_test.loc[mask].style.highlight_min(axis=1)

In [None]:
feature_names = pred_test.index.levels[-1]
N_SAMPLES = pred_test.index
M = len(feature_names)
pred_test.loc[pd.IndexSlice[:, feature_names[random.randint(0, M)]], :]

In [None]:
options = random.sample(set(freq_feat.index), 1)
pred_test.loc[pd.IndexSlice[:, options[0]], :]

### Correlation per feature

In [None]:
corr_per_feat_test = pred_test.groupby(FEAT_NAME).aggregate(lambda df: df.corr().loc['observed'])[ORDER_MODELS]
corr_per_feat_test = corr_per_feat_test.join(pred_test.groupby(FEAT_NAME)[
                                   'observed'].count().rename('n_obs'))

too_few_obs = corr_per_feat_test['n_obs'] < 3
corr_per_feat_test.loc[~too_few_obs].describe()

In [None]:
corr_per_feat_test.loc[too_few_obs].dropna(thresh=3, axis=0)

In [None]:
kwargs = dict(rot=90,
              # title=f'Corr. per {FEAT_NAME} on test data',
              ylabel=f'correlation per {FEAT_NAME}')
ax = corr_per_feat_test.loc[~too_few_obs].drop(
    'n_obs', axis=1).plot.box(**kwargs)
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
vaep.savefig(ax.get_figure(), name='pred_corr_test_per_feat', folder=args.out_figures)
with pd.ExcelWriter(args.out_figures/'pred_corr_test_per_feat.xlsx') as writer:
    corr_per_feat_test.loc[~too_few_obs].describe().to_excel(writer, sheet_name='summary')
    corr_per_feat_test.to_excel(writer, sheet_name='correlations')

In [None]:
feat_count_test = data.test_y.stack().groupby(FEAT_NAME).count()
feat_count_test.name = 'count'
feat_count_test.head()

In [None]:
treshold = vaep.pandas.get_lower_whiskers(corr_per_feat_test[models]).min()
mask = (corr_per_feat_test[models] < treshold).any(axis=1)

def highlight_min(s, color, tolerence=0.00001):
    return np.where((s - s.min()).abs() < tolerence, f"background-color: {color};", None)
corr_per_feat_test.join(feat_count_test).loc[mask].sort_values('count').style.apply(highlight_min, color='yellow', axis=1, subset=corr_per_feat_test.columns) 

In [None]:
metrics = vaep.models.Metrics(no_na_key='NA interpolated', with_na_key='NA not interpolated')
test_metrics = metrics.add_metrics(pred_test.drop('freq', axis=1), key='test data')
test_metrics = pd.DataFrame(test_metrics["NA interpolated"])[ORDER_MODELS]
test_metrics

In [None]:
n_in_comparison = int(test_metrics.loc['N'].unique()[0])
n_in_comparison

In [None]:
METRIC = 'MAE'
_to_plot = test_metrics.loc[METRIC].to_frame().T
_to_plot.index = [feature_names.name]
_to_plot

In [None]:
text = model_configs[["latent_dim", "hidden_layers"]].apply(
    lambda s: f'LD: {s["latent_dim"]:3} '
              f'- HL: {",".join(str(x) for x in s["hidden_layers"]) if s["hidden_layers"] is not np.nan else "-"}',
    axis=1)
text = text.rename({'dae': 'DAE', 'vae': 'VAE'})

_to_plot.loc["text"] = text
_to_plot = _to_plot.fillna('')
_to_plot

In [None]:
colors_to_use = [sns.color_palette()[5] ,*sns.color_palette()[:5]]
# list(sns.color_palette().as_hex()) # string representation of colors
sns.color_palette() # select colors for comparibility with grid search (where random shifted was omitted)

In [None]:
fig, ax = plt.subplots(figsize=(10,8))
ax = _to_plot.loc[[feature_names.name]].plot.bar(rot=0,
                                                 ylabel=f"{METRIC} for {feature_names.name} (based on {n_in_comparison:,} log2 intensities)",
                                                 # title=f'performance on test data (based on {n_in_comparison:,} measurements)',
                                                 color=colors_to_use,
                                                 ax=ax,
                                                 width=.8)
ax = vaep.plotting.add_height_to_barplot(ax)
ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc["text"], size=16)
ax.set_xticklabels([])
vaep.savefig(fig, "performance_models_test", folder=args.out_figures)

In [None]:
errors_test = vaep.pandas.calc_errors_per_feat(pred_test.drop("freq", axis=1), freq_feat=freq_feat)[[*ORDER_MODELS, 'freq']]
errors_test

In [None]:
def plot_rolling_error(errors: pd.DataFrame, metric_name, window: int = 200,
                       min_freq=None, freq_col: str = 'freq', 
                       ax=None):
    errors_smoothed = errors.drop(freq_col, axis=1).rolling(window=window, min_periods=1).mean()
    errors_smoothed_max = errors_smoothed.max().max()
    errors_smoothed[freq_col] = errors[freq_col]
    if min_freq is None:
        min_freq=errors_smoothed[freq_col].min()
    else:
        errors_smoothed = errors_smoothed.loc[errors_smoothed[freq_col] > min_freq]
    ax = errors_smoothed.plot(x=freq_col, ylabel=f'rolling average error ({metric_name})',
                              color=colors_to_use,
                              xlim=(min_freq, errors_smoothed[freq_col].max()),
                              ylim=(0, min(errors_smoothed_max, 5)), 
                              ax=None)
    return ax

min_freq = None
ax = plot_rolling_error(errors_test, metric_name=METRIC, window=int(len(errors_test)/15), min_freq=min_freq)
vaep.savefig(ax.get_figure(), name='errors_rolling_avg_test', folder=args.out_figures)

## Validation data

In [None]:
split = 'val'
pred_files = [f for f in args.out_preds.iterdir() if split in f.name]
pred_val = compare_predictions.load_predictions(pred_files)
# pred_val = pred_val.join(medians_train, on=freq_feat.index.name)
pred_val['random shifted normal'] = imputed_shifted_normal
# pred_val = pred_val.join(freq_feat, on=freq_feat.index.name)
pred_val_corr = pred_val.corr()
ax = pred_val_corr.loc['observed', ORDER_MODELS].plot.bar(
    # title='Correlation between Fake NA and model predictions on validation data',
    ylabel='correlation overall')
ax = vaep.plotting.add_height_to_barplot(ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
vaep.savefig(ax.get_figure(), name='pred_corr_val_overall', folder=args.out_figures)
pred_val_corr

In [None]:
corr_per_sample_val = pred_val.groupby('Sample ID').aggregate(lambda df: df.corr().loc['observed'])[ORDER_MODELS]

kwargs = dict(ylim=(0.7,1), rot=90,
              # title='Corr. betw. fake NA and model pred. per sample on validation data',
              ylabel='correlation per sample')
ax = corr_per_sample_val.plot.box(**kwargs)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
vaep.savefig(ax.get_figure(), name='pred_corr_valid_per_sample', folder=args.out_figures)
with pd.ExcelWriter(args.out_figures/'pred_corr_valid_per_sample.xlsx') as writer:   
    corr_per_sample_test.describe().to_excel(writer, sheet_name='summary')
    corr_per_sample_test.to_excel(writer, sheet_name='correlations')

identify samples which are below lower whisker for models

In [None]:
treshold = vaep.pandas.get_lower_whiskers(corr_per_sample_val[models]).min()
mask = (corr_per_sample_val[models] < treshold).any(axis=1)
corr_per_sample_val.loc[mask].style.highlight_min(axis=1)

### Error plot

In [None]:
errors_val = pred_val.drop('observed', axis=1).sub(pred_val['observed'], axis=0)[ORDER_MODELS]
errors_val.describe() # over all samples, and all features

Describe absolute error

In [None]:
errors_val.abs().describe() # over all samples, and all features

In [None]:
c_error_min = 4.5
mask = (errors_val[models].abs() > c_error_min).any(axis=1)
errors_val.loc[mask].sort_index(level=1)

In [None]:
errors_val = errors_val.abs().groupby(freq_feat.index.name).mean() # absolute error
errors_val = errors_val.join(freq_feat)
errors_val = errors_val.sort_values(by=freq_feat.name, ascending=True)
errors_val

Some interpolated features are missing

In [None]:
errors_val.describe()  # mean of means

In [None]:
c_avg_error = 2
mask = (errors_val[models] >= c_avg_error).any(axis=1)
errors_val.loc[mask]

In [None]:
errors_val_smoothed = errors_val.copy()
errors_val_smoothed[errors_val.columns[:-1]] = errors_val[errors_val.columns[:-1]].rolling(window=200, min_periods=1).mean()
ax = plot_rolling_error(errors_test, metric_name=METRIC, window=int(len(errors_test)/15), min_freq=min_freq)

In [None]:
errors_val_smoothed.describe()

In [None]:
vaep.savefig(
    ax.get_figure(),
    folder=args.out_figures,
    name='performance_methods_by_completness')

# Average errors per feature - example scatter for collab
- see how smoothing is done, here `collab`
- shows how features are distributed in training data

In [None]:
# scatter plots to see spread
model = models[0]
ax = errors_val.plot.scatter(x=prop.name, y=model, c='darkblue', ylim=(0,2),
  # title=f"Average error per feature on validation data for {model}",
  ylabel=f'average error ({METRIC}) for {model} on valid. data')

vaep.savefig(
    ax.get_figure(),
    folder=args.out_figures,
    name='performance_methods_by_completness_scatter',
)

- [ ] plotly plot with number of observations the mean for each feature is based on