In [16]:
import torch
from TransformerTrainer import BaseDataModule, MyTranslator
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor
import random
import numpy as np
import utils

In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
%env CUBLAS_WORKSPACE_CONFIG :16:8

env: CUBLAS_WORKSPACE_CONFIG=:16:8


':16:8'

In [4]:

DEVICE = "cuda"
BATCH_SIZE = 64
MAX_LEN = 50

torch.manual_seed(SEED)
torch.set_deterministic(True)

data_module = BaseDataModule(
    batch_size=BATCH_SIZE,
    device = DEVICE,
    data_path="./data/eng_rus.txt",
    seed=SEED
)

data_module.prepare_data()

In [5]:
params = {
    "src_vocab_size": data_module.src_vocab_len,
    "trg_vocab_size": data_module.trg_vocab_len,
    "d_model": 512,
    "n_enc_layers": 6,
    "n_dec_layers": 6,
    "n_heads": 8,
    "enc_dropout": 0.1,
    "dec_dropout": 0.1,
    "src_pad_idx": data_module.src_pad_idx,
    "trg_pad_idx": data_module.trg_pad_idx
}

In [6]:
model = MyTranslator(**params)
model.to(DEVICE)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{num_params/1e6} млн")

57.070626 млн


In [7]:
N_EPOCHS = 12
CLIP = 1

tb_logger = pl_loggers.TensorBoardLogger('./logs/')
lr_monitor = LearningRateMonitor(logging_interval='step')
early_stop_callback = EarlyStopping(
   monitor='avg_val_loss',
   min_delta=0.01,
   patience=2,
   verbose=False,
   mode='mean'
)
trainer = Trainer(
    max_epochs=N_EPOCHS,
    gradient_clip_val=CLIP,
    progress_bar_refresh_rate=1,
    callbacks=[early_stop_callback, lr_monitor], 
    logger=tb_logger,
    log_every_n_steps=20
)
data_module.setup('fit')
trainer.fit(model, data_module)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | model     | Transformer      | 57.1 M
-----------------------------------------------
57.1 M    Trainable params
0         Non-trainable params
57.1 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [8]:
trainer.save_checkpoint("models/myTransformer3.ckpt")

In [None]:
#model = MyTranslator.load_from_checkpoint(checkpoint_path="models/myTransformer3.ckpt", **params)

In [9]:
src_field = data_module.src_field
trg_field = data_module.trg_field
src_pad_idx = data_module.src_pad_idx
trg_pad_idx = data_module.trg_pad_idx
trg_eos_idx = data_module.trg_eos_idx
trg_unk_idx = src_field.vocab.stoi[src_field.unk_token]
init_idx = src_field.vocab.stoi[src_field.init_token]

In [13]:
def translate(sentence, max_len = 80):
    src_tokens = [src_field.init_token] + sentence + [src_field.eos_token]
    src_indexes = [src_field.vocab.stoi[token] for token in src_tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(DEVICE)
    src_mask = model.model.make_src_mask(src_tensor, src_pad_idx)
    with torch.no_grad():
        e_outputs = model.model.encoder(src_tensor, src_mask)
    outputs = torch.zeros(max_len).type_as(src_tensor)
    outputs[0] = torch.LongTensor([init_idx])

    for i in range(1, max_len):
        trg_mask = model.model.make_trg_mask(outputs[:i].unsqueeze(0), trg_pad_idx)
        with torch.no_grad():
            out = model.model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask)
            probs = model.model.out(out)
        outputs[i] = probs.argmax(2)[:,-1].item()
        if outputs[i] == trg_eos_idx:
            break
    
    return [trg_field.vocab.itos[i] for i in outputs if i != trg_unk_idx][1:-1]

In [14]:
idx = 3
src = data_module.test_iter[idx].src
trg = data_module.test_iter[idx].trg
translation = translate(data_module.test_iter[idx].src)
print(" ".join(src))
print(" ".join(trg))
print(" ".join(translation))

set right on the seafront in ostia , a 30 - minute train ride from rome ' s historic centre , hotel la scaletta offers rooms with air conditioning , wi - fi access , and a flat - screen tv with satellite channels .
отель la scaletta располагается непосредственно на берегу моря в остии , в 30 минутах езды от исторического центра рима . здесь вам предложат кондиционированные номера , оснащенные беспроводным доступом в интернет , а также телевизором с плоским экраном и спутниковыми каналами .
отель типа « постель и завтрак » la расположен на побережье , всего в 30 минутах езды от исторического центра города рима , в которых можно посмотреть телевизор с плоским экраном и спутниковыми каналами . к услугам гостей номера с плоским экраном и кондиционером и кондиционером .


In [19]:
import imp
imp.reload(utils)

utils.calculate_bleu(
    data = data_module.test_iter, 
    src_field = data_module.src_field, 
    trg_field = data_module.trg_field,
    model = model.model,
    device=DEVICE
)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7500.0), HTML(value='')))




0.26149169442227