# Latent space visualization

In [None]:
import logging
from pathlib import Path
from pprint import pprint
from src.nb_imports import *


from fastai.losses import MSELossFlat
from fastai.learner import Learner


import fastai
# from fastai.tabular.all import *

from fastai.basics import *
from fastai.callback.all import *
from fastai.torch_basics import *
from fastai.data.all import *

# import fastai.callback.hook # Learner.summary

import sklearn
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

import vaep.io_images
from vaep.models import ae
from vaep.transform import VaepPipeline
from vaep.io import datasplits
from vaep.io.dataloaders import get_dls

import src
import src.analyzers as analyzers
from src import config
from src.logging import setup_logger
logger = setup_logger(logger=logging.getLogger('vaep'))
logger.info("Experiment 03 - Analysis of latent spaces and performance comparisions")

figures = {}  # collection of ax or figures

Papermill script parameters

In [None]:
n_peptides = 50
data = 'data/msinstrument_in_QE4'
epochs_max = 30

Some argument transformations

In [None]:
args = config.Config()
args.data = Path(data)
args.epochs_max = epochs_max

## Load data

In [None]:
data = datasplits.DataSplits.from_folder(args.data)

data is loaded in long format

In [None]:
data.train_X.sample(5)

## Initialize Comparison

- replicates idea for truely missing values: Define truth as by using n=3 replicates to impute
  each sample
- real test data: Not used for predictions or early stopping.

In [None]:
test_predictions_real_na = data.interpolate('test_X').to_frame() # "gold standard"
test_predictions_real_na

In [None]:
test_predictions_observed = data.test_X.to_frame('truth')
test_predictions_observed

In [None]:
ana_train_X = analyzers.AnalyzePeptides(data=data.train_X, is_wide_format=False, ind_unstack='peptide')
# ana_train_X.df.set_index('peptide', append=True, inplace=True)
# ana_train_X.df.reset_index(inplace=True)
figures['pca_train'] = ana_train_X.plot_pca()

## Collaborative Filtering

## Data in Wide format

- Autoencoder need data in wide format

In [None]:
data.to_wide_format()
data.val_X.head()

## Denoising Autoencoder

### DataLoaders

In [None]:
dae_default_pipeline = sklearn.pipeline.Pipeline(
    [
        ('normalize', StandardScaler()),
        ('impute', SimpleImputer(add_indicator=False))
    ])

dae_transforms = VaepPipeline(
    df_train=data.train_X, encode=dae_default_pipeline, decode=['normalize'])

dls = get_dls(data.train_X, data.val_X, transformer=dae_transforms)

### Model

In [None]:
M = data.train_X.shape[-1]
latent_dim = 30

model = ae.Autoencoder(n_features=M, n_neurons=int(
    M/2), last_decoder_activation=None, dim_latent=latent_dim)

### Learner

In [None]:
learn = Learner(dls=dls, model=model,
                loss_func=MSELossFlat(), cbs=ae.ModelAdapter())

In [None]:
learn.show_training_loop()

In [None]:
learn.summary()

In [None]:
suggested_lr = learn.lr_find()
suggested_lr

### Training


In [None]:
learn.fit_one_cycle(args.epochs_max, lr_max=suggested_lr.valley)

### Predictions
- test dataset

In [None]:
# dls.test_dl
# needs to be part of setup procedure of a class
from vaep.io.datasets import DatasetWithTarget
def factory_test_dl(bs=64):
    
    def get_test_dl(df, transformer, dataset):
        ds = dataset(df, transformer)
        return DataLoader(ds, bs=bs)
    return get_test_dl

get_test_dl = factory_test_dl()
dl_test = get_test_dl(df=data.test_X, transformer=dae_transforms, dataset=DatasetWithTarget)

In [None]:
pred, target = learn.get_preds(act=noop, concat_dim=0, reorder=False)
len(pred), len(target)

In [None]:
pred, target = learn.get_preds(dl=dl_test, act=noop, concat_dim=0, reorder=False)
len(pred), len(target)

In [None]:
def get_preds_from_df(df, learn, transformer, dataset=DatasetWithTarget):
    dl = get_test_dl(df=df, transformer=transformer, dataset=dataset)
    res = learn.get_preds(dl=dl_test, concat_dim=0, reorder=False)
    res = L(res).map(lambda x: pd.DataFrame(x, index=df.index, columns=df.columns))
    res = L(res).map(lambda x: transformer.inverse_transform(x))
    return res

res = get_preds_from_df(df=data.test_X, learn=learn, transformer=dae_transforms)

# list(map(lambda x: x.shape, res))
L(res).map(lambda x: x.shape)

In [None]:
pred, target = res
assert len(data.test_X) == len(pred) == len(target)

In [None]:
all(dl_test.dataset.mask_obs == data.test_X.isna())

In [None]:
pred_observed = pred[~dl_test.dataset.mask_obs].stack()
pred_true_na = pred[dl_test.dataset.mask_obs].stack()
assert len(pred_true_na) + len(pred_observed) == reduce(mul, data.test_X.shape)

### Plots

## Config

In [None]:
args