In [28]:
!pip install torchtune torchao -q

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer
from tokenizers import decoders

import nltk
from nltk.tokenize import word_tokenize

import re
import os
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from collections import Counter

from torchtune.modules import RotaryPositionalEmbeddings
from torch.nn import Transformer
import matplotlib.pyplot as plt
%matplotlib inline

### Model

In [30]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [31]:
!wget https://data.statmt.org/opus-100-corpus/v1.0/supervised/en-ru/opus.en-ru-train.ru
!wget https://data.statmt.org/opus-100-corpus/v1.0/supervised/en-ru/opus.en-ru-train.en
!wget https://data.statmt.org/opus-100-corpus/v1.0/supervised/en-ru/opus.en-ru-test.ru
!wget https://data.statmt.org/opus-100-corpus/v1.0/supervised/en-ru/opus.en-ru-test.en

--2025-04-02 13:23:43--  https://data.statmt.org/opus-100-corpus/v1.0/supervised/en-ru/opus.en-ru-train.ru
Resolving data.statmt.org (data.statmt.org)... 129.215.32.28
Connecting to data.statmt.org (data.statmt.org)|129.215.32.28|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 121340806 (116M)
Saving to: ‘opus.en-ru-train.ru.1’


2025-04-02 13:23:48 (27.1 MB/s) - ‘opus.en-ru-train.ru.1’ saved [121340806/121340806]

--2025-04-02 13:23:48--  https://data.statmt.org/opus-100-corpus/v1.0/supervised/en-ru/opus.en-ru-train.en
Resolving data.statmt.org (data.statmt.org)... 129.215.32.28
Connecting to data.statmt.org (data.statmt.org)|129.215.32.28|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 67760131 (65M)
Saving to: ‘opus.en-ru-train.en.1’


2025-04-02 13:23:52 (23.4 MB/s) - ‘opus.en-ru-train.en.1’ saved [67760131/67760131]

--2025-04-02 13:23:52--  https://data.statmt.org/opus-100-corpus/v1.0/supervised/en-ru/opus.en-ru-test.ru
Resolvin

In [32]:
text = open('opus.en-ru-train.ru').read().replace('\xa0', ' ')
f = open('opus.en-ru-train.ru', 'w')
f.write(text)
f.close()

In [33]:
en_sents = open('opus.en-ru-train.en').read().splitlines()
ru_sents = open('opus.en-ru-train.ru').read().splitlines()

In [34]:
tokenizer_en = Tokenizer(BPE())
tokenizer_en.pre_tokenizer = Whitespace()

trainer_en = BpeTrainer(special_tokens=["[PAD]", "[BOS]", "[EOS]"], end_of_word_suffix='</w>')
tokenizer_en.train(files=["opus.en-ru-train.en"], trainer=trainer_en)

tokenizer_ru = Tokenizer(BPE())
tokenizer_ru.pre_tokenizer = Whitespace()

trainer_ru = BpeTrainer(special_tokens=["[PAD]"], end_of_word_suffix='</w>')
tokenizer_ru.train(files=["opus.en-ru-train.ru"], trainer=trainer_ru)

In [35]:
tokenizer_en.decoder = decoders.BPEDecoder()
tokenizer_ru.decoder = decoders.BPEDecoder()

In [36]:
def encode(text, tokenizer, max_len, encoder=False):
    if encoder:
        return tokenizer.encode(text).ids[:max_len]
    else:
        return [tokenizer.token_to_id('[BOS]')] + tokenizer.encode(text).ids[:max_len] + [tokenizer.token_to_id('[EOS]')]

In [37]:
PAD_IDX = tokenizer_en.token_to_id('[PAD]')
max_len_en, max_len_ru = 47, 48 
X_ru = [encode(t, tokenizer_ru, max_len_ru, encoder=True) for t in ru_sents]
X_en = [encode(t, tokenizer_en, max_len_en) for t in en_sents]

In [38]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, texts_ru, texts_en):
        self.texts_ru = [torch.LongTensor(sent) for sent in texts_ru]
        self.texts_ru = torch.nn.utils.rnn.pad_sequence(self.texts_ru, batch_first=True, padding_value=PAD_IDX)

        self.texts_en = [torch.LongTensor(sent) for sent in texts_en]
        self.texts_en = torch.nn.utils.rnn.pad_sequence(self.texts_en, batch_first=True, padding_value=PAD_IDX)

        self.length = len(texts_ru)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        ids_ru = self.texts_ru[index]
        ids_en = self.texts_en[index]
        return ids_ru, ids_en

