<center>
<h1>Recurrent Neural Network</h1>
</center>

---

Another type of deep learning model are RNNs, which promises to yield better results since they are capable of taking into account the order of tokens, i.e. AAs, in a sequence. This distinguishes them from LightGBM and MLPs. For the sake of brevity, we only consider GRUs, but LSTMs could also be explored.

In [None]:
import os
import sys 
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import torch
import mlflow
mlflow.autolog()
mlflow.set_experiment("Peptide retention time regression")
sns.set_style("darkgrid")

sys.path.append("..")
from src.data import load_data, preprocess_data, PeptideDataset
from src.models import LGBMModelHandler
from src.util import  rMAE, rMSE
from src.models import  RNN, TorchModelHandler


%load_ext autoreload
%autoreload 2

In [None]:
data = preprocess_data(load_data("../data/Peptides_and_iRT.tsv"))


### Define the RNN and a model handler

In [None]:
rnnHandler = TorchModelHandler(RNN,     
                               data=data, 
                               val_frac=0.15, 
                               test_frac=0.15, 
                               tokenize=True, 
                               preprocess_data=True,
                               model_parameters=dict(embedding_dim=128,
                                    vocab_size=100,
                                    rnn_hidden_dim=128,
                                    hidden_dim=128,
                                    dropout_p=0.3),
                               remove_non_numeric=False)

In [None]:
rnnHandler.train_eval()

In [None]:
rnnHandler.eval(rnnHandler.test_dataset)

### Plot all predictions

In [None]:
pred = rnnHandler.predict_all()

In [None]:
plt.figure()

y_true, y_pred = pred.T

sns.scatterplot(y_true, y_pred , marker='+', color="darkred")

plt.plot([-100,150], [-100, 150], color="black", lw=0.75)

plt.xlabel("iRT measured")
plt.ylabel("iRT predicted")


plt.title("RNN predictions vs. GT")
plt.tight_layout()

plt.show()

In [None]:
rnnHandler.dump("../models/rnn.pth")

In [None]:
rnnHandler.load("../models/rnn.pth")

<h1>
Conclusions
</h1>

As expected, the RNN gives the best results, as it processes the peptide chains sequentially instead of simple statistics of AA counts. However, the RNN is very slow to train on a non-GPU machine and cannot be parallelized. 
