## **Introduction to ML for NLP [Network + Practical]**

### **LSTM**

It is now time to train our first Recurrent Neural Nework.

#### **Libraries**

We import the necessary libraries for the notebook.

In [1]:
# general
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

# pytorch
import torch

# custom imports
from utility.models_pytorch import PytorchModel
from utility.dataviz import plot_model_fit_loss, plot_classes_accuracy

print("> Libraries Imported")

> Libraries Imported


#### **Setup**

- We set the device to *cuda*
- We import the dataset

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("> Device:", device)

> Device: cuda


In [3]:
dataframe = pd.read_pickle("data/3_multi_eurlex_encoded.pkl")
dataframe.head(3)

Unnamed: 0,celex_id,labels,labels_new,text_en,text_de,text_it,text_pl,text_sv,text_en_enc,text_de_enc,text_it_enc,text_pl_enc,text_sv_enc,set
0,32010D0395,2,0,commission decision of december on state aid c...,beschluss der kommission vom dezember uber die...,decisione della commissione del dicembre conce...,decyzja komisji z dnia grudnia r w sprawie pom...,kommissionens beslut av den december om det st...,"[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, ...","[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, ...","[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, ...","[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, ...","[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, ...",train
1,32012R0453,2,0,commission implementing regulation eu no of ma...,durchfuhrungsverordnung eu nr der kommission v...,regolamento di esecuzione ue n della commissio...,rozporzadzenie wykonawcze komisji ue nr z dnia...,kommissionens genomforandeforordning eu nr av ...,"[[2, 1275, 1276, 29, 100, 4, 743, 1277, 15, 12...","[[1302, 33, 1303, 3, 4, 5, 807, 15, 1304, 3, 6...","[[453, 10, 1422, 38, 14, 3, 4, 5, 990, 1423, 1...","[[1753, 1754, 3, 34, 24, 4, 5, 829, 7, 1755, 9...","[[2, 1239, 33, 23, 4, 5, 806, 7, 774, 4, 132, ...",train
2,32012D0043,2,0,commission implementing decision of january au...,durchfuhrungsbeschluss der kommission vom janu...,decisione di esecuzione della commissione del ...,decyzja wykonawcza komisji z dnia stycznia r u...,kommissionens genomforandebeslut av den januar...,"[[2, 1275, 3, 4, 1310, 1311, 15, 1015, 4, 1312...","[[1344, 3, 4, 5, 1345, 15, 1346, 74, 1347, 134...","[[2, 10, 1422, 3, 4, 5, 1454, 245, 1455, 24, 1...","[[2, 1791, 3, 4, 5, 1792, 7, 1, 1793, 1794, 65...","[[2, 1279, 4, 5, 1280, 7, 1281, 19, 1282, 1283...",train


#### **LSTM**

**Instantiate a Pytorch Model**

We use our custom class PytorchModel to train a LSTM.

In [4]:
COUNTS_EN = 3506
COUNTS_DE = 4216
COUNTS_IT = 4180
COUNTS_PL = 5255
COUNTS_SV = 4010

In [5]:
CNN_MODEL = PytorchModel(

    # set model and text language
    model_type      = "CNN_fixed",
    dataset         = dataframe,
    language        = "en",

    # set device, bacth size and epochs
    device          = device,
    batch_size      = 64,
    epochs          = 50,

    # set general hyperparameters
    learning_rate   = 0.001,

    # set specific hyperparameters
    vocab_size      = COUNTS_EN,
    embedding_dim   = 1024,
    out_channels    = 1,
    kernel_size     = 5,
    stride          = 1,
    padding         = 2,
    dropout_p       = 0.1,
)

> Parameters imported for CNN_fixed
> Dataset correctly divided in training set, validation set and test set
> Created Pytorch datasets and dataloaders
> Initialization required 0.1297 seconds


**Train the model**

We can now train the model.

The method will evaluate the performance of the model for each epoch.

In [6]:
global_res_df, classes_res_df = CNN_MODEL.train_model()

> Training Started
  - Total Epochs: 50


