In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torchinfo import summary
from config import en_id_model as mtconf, get_default_device
from dataset import get_tokenizers
from model import build_model
from train import train_model
from utils import TrainCheckpoint, EarlyStopping, ReduceLROnPlateau, TrainingCallback

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
t_src, t_tgt = get_tokenizers(mtconf, ds_train=None, force_retrain_tokenizer=False)
model = build_model(mtconf, t_src, t_tgt)

summary(
    model,
    input_data=[
        torch.randint(0, t_src.get_vocab_size(), (mtconf.batch_size, mtconf.seq_len)),  # encoder_input
        torch.randint(0, t_tgt.get_vocab_size(), (mtconf.batch_size, mtconf.seq_len)),  # decoder_input
        torch.ones(mtconf.batch_size, 1, 1, mtconf.seq_len, dtype=torch.int),  # encoder_mask
        torch.ones(mtconf.batch_size, 1, mtconf.seq_len, mtconf.seq_len, dtype=torch.int),  # decoder_mask
    ],
    col_names=["input_size", "output_size", "num_params", "trainable"],
    depth=10,
    row_settings=["var_names"]
)

tokenizer exist, getting from: .output\tokenizer_en.json
tokenizer exist, getting from: .output\tokenizer_id.json


Layer (type (var_name))                                           Input Shape               Output Shape              Param #                   Trainable
Transformer (Transformer)                                         [64, 50]                  [64, 50, 30000]           --                        True
├─InputEmbedding (src_embed)                                      [64, 50]                  [64, 50, 256]             --                        True
│    └─Embedding (embedding)                                      [64, 50]                  [64, 50, 256]             7,680,000                 True
├─PositionalEncoding (src_pos)                                    [64, 50, 256]             [64, 50, 256]             --                        --
│    └─Dropout (dropout)                                          [64, 50, 256]             [64, 50, 256]             --                        --
├─Encoder (encoder)                                               [64, 50, 256]             [64, 50, 256]

In [3]:
get_default_device()

device(type='cuda')

In [4]:
%load_ext autoreload
%autoreload 2

tcp = TrainCheckpoint(mtconf.model_output)
es = EarlyStopping(patience=5)
rlr = ReduceLROnPlateau(factor=0.5, patience=2, cooldown=2)
callback = TrainingCallback(checkpoint=tcp, early_stop=es, reduce_lr=rlr)

