<center>
<h1>Multi-layer perceptron</h1>
</center>

---

As an example for a deep learning model, but keeping the computational requirements as minimal as possible, we will train and evaluate a MLP.

In [None]:
import os
import sys 
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
import torch
import mlflow
mlflow.autolog()
mlflow.set_experiment("Peptide retention time regression")
sns.set_style("darkgrid")

sys.path.append("..")
from src.data import load_data, preprocess_data, PeptideDataset
from src.models import LGBMModelHandler
from src.util import  rMAE, rMSE
from src.models import  MLP, TorchModelHandler

### Define the MLP and a model handler

In [None]:
data = load_data("../data/Peptides_and_iRT.tsv")


In [None]:
mlpHandler = TorchModelHandler(MLP, 
                               data=data, 
                               val_frac=0.15, 
                               test_frac=0.15, 
                               vectorizer=CountVectorizer, 
                               model_parameters=dict(hidden_dim=128, 
                                                     hidden_layers=3, 
                                                     output_dim=1, 
                                                     dropout_prob=0.3)
                            )

In [None]:
mlpHandler.train_eval()

In [None]:
mlpHandler.eval(mlpHandler.test_dataset)

## Plot the results

In [None]:
pred = mlpHandler.predict_all()

In [None]:
plt.figure()

y_true, y_pred = pred.T

sns.scatterplot(y_true, y_pred , marker='+', color="darkred")

plt.plot([-100,150], [-100, 150], color="black", lw=0.75)

plt.xlabel("iRT measured")
plt.ylabel("iRT predicted")


plt.title("MLP predictions vs. GT")
plt.tight_layout()

plt.show()

In [None]:
mlpHandler.dump("../models/mlp.pth")

In [None]:
mlpHandler.load("../models/mlp.pth")

<h1>
Conclusions
</h1>

The MLP does not perform on par with the LightGBM model (not very surprising to me). The model has not been tuned, or trained for very long due to limited resources, so there might be some room for improvement at the expense of training time and model size. For tabular data such as this, LightGBM is usually an excellent choice both in terms of accuracy and resource efficiency.

Note that the saving and loading has not been implemented with mlflow since this requires pytorch-lightning. In the interest of keeping this a lightweight project, I did not include it in the environment. Instead the model can be saved and loaded from the model handler.

You can load the model using the modelHandlers `load` method.