> Epoch 1: 100%|██████████| 60/60 [00:03<00:00, 19.16it/s]


 - Training Loss        0.9773
 - Validation Loss      0.8618
 - Validation Accuracy  0.699

 - Validation Accuracy (per class)
   * Class 0	 0.4985 [169 out of 339]
   * Class 1	 0.901 [282 out of 313]
   * Class 2	 0.7143 [220 out of 308]
   * Mean        0.7046

> ATTENTION: epoch 1 was the best one so far! The model has been saved :)



> Epoch 2: 100%|██████████| 60/60 [00:01<00:00, 57.20it/s]


 - Training Loss        0.8297
 - Validation Loss      0.796
 - Validation Accuracy  0.7438

 - Validation Accuracy (per class)
   * Class 0	 0.5428 [184 out of 339]
   * Class 1	 0.8435 [264 out of 313]
   * Class 2	 0.8636 [266 out of 308]
   * Mean        0.75

> ATTENTION: epoch 2 was the best one so far! The model has been saved :)



> Epoch 3: 100%|██████████| 60/60 [00:01<00:00, 58.14it/s]


 - Training Loss        0.7731
 - Validation Loss      0.7513
 - Validation Accuracy  0.7917

 - Validation Accuracy (per class)
   * Class 0	 0.6578 [223 out of 339]
   * Class 1	 0.8978 [281 out of 313]
   * Class 2	 0.8312 [256 out of 308]
   * Mean        0.7956

> ATTENTION: epoch 3 was the best one so far! The model has been saved :)



> Epoch 4: 100%|██████████| 60/60 [00:01<00:00, 59.71it/s]


 - Training Loss        0.7183
 - Validation Loss      0.7319
 - Validation Accuracy  0.8323

 - Validation Accuracy (per class)
   * Class 0	 0.7906 [268 out of 339]
   * Class 1	 0.9169 [287 out of 313]
   * Class 2	 0.7922 [244 out of 308]
   * Mean        0.8332

> ATTENTION: epoch 4 was the best one so far! The model has been saved :)



> Epoch 5: 100%|██████████| 60/60 [00:01<00:00, 59.09it/s]


 - Training Loss        0.6977
 - Validation Loss      0.6992
 - Validation Accuracy  0.8417

 - Validation Accuracy (per class)
   * Class 0	 0.8024 [272 out of 339]
   * Class 1	 0.8786 [275 out of 313]
   * Class 2	 0.8474 [261 out of 308]
   * Mean        0.8428

> ATTENTION: epoch 5 was the best one so far! The model has been saved :)



> Epoch 6: 100%|██████████| 60/60 [00:01<00:00, 58.08it/s]


 - Training Loss        0.6874
 - Validation Loss      0.7277
 - Validation Accuracy  0.8302

 - Validation Accuracy (per class)
   * Class 0	 0.9174 [311 out of 339]
   * Class 1	 0.8786 [275 out of 313]
   * Class 2	 0.6851 [211 out of 308]
   * Mean        0.827



> Epoch 7: 100%|██████████| 60/60 [00:01<00:00, 56.92it/s]


 - Training Loss        0.6847
 - Validation Loss      0.7339
 - Validation Accuracy  0.8146

 - Validation Accuracy (per class)
   * Class 0	 0.8289 [281 out of 339]
   * Class 1	 0.6677 [209 out of 313]
   * Class 2	 0.9481 [292 out of 308]
   * Mean        0.8149



> Epoch 8: 100%|██████████| 60/60 [00:01<00:00, 57.96it/s]


 - Training Loss        0.6665
 - Validation Loss      0.6862
 - Validation Accuracy  0.8573

 - Validation Accuracy (per class)
   * Class 0	 0.8761 [297 out of 339]
   * Class 1	 0.8722 [273 out of 313]
   * Class 2	 0.8214 [253 out of 308]
   * Mean        0.8566

> ATTENTION: epoch 8 was the best one so far! The model has been saved :)



