## **Introduction to ML for NLP [Network + Practical]**

### **LSTM**

It is now time to train our first Recurrent Neural Nework.

#### **Libraries**

We import the necessary libraries for the notebook.

In [None]:
# general
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

# pytorch
import torch

# custom imports
from utility.models_pytorch import PytorchModel
from utility.dataviz import plot_model_fit_loss, plot_classes_accuracy

print("> Libraries Imported")

#### **Setup**

- We set the device to *cuda*
- We import the dataset

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("> Device:", device)

In [None]:
dataframe = pd.read_pickle("data/3_multi_eurlex_encoded.pkl")
dataframe.head(3)

#### **LSTM**

**Instantiate a Pytorch Model**

We use our custom class PytorchModel to train a LSTM.

In [None]:
COUNTS_EN = 3506
COUNTS_DE = 4216
COUNTS_IT = 4180
COUNTS_PL = 5255
COUNTS_SV = 4010

In [None]:
LSTM_MODEL = PytorchModel(

    # set model and text language
    model_type      = "LSTM_fixed",
    dataset         = dataframe,
    language        = "de",

    # set device, bacth size and epochs
    device          = device,
    batch_size      = 64,
    epochs          = 50,

    # set hyperparameters
    vocab_size      = COUNTS_DE,
    embedding_dim   = 2048,
    hidden_dim      = 2048,
    learning_rate   = 0.001,
    dropout_p       = 0.1
)

**Train the model**

We can now train the model.

The method will evaluate the performance of the model for each epoch.

In [None]:
global_res_df, classes_res_df = LSTM_MODEL.train_model()

In [None]:
global_res_df

In [None]:
classes_res_df

**Visualize the training results**

We plot the training and validation loss, as well as the mean validation accuracy for each class.

In [None]:
plot_model_fit_loss(
    train_loss=global_res_df['training_loss'],
    val_loss=global_res_df['validation_loss'],
    subtitle="Models Details: " + LSTM_MODEL.MODEL_DESCRIPTION
)

In [None]:
plot_classes_accuracy(
    classes_res_df, 
    subtitle="Models Details: " + LSTM_MODEL.MODEL_DESCRIPTION
    )