In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torchinfo import summary
from config import en_id_model as mtconf
from dataset import get_tokenizers
from model import build_model
from train import train_model
from utils import TrainCheckpoint, EarlyStopping, ReduceLROnPlateau, TrainingCallback

  from .autonotebook import tqdm as notebook_tqdm


In [31]:
t_src, t_tgt = get_tokenizers(mtconf, ds_train=None, force_retrain_tokenizer=False)
model = build_model(mtconf, t_src, t_tgt)

summary(
    model,
    input_data=[
        torch.randint(0, t_src.get_vocab_size(), (mtconf.batch_size, mtconf.seq_len)),  # encoder_input
        torch.randint(0, t_tgt.get_vocab_size(), (mtconf.batch_size, mtconf.seq_len)),  # decoder_input
        torch.ones(mtconf.batch_size, 1, 1, mtconf.seq_len, dtype=torch.int),  # encoder_mask
        torch.ones(mtconf.batch_size, 1, mtconf.seq_len, mtconf.seq_len, dtype=torch.int),  # decoder_mask
    ],
    col_names=["input_size", "output_size", "num_params", "trainable"],
    depth=10,
    row_settings=["var_names"]
)

tokenizer exist, getting from: .output\tokenizer_en.json
tokenizer exist, getting from: .output\tokenizer_id.json


Layer (type (var_name))                                           Input Shape               Output Shape              Param #                   Trainable
Transformer (Transformer)                                         [10, 225]                 [10, 225, 30000]          --                        True
├─InputEmbedding (src_embed)                                      [10, 225]                 [10, 225, 512]            --                        True
│    └─Embedding (embedding)                                      [10, 225]                 [10, 225, 512]            15,360,000                True
├─PositionalEncoding (src_pos)                                    [10, 225, 512]            [10, 225, 512]            --                        --
│    └─Dropout (dropout)                                          [10, 225, 512]            [10, 225, 512]            --                        --
├─Encoder (encoder)                                               [10, 225, 512]            [10, 225, 512

In [5]:
%load_ext autoreload
%autoreload 2

tcp = TrainCheckpoint(mtconf.model_output)
es = EarlyStopping(patience=5)
rlr = ReduceLROnPlateau(factor=0.5, patience=2, cooldown=2)
callback = TrainingCallback(checkpoint=tcp, early_stop=es, reduce_lr=rlr)

history = train_model(mtconf, callback, preload=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
total sentence pair for training: 50000
tokenizer exist, getting from: .output\tokenizer_en.json
tokenizer exist, getting from: .output\tokenizer_id.json
max length of source sentence: 127
max length of target sentence: 222


epoch 8: 100%|██████████| 5000/5000 [07:00<00:00, 11.89it/s, loss=3.309]


Epoch 8 - 0:9:39.6 | train_loss=3.350558 | val_loss=4.449178 | CER=0.6688094735145569 | WER=1.0796176195144653 | BLEU=0.0
* metrics improved from inf to 4.449178




epoch 9: 100%|██████████| 5000/5000 [07:03<00:00, 11.82it/s, loss=2.923]


Epoch 9 - 0:9:48.4 | train_loss=3.176356 | val_loss=4.441678 | CER=0.7299025058746338 | WER=1.14340341091156 | BLEU=0.0
* metrics improved from 4.449178 to 4.441678




epoch 10: 100%|██████████| 5000/5000 [07:05<00:00, 11.75it/s, loss=2.598]


Epoch 10 - 0:9:47.4 | train_loss=3.011871 | val_loss=4.490134 | CER=0.7252923846244812 | WER=1.1638240814208984 | BLEU=0.0
metrics did not improve from 4.441678




epoch 11: 100%|██████████| 5000/5000 [07:01<00:00, 11.87it/s, loss=2.364]


Epoch 11 - 0:9:37.3 | train_loss=2.855515 | val_loss=4.518099 | CER=0.681349515914917 | WER=1.0991204977035522 | BLEU=0.0
scheduling LR in the next epoch from 0.000100 to 0.000050
metrics did not improve from 4.441678




epoch 12: 100%|██████████| 5000/5000 [07:01<00:00, 11.88it/s, loss=2.851]


Epoch 12 - 0:9:31.2 | train_loss=2.564801 | val_loss=4.543865 | CER=0.6809506416320801 | WER=1.112198829650879 | BLEU=0.0
metrics did not improve from 4.441678




epoch 13:   2%|▏         | 79/5000 [00:14<15:07,  5.42it/s, loss=2.316]  


KeyboardInterrupt: 