> Epoch 9: 100%|██████████| 60/60 [00:01<00:00, 59.89it/s]


 - Training Loss        0.6637
 - Validation Loss      0.7013
 - Validation Accuracy  0.8656

 - Validation Accuracy (per class)
   * Class 0	 0.8525 [289 out of 339]
   * Class 1	 0.8658 [271 out of 313]
   * Class 2	 0.8799 [271 out of 308]
   * Mean        0.8661

> ATTENTION: epoch 9 was the best one so far! The model has been saved :)



> Epoch 10: 100%|██████████| 60/60 [00:01<00:00, 60.00it/s]


 - Training Loss        0.6614
 - Validation Loss      0.6767
 - Validation Accuracy  0.8698

 - Validation Accuracy (per class)
   * Class 0	 0.9027 [306 out of 339]
   * Class 1	 0.869 [272 out of 313]
   * Class 2	 0.8344 [257 out of 308]
   * Mean        0.8687

> ATTENTION: epoch 10 was the best one so far! The model has been saved :)



> Epoch 11: 100%|██████████| 60/60 [00:00<00:00, 60.05it/s]


 - Training Loss        0.6669
 - Validation Loss      0.7302
 - Validation Accuracy  0.8115

 - Validation Accuracy (per class)
   * Class 0	 0.7817 [265 out of 339]
   * Class 1	 0.7029 [220 out of 313]
   * Class 2	 0.9545 [294 out of 308]
   * Mean        0.813



> Epoch 12: 100%|██████████| 60/60 [00:01<00:00, 59.72it/s]


 - Training Loss        0.6662
 - Validation Loss      0.7125
 - Validation Accuracy  0.8427

 - Validation Accuracy (per class)
   * Class 0	 0.8525 [289 out of 339]
   * Class 1	 0.722 [226 out of 313]
   * Class 2	 0.9545 [294 out of 308]
   * Mean        0.843



> Epoch 13: 100%|██████████| 60/60 [00:01<00:00, 59.58it/s]


 - Training Loss        0.652
 - Validation Loss      0.7035
 - Validation Accuracy  0.8438

 - Validation Accuracy (per class)
   * Class 0	 0.8614 [292 out of 339]
   * Class 1	 0.738 [231 out of 313]
   * Class 2	 0.9318 [287 out of 308]
   * Mean        0.8437



> Epoch 14: 100%|██████████| 60/60 [00:01<00:00, 59.42it/s]


 - Training Loss        0.6611
 - Validation Loss      0.6949
 - Validation Accuracy  0.8531

 - Validation Accuracy (per class)
   * Class 0	 0.8466 [287 out of 339]
   * Class 1	 0.7796 [244 out of 313]
   * Class 2	 0.9351 [288 out of 308]
   * Mean        0.8538



> Epoch 15: 100%|██████████| 60/60 [00:00<00:00, 60.06it/s]


 - Training Loss        0.6593
 - Validation Loss      0.7462
 - Validation Accuracy  0.8042

 - Validation Accuracy (per class)
   * Class 0	 0.8761 [297 out of 339]
   * Class 1	 0.5911 [185 out of 313]
   * Class 2	 0.9416 [290 out of 308]
   * Mean        0.8029



> Epoch 16: 100%|██████████| 60/60 [00:01<00:00, 52.43it/s]


 - Training Loss        0.6585
 - Validation Loss      0.7013
 - Validation Accuracy  0.8563

 - Validation Accuracy (per class)
   * Class 0	 0.8437 [286 out of 339]
   * Class 1	 0.7827 [245 out of 313]
   * Class 2	 0.9448 [291 out of 308]
   * Mean        0.8571



> Epoch 17: 100%|██████████| 60/60 [00:01<00:00, 56.60it/s]


 - Training Loss        0.6528
 - Validation Loss      0.6886
 - Validation Accuracy  0.8604

 - Validation Accuracy (per class)
   * Class 0	 0.8083 [274 out of 339]
   * Class 1	 0.8914 [279 out of 313]
   * Class 2	 0.8864 [273 out of 308]
   * Mean        0.862