In [39]:
X_ru_train, X_ru_valid, X_en_train, X_en_valid = train_test_split(X_ru, X_en, test_size=0.05)

In [40]:
class TransformerEncoderDecoder(nn.Module):
    def __init__(self, vocab_size_enc, vocab_size_dec, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.embedding_enc = nn.Embedding(vocab_size_enc, embed_dim)
        self.embedding_dec = nn.Embedding(vocab_size_dec, embed_dim)
        self.positional_encoding = RotaryPositionalEmbeddings(embed_dim // num_heads, max_seq_len=128)

        self.transformer = Transformer(
            d_model=embed_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True
        )

        self.output_layer = nn.Linear(embed_dim, vocab_size_dec)

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None):

        src_embedded = self.embedding_enc(src)
        B,S,E = src_embedded.shape
        src_embedded = self.positional_encoding(src_embedded.view(B,S,self.num_heads, E//self.num_heads)).view(B,S,E)

        tgt_embedded = self.embedding_dec(tgt)
        B,S,E = tgt_embedded.shape
        tgt_embedded = self.positional_encoding(tgt_embedded.view(B,S,self.num_heads, E//self.num_heads)).view(B,S,E)


        tgt_mask = (~torch.tril(torch.ones((S, S), dtype=torch.bool))).to(DEVICE)

        encoder_output = self.transformer.encoder(
            src_embedded,
            src_key_padding_mask=src_key_padding_mask
        )

        decoder_output = self.transformer.decoder(
            tgt_embedded,
            encoder_output,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )

        output = self.output_layer(decoder_output)
        return output

In [41]:
vocab_size_enc = tokenizer_en.get_vocab_size()
vocab_size_dec = tokenizer_ru.get_vocab_size()
embed_dim = 256
num_heads = 8
ff_dim = embed_dim*4
num_layers = 4

batch_size = 100

In [42]:
model = TransformerEncoderDecoder(vocab_size_enc=tokenizer_ru.get_vocab_size(), vocab_size_dec=tokenizer_en.get_vocab_size(), embed_dim=256, num_heads=8, ff_dim=256*4, num_layers=4)

training_set = Dataset(X_ru_train, X_en_train)
training_generator = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)

valid_set = Dataset(X_ru_valid, X_en_valid)
valid_generator = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False)

In [43]:
from time import time
def train(model, iterator, optimizer, criterion, scheduler, print_every=100):

    epoch_loss = []
    ac = []

    model.train()

    for i, (texts_en, texts_ru) in enumerate(iterator):
        texts_en = texts_en.to(DEVICE)
        texts_ru = texts_ru.to(DEVICE)
        texts_ru_input = texts_ru[:,:-1].to(DEVICE)
        texts_ru_out = texts_ru[:, 1:].to(DEVICE)
        src_padding_mask = (texts_en == PAD_IDX).to(DEVICE)
        tgt_padding_mask = (texts_ru_input == PAD_IDX).to(DEVICE)


        logits = model(texts_en, texts_ru_input, src_padding_mask, tgt_padding_mask)
        optimizer.zero_grad()
        B,S,C = logits.shape
        loss = loss_fn(logits.reshape(B*S, C), texts_ru_out.reshape(B*S))
        loss.backward()
        optimizer.step()
        scheduler.step()
        epoch_loss.append(loss.item())

        if not (i+1) % print_every:
            print(f'Loss: {np.mean(epoch_loss)};')

    return np.mean(epoch_loss)

def evaluate(model, iterator, criterion):

    epoch_loss = []
    epoch_f1 = []

    model.eval()
    with torch.no_grad():
        for i, (texts_en, texts_ru) in enumerate(iterator):
            texts_en = texts_en.to(DEVICE)
            texts_ru = texts_ru.to(DEVICE)
            texts_ru_input = texts_ru[:,:-1].to(DEVICE)
            texts_ru_out = texts_ru[:, 1:].to(DEVICE)
            src_padding_mask = (texts_en == PAD_IDX).to(DEVICE)
            tgt_padding_mask = (texts_ru_input == PAD_IDX).to(DEVICE)

            logits = model(texts_en, texts_ru_input, src_padding_mask, tgt_padding_mask)

            B,S,C = logits.shape
            loss = loss_fn(logits.reshape(B*S, C), texts_ru_out.reshape(B*S))
            epoch_loss.append(loss.item())

    return np.mean(epoch_loss)

In [44]:
@torch.no_grad
def translate(text):

    input_ids = tokenizer_ru.encode(text).ids[:max_len_en]
    output_ids = [tokenizer_en.token_to_id('[BOS]')]

    input_ids_pad = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(input_ids)], batch_first=True).to(DEVICE)
    output_ids_pad = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(output_ids)], batch_first=True).to(DEVICE)

    src_padding_mask = (input_ids_pad == PAD_IDX).to(DEVICE)
    tgt_padding_mask = (output_ids_pad == PAD_IDX).to(DEVICE)

    logits = model(input_ids_pad, output_ids_pad, src_padding_mask, tgt_padding_mask)

    pred = logits.argmax(2).item()

    while pred not in [tokenizer_en.token_to_id('[EOS]'), tokenizer_en.token_to_id('[PAD]')] and len(output_ids) < 100:
        output_ids.append(pred)
        output_ids_pad = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(output_ids)], batch_first=True).to(DEVICE)
        tgt_padding_mask = (output_ids_pad == PAD_IDX).to(DEVICE)

        logits = model(input_ids_pad, output_ids_pad, src_padding_mask, tgt_padding_mask)
        pred = logits.argmax(2).view(-1)[-1].item()

    return tokenizer_en.decoder.decode([tokenizer_en.id_to_token(i) for i in output_ids[1:]])

