# Compare models

In [None]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd

import vaep
import vaep.imputation
from vaep import sampling
from vaep.io import datasplits

from src import config

In [None]:
# files and folders
folder_experiment:str = 'runs/experiment_03/df_intensities_proteinGroups_long_2017_2018_2019_2020_N05015_M04547/Q_Exactive_HF_X_Orbitrap_Exactive_Series_slot_#6070' # Datasplit folder with data for experiment
folder_data:str = '' # specify data directory if needed
file_format: str = 'pkl' # change default to pickled files
fn_rawfile_metadata: str = 'data/files_selected_metadata.csv' # Machine parsed metadata from rawfile workflow

In [None]:
args = config.Config()
args.fn_rawfile_metadata = fn_rawfile_metadata
del fn_rawfile_metadata
args.folder_experiment = Path(folder_experiment)
del folder_experiment
args.folder_experiment.mkdir(exist_ok=True, parents=True)
args.file_format = file_format
del file_format
args.out_folder = args.folder_experiment
if folder_data:
    args.data = Path(folder_data)
else:
    args.data = args.folder_experiment / 'data'
assert args.data.exists(), f"Directory not found: {args.data}"
del folder_data
args.out_figures = args.folder_experiment / 'figures'
args.out_figures.mkdir(exist_ok=True)
args.out_metrics = args.folder_experiment / 'metrics'
args.out_metrics.mkdir(exist_ok=True)
args.out_models = args.folder_experiment / 'models'
args.out_models.mkdir(exist_ok=True)
args.out_preds = args.folder_experiment / 'preds'
args.out_preds.mkdir(exist_ok=True)
args

In [None]:
data = datasplits.DataSplits.from_folder(args.data, file_format=args.file_format) 

In [None]:
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(18,10))

_ = data.val_y.unstack().notna().sum(axis=1).sort_values().plot(
        rot=90,
        ax=axes[0],
        title='Validation data',
        ylabel='number of feat')
_ = data.test_y.unstack().notna().sum(axis=1).sort_values().plot(
        rot=90,
        ax=axes[1],
        title='Test data')

## Across data completeness

In [None]:
freq_feat = sampling.frequency_by_index(data.train_X, 0)
freq_feat.name = 'freq'
freq_feat.head() # training data

In [None]:
prop = freq_feat / len(data.train_X.index.levels[0])
prop.to_frame()

## reference methods

In [None]:
data.to_wide_format()
imputed_shifted_normal = data.train_X.apply(vaep.imputation.imputation_normal_distribution)
imputed_shifted_normal = imputed_shifted_normal[data.train_X.isna()].stack()
imputed_shifted_normal

In [None]:
data.to_wide_format()
data.train_X

In [None]:
medians_train = data.train_X.median()
medians_train.name = 'median'

## load predictions

In [None]:
def load_predictions(split='test', folder:Path=args.out_preds):

    _pred_files =  folder.iterdir()
    
    for fname in _pred_files:
        if not split in fname.name:
            continue
        pred = pd.read_csv(fname, index_col=[0,1])
        break
    
    shared_columns = ['observed', 'interpolated']

    for fname in _pred_files:
        if not split in fname.name:
            continue
        _pred_file = pd.read_csv(fname, index_col=[0,1])
        assert all(pred[shared_columns] == _pred_file[shared_columns])
        pred = pred.join(_pred_file.drop(shared_columns, axis=1))
    return pred

pred_test = load_predictions(split='test')
pred_test = pred_test.join(medians_train, on=prop.index.name)
pred_test['shifted normal'] = imputed_shifted_normal
pred_test = pred_test.join(prop, on=prop.index.name)
pred_test

In [None]:
feature_names = pred_test.index.levels[-1] 
M = len(feature_names)
pred_test.loc[pd.IndexSlice[:, feature_names[random.randint(0,M)]], :]

In [None]:
pred_val = load_predictions(split='val')
pred_val = pred_val.join(medians_train, on=prop.index.name)
pred_val['shifted normal'] = imputed_shifted_normal
# pred_val = pred_val.join(prop, on=prop.index.name)

# 
errors_val = pred_val.drop('observed', axis=1).sub(pred_val['observed'], axis=0)
errors_val = errors_val.abs().groupby(prop.index.name).mean() # absolute error
errors_val = errors_val.join(prop)
errors_val = errors_val.sort_values(by=prop.name, ascending=True)


In [None]:
errors_val[errors_val.columns[:-1]] = errors_val[errors_val.columns[:-1]].rolling(window=200, min_periods=1).mean()
ax = errors_val.plot(x=prop.name)

In [None]:
scatter plots to see spread
errors_val.plot.scatter(x=prop.name, y='collab',  ylim=(0,4),)