# LSTM Model

In [None]:
import torch
import numpy as np
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from raims.data import GenerativeDataModule, load_word2vec, FixedSizedDataModule
from raims.nn import PureLSTM, EmbeddingLSTM

seed = np.random.SeedSequence(42)
logger = WandbLogger(project='raims', offline=True)

In [None]:
import matchms.logging_functions as mmsl
mmsl.add_logging_to_file("matchms.log",remove_stream_handlers=True)

In [None]:
vocabulary, embeddings = load_word2vec('model/mona-random-w2v.model')

In [None]:
vocabulary['peak@27']

In [None]:
datamodule = GenerativeDataModule(path='data/split/mona-random', vocabulary=vocabulary, onehot=True, intensity=False, batch_size=128,n_workers=1)
#datamodule = FixedSizedDataModule('data/split/mona-random', max_mz=1001, seed=seed)

In [None]:
pure_lstm_1 = PureLSTM(num_classes=len(vocabulary), hidden_size=200, include_intensity=False, learning_rate=1e-3)
pure_lstm_2 = PureLSTM(num_classes=len(vocabulary), hidden_size=200, include_intensity=True, learning_rate=1e-3)

embedding_lstm_1 = EmbeddingLSTM(embeddings=embeddings, hidden_size=200, freeze_embeddings=True, include_intensity=False, learning_rate=1e-3)
embedding_lstm_2 = EmbeddingLSTM(embeddings=embeddings, hidden_size=200, freeze_embeddings=True, include_intensity=True, learning_rate=1e-3)
embedding_lstm_3 = EmbeddingLSTM(embeddings=embeddings, hidden_size=200, freeze_embeddings=False, include_intensity=False, learning_rate=1e-3)
embedding_lstm_4 = EmbeddingLSTM(embeddings=embeddings, hidden_size=200, freeze_embeddings=False, include_intensity=True, learning_rate=1e-3)

embedding_lstm_random = EmbeddingLSTM(embeddings=torch.zeros_like(embeddings), hidden_size=200, freeze_embeddings=False, include_intensity=False, learning_rate=1e-3)

In [None]:
trainer =  Trainer(callbacks=[EarlyStopping(monitor='val_loss', patience=3)], logger=logger, max_epochs=500, accelerator='gpu')

In [None]:
trainer.fit(model=pure_lstm_1, datamodule=datamodule)

In [None]:
pure_lstm_2

In [None]:
trainer.fit(model=pure_lstm_2, datamodule=datamodule)

In [None]:
embedding_lstm_1

In [None]:
trainer.fit(model=embedding_lstm_1, datamodule=datamodule)

In [None]:
trainer.fit(model=embedding_lstm_2, datamodule=datamodule)

In [None]:
trainer.fit(model=embedding_lstm_3, datamodule=datamodule)

In [None]:
trainer.fit(model=embedding_lstm_4, datamodule=datamodule)

In [None]:
trainer.fit(model=embedding_lstm_random, datamodule=datamodule)