> Epoch 18: 100%|██████████| 60/60 [00:01<00:00, 59.84it/s]


 - Training Loss        0.6435
 - Validation Loss      0.696
 - Validation Accuracy  0.8583

 - Validation Accuracy (per class)
   * Class 0	 0.7935 [269 out of 339]
   * Class 1	 0.8914 [279 out of 313]
   * Class 2	 0.8961 [276 out of 308]
   * Mean        0.8603



> Epoch 19: 100%|██████████| 60/60 [00:01<00:00, 58.21it/s]


 - Training Loss        0.6437
 - Validation Loss      0.7201
 - Validation Accuracy  0.8385

 - Validation Accuracy (per class)
   * Class 0	 0.7463 [253 out of 339]
   * Class 1	 0.8435 [264 out of 313]
   * Class 2	 0.9351 [288 out of 308]
   * Mean        0.8416



> Epoch 20: 100%|██████████| 60/60 [00:01<00:00, 58.59it/s]


 - Training Loss        0.6667
 - Validation Loss      0.6929
 - Validation Accuracy  0.851

 - Validation Accuracy (per class)
   * Class 0	 0.8024 [272 out of 339]
   * Class 1	 0.9137 [286 out of 313]
   * Class 2	 0.8409 [259 out of 308]
   * Mean        0.8523



> Epoch 21: 100%|██████████| 60/60 [00:01<00:00, 59.17it/s]


 - Training Loss        0.6677
 - Validation Loss      0.7119
 - Validation Accuracy  0.8

 - Validation Accuracy (per class)
   * Class 0	 0.9764 [331 out of 339]
   * Class 1	 0.6965 [218 out of 313]
   * Class 2	 0.711 [219 out of 308]
   * Mean        0.7946



> Epoch 22: 100%|██████████| 60/60 [00:01<00:00, 58.88it/s]


 - Training Loss        0.6571
 - Validation Loss      0.6907
 - Validation Accuracy  0.8542

 - Validation Accuracy (per class)
   * Class 0	 0.7906 [268 out of 339]
   * Class 1	 0.8882 [278 out of 313]
   * Class 2	 0.8896 [274 out of 308]
   * Mean        0.8561



> Epoch 23: 100%|██████████| 60/60 [00:01<00:00, 58.48it/s]


 - Training Loss        0.6579
 - Validation Loss      0.7207
 - Validation Accuracy  0.8448

 - Validation Accuracy (per class)
   * Class 0	 0.8112 [275 out of 339]
   * Class 1	 0.7796 [244 out of 313]
   * Class 2	 0.9481 [292 out of 308]
   * Mean        0.8463



> Epoch 24: 100%|██████████| 60/60 [00:01<00:00, 58.71it/s]


 - Training Loss        0.6776
 - Validation Loss      0.6955
 - Validation Accuracy  0.8615

 - Validation Accuracy (per class)
   * Class 0	 0.8289 [281 out of 339]
   * Class 1	 0.8403 [263 out of 313]
   * Class 2	 0.9188 [283 out of 308]
   * Mean        0.8627



> Epoch 25: 100%|██████████| 60/60 [00:01<00:00, 58.32it/s]


 - Training Loss        0.6548
 - Validation Loss      0.6826
 - Validation Accuracy  0.8677

 - Validation Accuracy (per class)
   * Class 0	 0.8584 [291 out of 339]
   * Class 1	 0.9233 [289 out of 313]
   * Class 2	 0.8214 [253 out of 308]
   * Mean        0.8677



> Epoch 26: 100%|██████████| 60/60 [00:01<00:00, 58.54it/s]


 - Training Loss        0.6608
 - Validation Loss      0.7305
 - Validation Accuracy  0.8271

 - Validation Accuracy (per class)
   * Class 0	 0.8732 [296 out of 339]
   * Class 1	 0.6901 [216 out of 313]
   * Class 2	 0.9156 [282 out of 308]
   * Mean        0.8263



