In [229]:
import torch
import imp
import TransformerTrainer
import MyTransformer
from TransformerTrainer import MyTranslator
from DataModule import BaseDataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor
import random
import numpy as np
import utils

In [230]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
%env CUBLAS_WORKSPACE_CONFIG :16:8

env: CUBLAS_WORKSPACE_CONFIG=:16:8


In [231]:
DEVICE = "cuda"
BATCH_SIZE = 64
MAX_LEN = 50

torch.manual_seed(SEED)
torch.set_deterministic(True)

data_module = BaseDataModule(
    batch_size=BATCH_SIZE,
    device = DEVICE,
    data_path="./data/eng_rus.txt",
    seed=SEED
)

data_module.prepare_data()

In [205]:
params = {
    "src_vocab_size": data_module.src_vocab_len,
    "trg_vocab_size": data_module.trg_vocab_len,
    "d_model": 512,
    "n_enc_layers": 6,
    "n_dec_layers": 6,
    "n_enc_heads": 8,
    "n_dec_heads": 8,
    "enc_dropout": 0.1,
    "dec_dropout": 0.1,
    "src_pad_idx": data_module.src_pad_idx,
    "trg_pad_idx": data_module.trg_pad_idx
}

In [206]:
plmodel = MyTranslator(**params)
plmodel.to(DEVICE)

num_params = sum(p.numel() for p in plmodel.parameters() if p.requires_grad)
print(f"{num_params/1e6} млн")

57.070626 млн


In [7]:
N_EPOCHS = 12
CLIP = 1

tb_logger = pl_loggers.TensorBoardLogger('./logs/')
lr_monitor = LearningRateMonitor(logging_interval='step')
early_stop_callback = EarlyStopping(
   monitor='avg_val_loss',
   min_delta=0.01,
   patience=2,
   verbose=False,
   mode='mean'
)
trainer = Trainer(
    max_epochs=N_EPOCHS,
    gradient_clip_val=CLIP,
    progress_bar_refresh_rate=1,
    callbacks=[early_stop_callback, lr_monitor], 
    logger=tb_logger,
    log_every_n_steps=20
)
data_module.setup('fit')
trainer.fit(plmodel, data_module)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | model     | Transformer      | 57.1 M
-----------------------------------------------
57.1 M    Trainable params
0         Non-trainable params
57.1 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [8]:
trainer.save_checkpoint("models/myTransformer3.ckpt")

In [228]:
torch.save(plmodel.model.state_dict(), 'models/transformer_model.pt')

In [227]:
model = MyTranslator.load_from_checkpoint(checkpoint_path="models/myTransformer3.ckpt", **params)
model.to(DEVICE)

MyTranslator(
  (criterion): CrossEntropyLoss()
  (model): Transformer(
    (encoder): Encoder(
      (embed): Embedding(6736, 512)
      (pe): PositionalEncoder(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layers): ModuleList(
        (0): EncoderLayer(
          (norm_1): Norm()
          (norm_2): Norm()
          (attn): MultiHeadAttention(
            (q_linear): Linear(in_features=512, out_features=512, bias=True)
            (v_linear): Linear(in_features=512, out_features=512, bias=True)
            (k_linear): Linear(in_features=512, out_features=512, bias=True)
            (attention): PosAttentionLayer(
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (dropout): Dropout(p=0.1, inplace=False)
            (fc_out): Linear(in_features=512, out_features=512, bias=True)
          )
          (ff): FeedForward(
            (linear_1): Linear(in_features=512, out_features=2048, bias=True)
            (dropout): Dropout(p=0.1, inpl

## Посмотрим, как обученный трансформер справляется с переводом

In [208]:
idx = 0
src = data_module.test_iter[idx].src
trg = data_module.test_iter[idx].trg
translation = utils.translate_sentence(data_module.test_iter[idx].src, model.model, data_module.src_field, data_module.trg_field, 80, DEVICE)
print(" ".join(src))
print(" ".join(trg))
print(" ".join(translation))

the nearest airport is vnukovo international airport , 23 km from apartment clubapart on studenchenskaya 16 .
расстояние до международного аэропорта внуково от апартаментов « clubapart на студенческой , 16 » составляет 23 км .
расстояние от апартаментов до международного аэропорта внуково составляет 23 , 7 км .


In [209]:
utils.calculate_bleu(
    data = data_module.test_iter, 
    src_field = data_module.src_field, 
    trg_field = data_module.trg_field,
    model = model.model,
    device=DEVICE
)

0.26132895673405926
