## **NLP Practical**

### **LSTM**

It is now time to train a Recurrent Neural Nework.

#### **Libraries**

We import the necessary libraries for the notebook.

In [1]:
# general
import pandas as pd
import numpy as np
from tqdm import tqdm
tqdm.pandas()

# dataset
from torch.utils.data import Dataset, DataLoader

# pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn

# metrics
from sklearn.metrics import mean_squared_error

# custom imports
from models import PytorchModel

print("> Libraries Imported")

> Libraries Imported


#### **Setup**

- We set the device to *cuda*
- We import the dataset

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("> Device:", device)

> Device: cuda


In [3]:
dataframe = pd.read_pickle("../data/3_multi_eurlex_encoded_v3.pkl")
dataframe.head(3)

Unnamed: 0,celex_id,labels,labels_new,text_en,text_de,text_it,text_pl,text_sv,text_en_enc,text_de_enc,text_it_enc,text_pl_enc,text_sv_enc,set
0,32003R1012,2,1,commission regulation ec no of june amending f...,verordnung eg nr der kommission vom juni zur n...,regolamento ce n della commissione del giugno ...,rozporzadzenie komisji we nr z dnia czerwca r ...,kommissionens forordning eg nr av den juni om ...,"[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 3, 4...","[[2, 3, 4, 5, 6, 7, 8, 9, 1, 10, 5, 2, 3, 4, 1...","[[2, 3, 4, 5, 6, 7, 8, 9, 1, 10, 7, 2, 3, 4, 1...","[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 13...","[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 6, 1...",test
1,32003R2229,18,3,council regulation ec no of december imposing ...,verordnung eg nr des rates vom dezember zur ei...,regolamento ce n del consiglio del dicembre ch...,rozporzadzenie rady we nr z dnia grudnia r nak...,radets forordning eg nr av den december om inf...,"[[13, 3, 4, 5, 6, 117, 14, 118, 119, 120, 121,...","[[2, 3, 4, 11, 12, 7, 116, 9, 117, 118, 119, 1...","[[2, 3, 4, 7, 37, 7, 128, 11, 44, 129, 130, 13...","[[2, 13, 4, 5, 6, 7, 134, 9, 135, 136, 137, 13...","[[14, 3, 4, 5, 6, 7, 113, 9, 16, 6, 114, 115, ...",validation
2,32003R0223,7,2,commission regulation ec no of february on lab...,verordnung eg nr der kommission vom februar zu...,regolamento ce n della commissione del febbrai...,rozporzadzenie komisji we nr z dnia lutego r w...,kommissionens forordning eg nr av den februari...,"[[2, 3, 4, 5, 6, 1033, 78, 1034, 343, 574, 38,...","[[2, 3, 4, 5, 6, 7, 1305, 9, 1136, 57, 1306, 1...","[[2, 3, 4, 5, 6, 7, 1259, 1260, 80, 393, 53, 1...","[[2, 3, 4, 5, 6, 7, 1603, 9, 59, 150, 1604, 59...","[[2, 3, 4, 5, 6, 7, 1240, 9, 1241, 122, 1242, ...",train


#### **Instantiate a Pytorch Model: LSTM**

We use our custom class PytorchModel to train a LSTM.

In [4]:
COUNTS_EN = 14751
COUNTS_DE = 31903
COUNTS_IT = 20577
COUNTS_PL = 33812
COUNTS_SV = 29075

In [5]:
LSTM_MODEL = PytorchModel(

    # set model and text language
    model_type      = "LSTM",
    dataset         = dataframe,
    language        = "it",

    # set device, bacth size and epochs
    device          = device,
    batch_size      = 50,
    epochs          = 50,

    # set hyperparameters
    vocab_size      = COUNTS_IT,
    embedding_dim   = 2048,
    hidden_dim      = 1024,
    learning_rate   = 0.001,
    dropout_p       = 0.1
)

> Parameters imported
> Dataset correctly divided in training set, validation set and test set
> Created Pytorch datasets and dataloaders
> Model 'LSTM' instantiated
> Initialization required 1.6358 seconds


**Train the model**

We can now train the model.

The method will evaluate the performance of the model for each epoch.

In [6]:
global_res_df, classes_res_df = LSTM_MODEL.train_model()

> Training Started
  - Total Epochs: 50


> Epoch 1: 100%|██████████| 150/150 [02:32<00:00,  1.02s/it]


 - Training Loss        1.3626
 - Validation Loss      1.3481
 - Validation Accuracy  0.3812
 - Validation Accuracy (per class)
   * Class 0	 0.0144 [6 out of 416]
   * Class 1	 0.9292 [669 out of 720]
   * Class 2	 0.0482 [15 out of 311]
   * Class 3	 0.0523 [22 out of 421]



> Epoch 2: 100%|██████████| 150/150 [02:33<00:00,  1.02s/it]


 - Training Loss        1.2342
 - Validation Loss      1.4212
 - Validation Accuracy  0.379
 - Validation Accuracy (per class)
   * Class 0	 0.0457 [19 out of 416]
   * Class 1	 0.9083 [654 out of 720]
   * Class 2	 0.0514 [16 out of 311]
   * Class 3	 0.0451 [19 out of 421]