> Epoch 27: 100%|██████████| 60/60 [00:01<00:00, 59.15it/s]


 - Training Loss        0.6574
 - Validation Loss      0.6846
 - Validation Accuracy  0.8635

 - Validation Accuracy (per class)
   * Class 0	 0.8348 [283 out of 339]
   * Class 1	 0.9073 [284 out of 313]
   * Class 2	 0.8506 [262 out of 308]
   * Mean        0.8642



> Epoch 28: 100%|██████████| 60/60 [00:01<00:00, 59.05it/s]


 - Training Loss        0.6516
 - Validation Loss      0.7151
 - Validation Accuracy  0.8448

 - Validation Accuracy (per class)
   * Class 0	 0.9233 [313 out of 339]
   * Class 1	 0.7188 [225 out of 313]
   * Class 2	 0.8864 [273 out of 308]
   * Mean        0.8428



> Epoch 29: 100%|██████████| 60/60 [00:01<00:00, 59.00it/s]


 - Training Loss        0.6628
 - Validation Loss      0.6851
 - Validation Accuracy  0.8594

 - Validation Accuracy (per class)
   * Class 0	 0.8614 [292 out of 339]
   * Class 1	 0.9393 [294 out of 313]
   * Class 2	 0.776 [239 out of 308]
   * Mean        0.8589



> Epoch 30: 100%|██████████| 60/60 [00:01<00:00, 59.05it/s]


 - Training Loss        0.6592
 - Validation Loss      0.6944
 - Validation Accuracy  0.8625

 - Validation Accuracy (per class)
   * Class 0	 0.8761 [297 out of 339]
   * Class 1	 0.8115 [254 out of 313]
   * Class 2	 0.8994 [277 out of 308]
   * Mean        0.8623



> Epoch 31: 100%|██████████| 60/60 [00:01<00:00, 52.86it/s]


 - Training Loss        0.6541
 - Validation Loss      0.687
 - Validation Accuracy  0.8635

 - Validation Accuracy (per class)
   * Class 0	 0.9174 [311 out of 339]
   * Class 1	 0.8754 [274 out of 313]
   * Class 2	 0.7922 [244 out of 308]
   * Mean        0.8617



> Epoch 32: 100%|██████████| 60/60 [00:01<00:00, 49.92it/s]


 - Training Loss        0.6576
 - Validation Loss      0.7075
 - Validation Accuracy  0.849

 - Validation Accuracy (per class)
   * Class 0	 0.8437 [286 out of 339]
   * Class 1	 0.8115 [254 out of 313]
   * Class 2	 0.8929 [275 out of 308]
   * Mean        0.8494



> Epoch 33: 100%|██████████| 60/60 [00:01<00:00, 49.54it/s]


 - Training Loss        0.6521
 - Validation Loss      0.6906
 - Validation Accuracy  0.8635

 - Validation Accuracy (per class)
   * Class 0	 0.8555 [290 out of 339]
   * Class 1	 0.8498 [266 out of 313]
   * Class 2	 0.8864 [273 out of 308]
   * Mean        0.8639



> Epoch 34: 100%|██████████| 60/60 [00:01<00:00, 49.26it/s]


 - Training Loss        0.6559
 - Validation Loss      0.6942
 - Validation Accuracy  0.8677

 - Validation Accuracy (per class)
   * Class 0	 0.8584 [291 out of 339]
   * Class 1	 0.9169 [287 out of 313]
   * Class 2	 0.8279 [255 out of 308]
   * Mean        0.8677



> Epoch 35: 100%|██████████| 60/60 [00:01<00:00, 49.06it/s]


 - Training Loss        0.6573
 - Validation Loss      0.681
 - Validation Accuracy  0.8729

 - Validation Accuracy (per class)
   * Class 0	 0.885 [300 out of 339]
   * Class 1	 0.8403 [263 out of 313]
   * Class 2	 0.8929 [275 out of 308]
   * Mean        0.8727

> ATTENTION: epoch 35 was the best one so far! The model has been saved :)