history = train_model(mtconf, callback, preload=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload




total sentence pair for training = 1000000
tokenizer exist, getting from: .output\tokenizer_en.json
tokenizer exist, getting from: .output\tokenizer_id.json
max length of source sentence: 579
max length of target sentence: 535


Map: 100%|██████████| 1000000/1000000 [00:25<00:00, 39244.53 examples/s]
Filter: 100%|██████████| 1000000/1000000 [00:03<00:00, 264003.08 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 27851.64 examples/s]
Filter: 100%|██████████| 2000/2000 [00:00<00:00, 103927.45 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 29864.60 examples/s]
Filter: 100%|██████████| 2000/2000 [00:00<00:00, 9618.73 examples/s] 


total sentence pair for training after filtering = 993856


epoch 1: 100%|██████████| 15529/15529 [13:17<00:00, 19.47it/s, loss=3.832]


Epoch 1 - 0:13:39.6 | train_loss=4.545930 | val_loss=3.924799 | CER=1.1001691818237305 | WER=2.28844952583313 | BLEU=0.030222974717617035
* metrics improved from inf to 3.924799




epoch 2: 100%|██████████| 15529/15529 [14:41<00:00, 17.62it/s, loss=3.461]


Epoch 2 - 0:15:3.4 | train_loss=3.608983 | val_loss=3.542757 | CER=1.0269685983657837 | WER=2.3082995414733887 | BLEU=0.03882858529686928
* metrics improved from 3.924799 to 3.542757




epoch 3: 100%|██████████| 15529/15529 [14:40<00:00, 17.64it/s, loss=3.624]


Epoch 3 - 0:15:3.0 | train_loss=3.343867 | val_loss=3.395112 | CER=0.9589769244194031 | WER=2.345029592514038 | BLEU=0.040664542466402054
* metrics improved from 3.542757 to 3.395112




epoch 4: 100%|██████████| 15529/15529 [14:44<00:00, 17.56it/s, loss=3.477]


Epoch 4 - 0:15:5.8 | train_loss=3.203865 | val_loss=3.309647 | CER=0.8717939257621765 | WER=2.1752891540527344 | BLEU=0.04519190639257431
* metrics improved from 3.395112 to 3.309647




epoch 5: 100%|██████████| 15529/15529 [13:30<00:00, 19.15it/s, loss=2.730]


Epoch 5 - 0:13:50.5 | train_loss=3.110428 | val_loss=3.269497 | CER=0.9205103516578674 | WER=2.244373321533203 | BLEU=0.045407358556985855
* metrics improved from 3.309647 to 3.269497




epoch 6: 100%|██████████| 15529/15529 [12:50<00:00, 20.15it/s, loss=3.029]


Epoch 6 - 0:13:10.9 | train_loss=3.041237 | val_loss=3.239122 | CER=0.8776869773864746 | WER=2.195607900619507 | BLEU=0.04744183272123337
* metrics improved from 3.269497 to 3.239122




epoch 7: 100%|██████████| 15529/15529 [13:05<00:00, 19.76it/s, loss=3.186]


Epoch 7 - 0:13:25.6 | train_loss=2.986087 | val_loss=3.211819 | CER=0.931972324848175 | WER=2.2969677448272705 | BLEU=0.046742040663957596
* metrics improved from 3.239122 to 3.211819




epoch 8: 100%|██████████| 15529/15529 [13:05<00:00, 19.77it/s, loss=2.963]


Epoch 8 - 0:13:25.1 | train_loss=2.941195 | val_loss=3.192698 | CER=0.8660329580307007 | WER=2.1448891162872314 | BLEU=0.04849379137158394
* metrics improved from 3.211819 to 3.192698




epoch 9: 100%|██████████| 15529/15529 [13:05<00:00, 19.77it/s, loss=2.722]


Epoch 9 - 0:13:25.3 | train_loss=2.903594 | val_loss=3.184977 | CER=1.0348060131072998 | WER=2.520944118499756 | BLEU=0.04282823204994202
* metrics improved from 3.192698 to 3.184977




epoch 10: 100%|██████████| 15529/15529 [13:05<00:00, 19.77it/s, loss=2.937]


Epoch 10 - 0:13:24.9 | train_loss=2.871603 | val_loss=3.179969 | CER=0.6643142700195312 | WER=1.521647334098816 | BLEU=0.06790437549352646
* metrics improved from 3.184977 to 3.179969




epoch 11: 100%|██████████| 15529/15529 [13:05<00:00, 19.77it/s, loss=2.813]


Epoch 11 - 0:13:25.8 | train_loss=2.842298 | val_loss=3.182621 | CER=0.8091071844100952 | WER=1.8688652515411377 | BLEU=0.05629277229309082
metrics did not improve from 3.179969




epoch 12: 100%|██████████| 15529/15529 [13:05<00:00, 19.77it/s, loss=3.042]


Epoch 12 - 0:13:25.0 | train_loss=2.816740 | val_loss=3.184168 | CER=0.6622259020805359 | WER=1.5068771839141846 | BLEU=0.06821643561124802
Reducing LR in the next epoch from 0.000100 to 0.000050
metrics did not improve from 3.179969




epoch 13: 100%|██████████| 15529/15529 [13:05<00:00, 19.77it/s, loss=3.179]


Epoch 13 - 0:13:24.6 | train_loss=2.745396 | val_loss=3.168913 | CER=0.5856407284736633 | WER=1.2482025623321533 | BLEU=0.0833871141076088
* metrics improved from 3.179969 to 3.168913




epoch 14: 100%|██████████| 15529/15529 [13:00<00:00, 19.89it/s, loss=2.818]


Epoch 14 - 0:13:19.9 | train_loss=2.725329 | val_loss=3.171391 | CER=0.6765923500061035 | WER=1.5023444890975952 | BLEU=0.06990515440702438
metrics did not improve from 3.168913




epoch 15: 100%|██████████| 15529/15529 [12:28<00:00, 20.75it/s, loss=2.852]


Epoch 15 - 0:12:47.8 | train_loss=2.711651 | val_loss=3.177461 | CER=0.655792772769928 | WER=1.450609564781189 | BLEU=0.07251495867967606
metrics did not improve from 3.168913




epoch 16: 100%|██████████| 15529/15529 [12:35<00:00, 20.56it/s, loss=2.772]


Epoch 16 - 0:12:54.6 | train_loss=2.699510 | val_loss=3.181434 | CER=0.6938153505325317 | WER=1.591200351715088 | BLEU=0.06519240885972977
metrics did not improve from 3.168913




epoch 17: 100%|██████████| 15529/15529 [12:18<00:00, 21.03it/s, loss=2.480]


Epoch 17 - 0:12:36.8 | train_loss=2.687878 | val_loss=3.188818 | CER=0.6042919158935547 | WER=1.3137699365615845 | BLEU=0.08063607662916183
Reducing LR in the next epoch from 0.000050 to 0.000025
metrics did not improve from 3.168913




epoch 18: 100%|██████████| 15529/15529 [12:50<00:00, 20.15it/s, loss=3.071]


Epoch 18 - 0:13:11.3 | train_loss=2.650518 | val_loss=3.190245 | CER=0.712478518486023 | WER=1.5928415060043335 | BLEU=0.06702635437250137
metrics did not improve from 3.168913
Early stopping, no improvement in the last 5 epochs




epoch 19:  55%|█████▍    | 8497/15529 [06:57<05:45, 20.34it/s, loss=2.832]


KeyboardInterrupt: 