In [None]:
from functools import partial
from multiprocessing import Pool

import iuliia
import torch
from joblib import Parallel, delayed
from tokenizers import Tokenizer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
from torch import nn, optim
from torchtext import data
from torchtext.data import Field, Example, BucketIterator
from tqdm import tqdm

In [None]:
BATCH_SIZE = 16
EMBEDDING_SIZE = 200
NUM_HEADS = 8
NUM_ENCODERS = 6
NUM_DECODERS = 6
DROPOUT = 0.1
HIDDEN_DIM = 1024
LR = 0.001
LOAD_MODEL = True
NUM_EPOCH = 10
SAVE_MODEL = True

MAX_SIZE = 256
MIN_SIZE = 5

In [None]:
## DATA LOADING#
import os
print('Reading datasets from disk ...', end='')
#os.chdir(r'C:\Users\ysuho\dl_hws\data\Homework3\MHQG-like implementation')
EN_FILE = './data/corpus.en_ru.1m.en'
RU_FILE = './data/corpus.en_ru.1m.ru'
device = torch.device("cpu")

src = Field(init_token = '<sos>', eos_token = '<eos>', tokenize=None)  # No need to tokenize because of BPE
tgt = Field(init_token = '<sos>', eos_token = '<eos>', tokenize=None)  # No need to tokenize because of BPE
tgt_rev = Field(init_token = '<sos>', eos_token = '<eos>', tokenize=None)  # No need to tokenize because of BPE
named_fields = {'ru': ('src', src), 'en': ('tgt', tgt), 'en_rev': ('tgt_rev', tgt_rev)}
fields = {name: field for name, field in named_fields.values()}

src_data = open(RU_FILE).readlines()
tgt_data = open(EN_FILE).readlines()
print(' ✔')

In [4]:
print('Transliterating ...', end='')
## BPE tokenization
NUM_PROCESSES=8
pool = Pool(NUM_PROCESSES)
src_data = Parallel(n_jobs=8)(delayed(partial(iuliia.translate, schema=iuliia.WIKIPEDIA))(x) for x in tqdm(src_data))
print(' ✔')

Transliterating ...

100%|██████████| 1000000/1000000 [00:31<00:00, 32192.81it/s]


 ✔


In [5]:
print('Learning BPE ...', end='')
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
trainer = BpeTrainer(special_tokens=["<sos>", "<eos>", "<unk>", "<pad>"])
tokenizer.train_from_iterator(src_data + tgt_data, trainer=trainer)
print(' ✔')

Learning BPE ... ✔


In [6]:
def tokenize(x):
    return tokenizer.encode(x).tokens


print('BPE encoding source text ...', end='')
src_data = [tokenizer.encode(x).tokens for x in tqdm(src_data)]
print(' ✔')

print('BPE encoding target text ...', end='')
tgt_data = [tokenizer.encode(x).tokens for x in tqdm(tgt_data)]
print(' ✔')



  0%|          | 1281/1000000 [00:00<01:17, 12807.20it/s]

BPE encoding source text ...

100%|██████████| 1000000/1000000 [01:19<00:00, 12522.48it/s]
  0%|          | 1255/1000000 [00:00<01:19, 12547.59it/s]

 ✔
BPE encoding target text ...

100%|██████████| 1000000/1000000 [01:28<00:00, 11344.92it/s]

 ✔





In [7]:
# Vocab and dataset creation
print('Creating dataset ...', end='')
samples = [
    Example.fromdict({'ru': src_line, 'en': tgt_line, 'en_rev': tgt_line[::-1]}, named_fields)
    for tgt_line, src_line in tqdm(zip(tgt_data, src_data), total=len(src_data))
]

def sanitize(x):
    return MIN_SIZE <= len(x.src) <= MAX_SIZE and MIN_SIZE <= len(x.tgt) <= MAX_SIZE

dataset = data.Dataset(samples, fields, filter_pred=sanitize)
print(' ✔')