In [45]:
model = model.to(DEVICE)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

NUM_EPOCHS = 20
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.05,
                                                steps_per_epoch=len(training_generator), epochs=NUM_EPOCHS)

In [46]:
from timeit import default_timer as timer

losses = []
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train(model, training_generator, optimizer, loss_fn, scheduler)
    end_time = timer()
    val_loss = evaluate(model, valid_generator, loss_fn)

    if not losses:
        print(f'First epoch - {val_loss}, saving model..')
        torch.save(model, 'model')

    elif val_loss < min(losses):
        print(f'Improved from {min(losses)} to {val_loss}, saving model..')
        torch.save(model, 'model')

    losses.append(val_loss)

    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, \
           "f"Epoch time={(end_time-start_time):.3f}s"))

    print(translate("Солнце светит ярко сегодня."))
    print(translate('Она читает книгу в саду.'))
    print(translate('Вчера мы ходили в кино.'))
    print(translate('Ты любишь путешествовать?'))

Loss: 9.223043956756591;
Loss: 8.68271671295166;
Loss: 8.243358964920043;
Loss: 7.914159022569656;
Loss: 7.664299178123474;
Loss: 7.481049892902374;
Loss: 7.327651720728193;
Loss: 7.204285857677459;
Loss: 7.098638204468621;
Loss: 7.006090016841888;
Loss: 6.92505360993472;
Loss: 6.851592889229456;
Loss: 6.783779351161076;
Loss: 6.722773492676871;
Loss: 6.665105518658956;
Loss: 6.610862447321415;
Loss: 6.559599141233107;
Loss: 6.510222001340654;
Loss: 6.46438165162739;
Loss: 6.421387293338776;
Loss: 6.380620034989857;
Loss: 6.341961101835424;
Loss: 6.30358838122824;
Loss: 6.266447806755702;
Loss: 6.2312002338409425;
Loss: 6.196645984649658;
Loss: 6.163078776995341;
Loss: 6.129840738773346;
Loss: 6.10002393393681;
Loss: 6.070088535626729;
Loss: 6.040658587332695;
Loss: 6.01228588566184;
Loss: 5.98508316487977;
Loss: 5.95818591300179;
Loss: 5.932806850024632;
Loss: 5.907041537761688;
Loss: 5.881873382491034;
Loss: 5.8578438992249335;
Loss: 5.833890034357707;
Loss: 5.8103335326910015;
Loss:

  output = torch._nested_tensor_from_mask(


First epoch - 3.9933562245368956, saving model..
Epoch: 1, Train loss: 5.005, Val loss: 3.993,            Epoch time=1808.161s
The most important thing is the only one day .
She ' s in the house .
We ' ve been in the house .
You love a little ?
Loss: 4.022424058914185;
Loss: 4.006895552873612;
Loss: 4.000445860226949;
Loss: 3.9973084235191347;
Loss: 3.997392204284668;
Loss: 3.99490420182546;
Loss: 3.993723681994847;
Loss: 3.98845636844635;
Loss: 3.986220279534658;
Loss: 3.9846769127845763;
Loss: 3.980241358930414;
Loss: 3.9763664003213246;
Loss: 3.9755400998775774;
Loss: 3.9715638160705566;
Loss: 3.9671533416112266;
Loss: 3.9643355959653857;
Loss: 3.9605833091455347;
Loss: 3.956493777566486;
Loss: 3.9520689365738315;
Loss: 3.948674214839935;
Loss: 3.9455024447895233;
Loss: 3.942200259295377;
Loss: 3.940243249976117;
Loss: 3.937986590862274;
Loss: 3.9332942165374756;
Loss: 3.92935200544504;
Loss: 3.927350755267673;
Loss: 3.924936270884105;
Loss: 3.922020130979604;
Loss: 3.91916581432024

KeyboardInterrupt: 

### BLEU

In [19]:
text = open('opus.en-ru-test.ru').read().replace('\xa0', ' ')
f = open('opus.en-ru-test.ru', 'w')
f.write(text)
f.close()

In [20]:
en_sents_test = open('opus.en-ru-test.en').read().splitlines()[:100]
ru_sents_test = open('opus.en-ru-test.ru').read().splitlines()[:100]

In [58]:
bleu_scores = []
total_bleu = 0

for i in tqdm(range(len(ru_sents_test))):
    gold = word_tokenize(en_sents_test[i])
    pred = word_tokenize(translate(ru_sents_test[i]))
    bleu_score = nltk.translate.bleu_score.sentence_bleu([gold], pred, auto_reweigh=True)
    bleu_scores.append((bleu_score, ru_sents_test[i], en_sents_test[i], pred))
    total_bleu += bleu_score

average_bleu = total_bleu / len(ru_sents_test)
average_bleu

100%|██████████| 100/100 [00:16<00:00,  6.06it/s]


0.3786726138662766

In [60]:
bleu_scores.sort(reverse=True, key=lambda x: x[0])
for (score, ru_sent, en_sent, pred) in bleu_scores[:5]:
    print(score)
    print(ru_sent)
    print(en_sent)
    print(pred)
    print('--------------------')

1.0
12844
12844
['12844']
--------------------
0.7598356856515925
Они удручают ещё больше.
They're more depressing.
['They', "'", 're', 'still', 'breathing', '.']
--------------------
0.7598356856515925
Но разочарование «Хезболлы» превратилось в интенсивноебеспокойство, когда сирийцы восстали против Асада.
But Hezbollah’s disappointment turned to intense concernwhen Syrians rebelled against Assad.
['But', 'the', '“', 'Hezbollah', '”', 'turned', 'into', 'intense', 'concern', 'when', 'Syrian', 'settlers', 'became', 'a', 'fight', 'against', 'Aubarak', '.']
--------------------
0.7311104457090247
И как ты только справляешься, папа, таская эти коробки взад-вперед целый день.
I don't know how you do it, Pop, carrying these boxes around every day.
['And', 'as', 'soon', 'as', 'you', "'", 're', 'only', 'doing', ',', 'Dad', ',', 'these', 'are', 'the', 'boxes', 'of', 'the', 'box', 'forward', '.']
--------------------
0.7311104457090247
Коллекция администрации Адамса.
A collection from the Adams a

Некоторые предложения все-таки совсем неправильно перевелись (например, второе). Но, например, более длинные предложения как будто бы правда более менее переводятся нормально, и видно, что модель заучила какие-то конструкции, но возможно не всегда их правильно пока применяет.

### Translate

In [61]:
@torch.no_grad
def translate(texts):
    input_ids = [tokenizer_ru.encode(text).ids[:max_len_en] for text in texts]
    output_ids = [[tokenizer_en.token_to_id('[BOS]')] for _ in texts]

    input_ids_pad = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(ids) for ids in input_ids], batch_first=True).to(DEVICE)
    output_ids_pad = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(ids) for ids in output_ids], batch_first=True).to(DEVICE)

    src_padding_mask = (input_ids_pad == PAD_IDX).to(DEVICE)
    tgt_padding_mask = (output_ids_pad == PAD_IDX).to(DEVICE)

    logits = model(input_ids_pad, output_ids_pad, src_padding_mask, tgt_padding_mask)

    preds = logits.argmax(2)[:, 0].unsqueeze(1)

    done_flags = [False] * len(texts)
    max_len = 100

    while not all(done_flags):
        for i in range(len(texts)):
            if not done_flags[i]:
                output_ids[i].append(preds[i].item())

        output_ids_pad = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(ids) for ids in output_ids], batch_first=True).to(DEVICE)
        tgt_padding_mask = (output_ids_pad == PAD_IDX).to(DEVICE)

        logits = model(input_ids_pad, output_ids_pad, src_padding_mask, tgt_padding_mask)

        preds = logits.argmax(2)[:, -1]

        for i in range(len(texts)):
            if not done_flags[i]:
                if preds[i].item() == tokenizer_en.token_to_id('[EOS]') or len(output_ids[i]) >= max_len:
                    done_flags[i] = True

    translated_texts = [tokenizer_en.decoder.decode([tokenizer_en.id_to_token(i) for i in output_ids[i][1:]]) for i in range(len(texts))]

    return translated_texts