> Epoch 36: 100%|██████████| 60/60 [00:01<00:00, 49.54it/s]


 - Training Loss        0.662
 - Validation Loss      0.7005
 - Validation Accuracy  0.8563

 - Validation Accuracy (per class)
   * Class 0	 0.8201 [278 out of 339]
   * Class 1	 0.8658 [271 out of 313]
   * Class 2	 0.8864 [273 out of 308]
   * Mean        0.8574



> Epoch 37: 100%|██████████| 60/60 [00:01<00:00, 49.83it/s]


 - Training Loss        0.6574
 - Validation Loss      0.6732
 - Validation Accuracy  0.8625

 - Validation Accuracy (per class)
   * Class 0	 0.9351 [317 out of 339]
   * Class 1	 0.8658 [271 out of 313]
   * Class 2	 0.7792 [240 out of 308]
   * Mean        0.86



> Epoch 38: 100%|██████████| 60/60 [00:01<00:00, 49.34it/s]


 - Training Loss        0.6537
 - Validation Loss      0.68
 - Validation Accuracy  0.8667

 - Validation Accuracy (per class)
   * Class 0	 0.9086 [308 out of 339]
   * Class 1	 0.8594 [269 out of 313]
   * Class 2	 0.8279 [255 out of 308]
   * Mean        0.8653



> Epoch 39: 100%|██████████| 60/60 [00:01<00:00, 49.22it/s]


 - Training Loss        0.6569
 - Validation Loss      0.7076
 - Validation Accuracy  0.849

 - Validation Accuracy (per class)
   * Class 0	 0.8378 [284 out of 339]
   * Class 1	 0.9042 [283 out of 313]
   * Class 2	 0.8052 [248 out of 308]
   * Mean        0.8491



> Epoch 40: 100%|██████████| 60/60 [00:01<00:00, 49.54it/s]


 - Training Loss        0.6674
 - Validation Loss      0.7061
 - Validation Accuracy  0.8542

 - Validation Accuracy (per class)
   * Class 0	 0.882 [299 out of 339]
   * Class 1	 0.8275 [259 out of 313]
   * Class 2	 0.8506 [262 out of 308]
   * Mean        0.8534



> Epoch 41: 100%|██████████| 60/60 [00:01<00:00, 49.73it/s]


 - Training Loss        0.6653
 - Validation Loss      0.697
 - Validation Accuracy  0.8521

 - Validation Accuracy (per class)
   * Class 0	 0.8407 [285 out of 339]
   * Class 1	 0.9297 [291 out of 313]
   * Class 2	 0.7857 [242 out of 308]
   * Mean        0.852



> Epoch 42: 100%|██████████| 60/60 [00:01<00:00, 50.34it/s]


 - Training Loss        0.6826
 - Validation Loss      0.7282
 - Validation Accuracy  0.8417

 - Validation Accuracy (per class)
   * Class 0	 0.7434 [252 out of 339]
   * Class 1	 0.9105 [285 out of 313]
   * Class 2	 0.8799 [271 out of 308]
   * Mean        0.8446



> Epoch 43: 100%|██████████| 60/60 [00:01<00:00, 49.96it/s]


 - Training Loss        0.6713
 - Validation Loss      0.6889
 - Validation Accuracy  0.8573

 - Validation Accuracy (per class)
   * Class 0	 0.8614 [292 out of 339]
   * Class 1	 0.8722 [273 out of 313]
   * Class 2	 0.8377 [258 out of 308]
   * Mean        0.8571



> Epoch 44: 100%|██████████| 60/60 [00:01<00:00, 49.71it/s]


 - Training Loss        0.6567
 - Validation Loss      0.6874
 - Validation Accuracy  0.8615

 - Validation Accuracy (per class)
   * Class 0	 0.8968 [304 out of 339]
   * Class 1	 0.8466 [265 out of 313]
   * Class 2	 0.8377 [258 out of 308]
   * Mean        0.8604



> Epoch 45: 100%|██████████| 60/60 [00:01<00:00, 48.98it/s]


 - Training Loss        0.6577
 - Validation Loss      0.6745
 - Validation Accuracy  0.8719

 - Validation Accuracy (per class)
   * Class 0	 0.8968 [304 out of 339]
   * Class 1	 0.8914 [279 out of 313]
   * Class 2	 0.8247 [254 out of 308]
   * Mean        0.871