src.build_vocab(dataset)
tgt.build_vocab(dataset)
tgt_rev.build_vocab(dataset)

train_iterator, = BucketIterator.splits(
    (dataset,),
    device = device,
    batch_size=BATCH_SIZE,
    sort_key=lambda x: len(x.src) + len(x.tgt), shuffle=True)


PAD_idx = src.vocab.stoi[src.pad_token]
SOS_idx = src.vocab.stoi['<sos>']
EOS_idx = src.vocab.stoi['<eos>']

src_vocab_size = len(src.vocab.stoi)
tgt_vocab_size = len(tgt.vocab.stoi)


  3%|▎         | 25792/1000000 [00:00<00:03, 257910.41it/s]

Creating dataset ...

100%|██████████| 1000000/1000000 [00:10<00:00, 94715.88it/s]


 ✔


In [8]:
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = torch.arange(0, src_seq_length).unsqueeze(1).expand(src_seq_length, N).to(self.device)
        trg_positions = torch.arange(0, trg_seq_length).unsqueeze(1).expand(trg_seq_length, N).to(self.device)

        embed_src = self.dropout((self.src_word_embedding(src) + self.src_position_embedding(src_positions)))
        embed_trg = self.dropout((self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions)))

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out

In [9]:
class DoubleTransformerModel(nn.Module):
    def __init__(self, l2r, r2l):
        super().__init__()
        self.model_l2r = l2r
        self.model_r2l = r2l

    def forward(self, source, l2r_data, r2l_data):
        return self.model_l2r(source, l2r_data), self.model_r2l(source, r2l_data)

In [10]:
# We create two identical models, one will learn src -> tgt, the second will learn src -> tgt-reversed

l2r = Transformer(
    EMBEDDING_SIZE,
    src_vocab_size,
    tgt_vocab_size,
    PAD_idx,
    NUM_HEADS,
    NUM_ENCODERS,
    NUM_DECODERS,
    HIDDEN_DIM,
    DROPOUT,
    MAX_SIZE,
    device,
).to(device)

r2l = Transformer(
    EMBEDDING_SIZE,
    src_vocab_size,
    tgt_vocab_size,
    PAD_idx,
    NUM_HEADS,
    NUM_ENCODERS,
    NUM_DECODERS,
    HIDDEN_DIM,
    DROPOUT,
    MAX_SIZE,
    device,
).to(device)

double_transformer = DoubleTransformerModel(l2r, r2l)

In [11]:
def translate_sentence(model, sentence, src, tgt, device, max_length=50):
    # Load german tokenizer
    if type(sentence) == str:
        tokens = tokenizer.encode(iuliia.translate(sentence, schema=iuliia.WIKIPEDIA)).tokens
    else:
        tokens = [token.lower() for token in sentence]

    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, src.init_token)
    tokens.append(src.eos_token)

    text_to_indices = [src.vocab.stoi[token] for token in tokens]

    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [tgt.vocab.stoi["<sos>"]]
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)
        trg_tensor_rev = torch.LongTensor(outputs[::-1]).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor, trg_tensor_rev)
            output = output[0]+output[1]

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == tgt.vocab.stoi["<eos>"]:
            break

    translated_sentence = [tgt.vocab.itos[idx] for idx in outputs]
    return tokenizer.decode([tokenizer.token_to_id(x) for x in translated_sentence[1:]])


def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])


optimizer = optim.Adam(double_transformer.parameters(), lr=LR)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10, verbose=True
)

pad_idx = tgt.vocab.stoi["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

if LOAD_MODEL:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), double_transformer, optimizer)

sentence = "Главными движущими силами благого управления должны быть уважение закона, соблюдение принципов диалога и сотрудничества, защита демократии и поощрение прав человека"

=> Loading checkpoint