Проверим, что на 100 такие же результаты:

In [69]:
en_sents_test = open('opus.en-ru-test.en').read().splitlines()[:100]
ru_sents_test = open('opus.en-ru-test.ru').read().splitlines()[:100]

In [70]:
batch_size = 32
bleu_scores = []
total_bleu = 0

def batchify(texts, batch_size):
    for i in range(0, len(texts), batch_size):
        yield texts[i:i + batch_size]

for batch_start in tqdm(range(0, len(ru_sents_test), batch_size)):
    batch_ru = ru_sents_test[batch_start:batch_start + batch_size]
    batch_en = en_sents_test[batch_start:batch_start + batch_size]

    batch_pred = translate(batch_ru)

    for (ru_sent, en_sent, pred) in zip(batch_ru, batch_en, batch_pred):
        gold = word_tokenize(en_sent)
        pred = word_tokenize(pred)
        bleu_score = nltk.translate.bleu_score.sentence_bleu([gold], pred, auto_reweigh=True)
        
        bleu_scores.append((bleu_score, ru_sent, en_sent, pred))
        total_bleu += bleu_score

average_bleu = total_bleu / len(ru_sents_test)
average_bleu

100%|██████████| 4/4 [00:02<00:00,  1.56it/s]


0.37835365738321547