> Epoch 3: 100%|██████████| 150/150 [02:34<00:00,  1.03s/it]


 - Training Loss        1.1502
 - Validation Loss      1.4475
 - Validation Accuracy  0.3828
 - Validation Accuracy (per class)
   * Class 0	 0.0409 [17 out of 416]
   * Class 1	 0.9125 [657 out of 720]
   * Class 2	 0.0643 [20 out of 311]
   * Class 3	 0.0499 [21 out of 421]



> Epoch 4: 100%|██████████| 150/150 [02:33<00:00,  1.02s/it]


 - Training Loss        1.1253
 - Validation Loss      1.4739
 - Validation Accuracy  0.3817
 - Validation Accuracy (per class)
   * Class 0	 0.0409 [17 out of 416]
   * Class 1	 0.9167 [660 out of 720]
   * Class 2	 0.0611 [19 out of 311]
   * Class 3	 0.0404 [17 out of 421]



> Epoch 5: 100%|██████████| 150/150 [02:33<00:00,  1.03s/it]


 - Training Loss        1.207
 - Validation Loss      1.4251
 - Validation Accuracy  0.3399
 - Validation Accuracy (per class)
   * Class 0	 0.3606 [150 out of 416]
   * Class 1	 0.6194 [446 out of 720]
   * Class 2	 0.0675 [21 out of 311]
   * Class 3	 0.0428 [18 out of 421]



> Epoch 6: 100%|██████████| 150/150 [02:34<00:00,  1.03s/it]


 - Training Loss        1.1711
 - Validation Loss      1.4333
 - Validation Accuracy  0.3431
 - Validation Accuracy (per class)
   * Class 0	 0.2885 [120 out of 416]
   * Class 1	 0.6611 [476 out of 720]
   * Class 2	 0.0643 [20 out of 311]
   * Class 3	 0.0594 [25 out of 421]



> Epoch 7: 100%|██████████| 150/150 [02:34<00:00,  1.03s/it]


 - Training Loss        1.1512
 - Validation Loss      1.4261
 - Validation Accuracy  0.3849
 - Validation Accuracy (per class)
   * Class 0	 0.0337 [14 out of 416]
   * Class 1	 0.9236 [665 out of 720]
   * Class 2	 0.0643 [20 out of 311]
   * Class 3	 0.0475 [20 out of 421]



> Epoch 8: 100%|██████████| 150/150 [02:34<00:00,  1.03s/it]


 - Training Loss        1.1483
 - Validation Loss      1.4416
 - Validation Accuracy  0.3415
 - Validation Accuracy (per class)
   * Class 0	 0.2885 [120 out of 416]
   * Class 1	 0.6611 [476 out of 720]
   * Class 2	 0.0675 [21 out of 311]
   * Class 3	 0.0499 [21 out of 421]



> Epoch 9: 100%|██████████| 150/150 [02:32<00:00,  1.01s/it]


 - Training Loss        1.1472
 - Validation Loss      1.436
 - Validation Accuracy  0.386
 - Validation Accuracy (per class)
   * Class 0	 0.0337 [14 out of 416]
   * Class 1	 0.9236 [665 out of 720]
   * Class 2	 0.0675 [21 out of 311]
   * Class 3	 0.0499 [21 out of 421]



> Epoch 10: 100%|██████████| 150/150 [02:33<00:00,  1.02s/it]


 - Training Loss        1.1377
 - Validation Loss      1.468
 - Validation Accuracy  0.3854
 - Validation Accuracy (per class)
   * Class 0	 0.0361 [15 out of 416]
   * Class 1	 0.9208 [663 out of 720]
   * Class 2	 0.0675 [21 out of 311]
   * Class 3	 0.0499 [21 out of 421]



> Epoch 11: 100%|██████████| 150/150 [02:34<00:00,  1.03s/it]


 - Training Loss        1.1415
 - Validation Loss      1.4611
 - Validation Accuracy  0.3292
 - Validation Accuracy (per class)
   * Class 0	 0.0361 [15 out of 416]
   * Class 1	 0.6194 [446 out of 720]
   * Class 2	 0.0707 [22 out of 311]
   * Class 3	 0.3135 [132 out of 421]



> Epoch 12: 100%|██████████| 150/150 [02:34<00:00,  1.03s/it]


 - Training Loss        1.1409
 - Validation Loss      1.469
 - Validation Accuracy  0.3865
 - Validation Accuracy (per class)
   * Class 0	 0.0337 [14 out of 416]
   * Class 1	 0.9222 [664 out of 720]
   * Class 2	 0.0707 [22 out of 311]
   * Class 3	 0.0523 [22 out of 421]



> Epoch 13: 100%|██████████| 150/150 [02:30<00:00,  1.01s/it]


 - Training Loss        1.1405
 - Validation Loss      1.4663
 - Validation Accuracy  0.387
 - Validation Accuracy (per class)
   * Class 0	 0.0337 [14 out of 416]
   * Class 1	 0.925 [666 out of 720]
   * Class 2	 0.0675 [21 out of 311]
   * Class 3	 0.0523 [22 out of 421]



> Epoch 14:  46%|████▌     | 69/150 [01:09<01:19,  1.02it/s]

In [None]:
global_res_df

In [None]:
classes_res_df

**Evaluate the model**

Let us now evaluate the performance of the model on a test set (which contains observations that the model has never seen).

In [None]:
# TODO: create function 'test_model()' 