In [13]:
for epoch in range(NUM_EPOCH):
    print(f"[Epoch {epoch} / {NUM_EPOCH}]")

    double_transformer.eval()
    translated_sentence = translate_sentence(
        double_transformer, sentence, src, tgt, device, max_length=50
    )

    print(f"Translated example sentence: \n {translated_sentence}")
    double_transformer.train()
    losses = []

    for batch_idx, batch in tqdm(enumerate(train_iterator), total=len(src_data)//BATCH_SIZE):
        # Get input and targets and get to cuda
        inp_data = batch.src.to(device)
        target = batch.tgt.to(device)
        target_reverse = batch.tgt_rev.to(device)

        output_l2r, output_r2l = double_transformer(inp_data,  target[:-1, :], target_reverse[:-1, :])
        output = output_l2r + output_r2l
        output = output.reshape(-1, output.shape[2])
        target = target[1:].reshape(-1)

        optimizer.zero_grad()

        loss = criterion(output, target)
        losses.append(loss.item())

        # Back prop
        loss.backward()
        torch.nn.utils.clip_grad_norm_(double_transformer.parameters(), max_norm=1)

        optimizer.step()

    mean_loss = sum(losses) / len(losses)
    print(mean_loss)
    scheduler.step(mean_loss)
    
    if SAVE_MODEL:
        checkpoint = {
            "state_dict": double_transformer.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)


[Epoch 0 / 10]
=> Saving checkpoint
['ĠGlav', 'nymi', 'Ġdvizhu', 'shchimi', 'Ġsilami', 'Ġblago', 'go', 'Ġupravleniya', 'Ġdolzhny', 'Ġbyt', 'Ġuvazheniye', 'Ġzakona', ',', 'Ġsoblyudeniye', 'Ġprintsipov', 'Ġdialoga', 'Ġi', 'Ġsotrudnichestva', ',', 'Ġzashchita', 'Ġdemokratii', 'Ġi', 'Ġpooshch', 'reniye', 'Ġprav', 'Ġcheloveka']
[2, 4350, 101, 14784, 1329, 6693, 3531, 323, 386, 189, 87, 10594, 2514, 4, 9669, 3654, 6473, 7, 1019, 4, 6158, 4151, 7, 10938, 2941, 315, 325, 3]


IndexError: index 1 is out of bounds for dimension 0 with size 1

In [18]:
## INFERENCE

data_in = open('eval-ru-100.txt').readlines()

data_out = open('answer.txt', 'w')

data = []

for x in tqdm(data_in):
    try:
        data.append(translate_sentence(double_transformer, x, src, tgt, device))
    except:
        print('whoops')
        data.append('whoops')
data_out.write('\n'.join(data))

100%|██████████| 100/100 [02:25<00:00,  1.46s/it]


29010

In [20]:
data_out.close()

[' EstoniaGA nach BernGA arri sro textsgalter revel Pol sq lens lens AG Bangvra dealt obsticiencyERT implement spro Untilakon Entertainment ProgGO lens tryingboundboundbound Svetmun subsidiariesERT ProsparaGA zeroGA Svet Bangsupp Until doctrine return suspected world', 'alizing phenodzhgedgedfies pribfiesdzh Clearged Prosoughtvragedged zeroalizingvra Sveticiency vodagedativvra Pros Bangboundiciency texts householdsiciency 2006 Pros Pros households Banggalter Pros police world textsgedalizinggedgedvra Pros expertsiciency', 'gediciency Nim Pros households�ged� Pros Prosged Svet 2006iciency bear occupy households 2006 households BernERTvra Prosiciency sqadianactionGAbound Prosmun Prosged Pros Pros Pros SN Svet Prosged Pros Pros Prosgot texts overwhel Pros households Pros overwhel', ' plav Faficationsiciency Pros Rec protects Svet encouraging Svet Svet ofitsvra collaborationadian roziciencyalizing householdsalizing rozvraovogo associatevra householdsvrapara Prosvra Svet 2006vra bear textsa