Теперь на полном:

In [65]:
en_sents_test = open('opus.en-ru-test.en').read().splitlines()
ru_sents_test = open('opus.en-ru-test.ru').read().splitlines()

In [68]:
batch_size = 32
bleu_scores = []
total_bleu = 0

def batchify(texts, batch_size):
    for i in range(0, len(texts), batch_size):
        yield texts[i:i + batch_size]

for batch_start in tqdm(range(0, len(ru_sents_test), batch_size)):
    batch_ru = ru_sents_test[batch_start:batch_start + batch_size]
    batch_en = en_sents_test[batch_start:batch_start + batch_size]

    batch_pred = translate(batch_ru)

    for (ru_sent, en_sent, pred) in zip(batch_ru, batch_en, batch_pred):
        gold = word_tokenize(en_sent)
        pred = word_tokenize(pred)
        bleu_score = nltk.translate.bleu_score.sentence_bleu([gold], pred, auto_reweigh=True)
        
        bleu_scores.append((bleu_score, ru_sent, en_sent, pred))
        total_bleu += bleu_score

average_bleu = total_bleu / len(ru_sents_test)
average_bleu

100%|██████████| 63/63 [00:47<00:00,  1.32it/s]


0.40471688852092874

### Back Translation

Back translation — это техника, которая помогает улучшить модель машинного перевода, когда параллельных данных недостаточно. Суть в том, чтобы взять уже существующие переводы с одного языка на другой и перевести их обратно, создавая дополнительные примеры для обучения модели. Этот метод позволяет использовать одноязычные тексты для создания синтетических пар данных, что особенно полезно, когда реальных параллельных данных не хватает. В отличие от традиционного подхода, при котором используется только существующий параллельный корпус, back translation расширяет обучающий набор, генерируя дополнительные переводы, которые могут быть использованы для обучения модели.

Для применения этой техники нужно сначала обучить модель перевода в одном направлении, например, как в семинаре, с английского на русский. После этого тренируется модель для обратного перевода с русского на английский. Затем, используя модель ru -> en, переводим одноязычные русские тексты на английский. Это даёт нам новые синтетические пары данных: оригинальные русские предложения и переведённые обратно на английский. Эти новые пары можно добавить в параллельный корпус.

Чтобы применить back translation, нужно дважды обучить модели: одну для перевода с английского на русский и другую для обратного перевода. Всего потребуется минимум два запуска обучения — первый для начальной модели и второй с новыми данными, полученными через back translation.