> Epoch 46: 100%|██████████| 60/60 [00:01<00:00, 58.32it/s]


 - Training Loss        0.6558
 - Validation Loss      0.7005
 - Validation Accuracy  0.8417

 - Validation Accuracy (per class)
   * Class 0	 0.9292 [315 out of 339]
   * Class 1	 0.722 [226 out of 313]
   * Class 2	 0.8669 [267 out of 308]
   * Mean        0.8394



> Epoch 47: 100%|██████████| 60/60 [00:01<00:00, 59.43it/s]


 - Training Loss        0.6597
 - Validation Loss      0.7078
 - Validation Accuracy  0.851

 - Validation Accuracy (per class)
   * Class 0	 0.7729 [262 out of 339]
   * Class 1	 0.9297 [291 out of 313]
   * Class 2	 0.8571 [264 out of 308]
   * Mean        0.8532



> Epoch 48: 100%|██████████| 60/60 [00:01<00:00, 59.46it/s]


 - Training Loss        0.6639
 - Validation Loss      0.6866
 - Validation Accuracy  0.8688

 - Validation Accuracy (per class)
   * Class 0	 0.9027 [306 out of 339]
   * Class 1	 0.8275 [259 out of 313]
   * Class 2	 0.8734 [269 out of 308]
   * Mean        0.8679



> Epoch 49: 100%|██████████| 60/60 [00:01<00:00, 59.46it/s]


 - Training Loss        0.6733
 - Validation Loss      0.7213
 - Validation Accuracy  0.8083

 - Validation Accuracy (per class)
   * Class 0	 0.9381 [318 out of 339]
   * Class 1	 0.6326 [198 out of 313]
   * Class 2	 0.8442 [260 out of 308]
   * Mean        0.805



> Epoch 50: 100%|██████████| 60/60 [00:01<00:00, 59.76it/s]


 - Training Loss        0.6655
 - Validation Loss      0.7133
 - Validation Accuracy  0.8417

 - Validation Accuracy (per class)
   * Class 0	 0.7817 [265 out of 339]
   * Class 1	 0.9553 [299 out of 313]
   * Class 2	 0.7922 [244 out of 308]
   * Mean        0.8431



In [7]:
global_res_df

Unnamed: 0,epoch,training_loss,validation_loss,validation_accuracy (global),validation_accuracy (mean)
0,1,0.9773,0.8618,0.699,0.7046
1,2,0.8297,0.796,0.7438,0.75
2,3,0.7731,0.7513,0.7917,0.7956
3,4,0.7183,0.7319,0.8323,0.8332
4,5,0.6977,0.6992,0.8417,0.8428
5,6,0.6874,0.7277,0.8302,0.827
6,7,0.6847,0.7339,0.8146,0.8149
7,8,0.6665,0.6862,0.8573,0.8566
8,9,0.6637,0.7013,0.8656,0.8661
9,10,0.6614,0.6767,0.8698,0.8687


In [8]:
classes_res_df

Unnamed: 0,correct,total,accuracy,epoch
0,169,339,0.4985,1
1,282,313,0.9010,1
2,220,308,0.7143,1
0,184,339,0.5428,2
1,264,313,0.8435,2
...,...,...,...,...
1,198,313,0.6326,49
2,260,308,0.8442,49
0,265,339,0.7817,50
1,299,313,0.9553,50


**Visualize the training results**

We plot the training and validation loss, as well as the mean validation accuracy for each class.

In [9]:
plot_model_fit_loss(
    train_loss=global_res_df['training_loss'],
    val_loss=global_res_df['validation_loss'],
    subtitle="Models Details: " + CNN_MODEL.MODEL_DESCRIPTION
)

In [10]:
plot_classes_accuracy(
    classes_res_df, 
    subtitle="Models Details: " + CNN_MODEL.MODEL_DESCRIPTION
    )