В этом семинаре рассмотрим трансформер для задачи перевода текста с одного языка на другой. Возьмем тексты на английском и будем переводить их на французский.

In [None]:
import numpy as np
import math
import pandas as pd
import re
import string
from collections import Counter, OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext
from torchtext.data import get_tokenizer
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7c3601329110>

In [None]:
df = pd.read_csv(
    'eng_-french.csv',
    usecols=['English words/sentences', 'French words/sentences']
)
df.columns = ['en', 'fr']
df.head(5)

Unnamed: 0,en,fr
0,Hi.,Salut!
1,Run!,Cours !
2,Run!,Courez !
3,Who?,Qui ?
4,Wow!,Ça alors !


In [None]:
valid_punctuations = "-,'!?."
invalid_punctuations = string.punctuation
for c in valid_punctuations:
    invalid_punctuations = invalid_punctuations.replace(c, "")
invalid_punctuations_re = re.compile("["+ invalid_punctuations +"]")

# очистим текст от лишних символов
def sanitize_text(text: str) -> str:
    text = invalid_punctuations_re.sub("", text)
    text = re.sub(r"\s+", " ", text)
    return text

In [None]:
df["en"] = df["en"].apply(sanitize_text)
df["fr"] = df["fr"].apply(sanitize_text)
df

Unnamed: 0,en,fr
0,Hi.,Salut!
1,Run!,Cours !
2,Run!,Courez !
3,Who?,Qui ?
4,Wow!,Ça alors !
...,...,...
175616,"Top-down economics never works, said Obama. Th...","« L'économie en partant du haut vers le bas, ç..."
175617,A carbon footprint is the amount of carbon dio...,Une empreinte carbone est la somme de pollutio...
175618,Death is something that we're often discourage...,La mort est une chose qu'on nous décourage sou...
175619,Since there are usually multiple websites on a...,Puisqu'il y a de multiples sites web sur chaqu...


In [None]:
fr_tokenizer = get_tokenizer('spacy', language='fr_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

Создадим для каждого языка уникальный набор токенов - словарь, используя уже существующие токенизаторы.

In [None]:
en_words = Counter()
for sentence in tqdm(df["en"]):
    en_words.update(en_tokenizer(sentence))
en_words = OrderedDict(sorted(en_words.items(), key=lambda x:-x[1]))

fr_words = Counter()
for sentence in tqdm(df["fr"]):
    fr_words.update(fr_tokenizer(sentence))
fr_words = OrderedDict(sorted(fr_words.items(), key=lambda x:-x[1]))

print(f"en: {len(en_words)}")
print(f"fr: {len(fr_words)}")

  0%|          | 0/175621 [00:00<?, ?it/s]

  0%|          | 0/175621 [00:00<?, ?it/s]

en: 15895
fr: 26306


Добавим токены неизвестного символа, начала и конца последовательности.

In [None]:
unk_token = '<unk>'  # Unknown
sos_token = '<sos>'  # Start of sentence
eos_token = '<eos>'  # End of sentence
en_vocab = torchtext.vocab.vocab(en_words, specials=[unk_token, eos_token])
en_vocab.set_default_index(en_vocab[unk_token])
fr_vocab = torchtext.vocab.vocab(fr_words, specials=[unk_token, eos_token, sos_token])
fr_vocab.set_default_index(fr_vocab[unk_token])

In [None]:
# вспомогательные функции для токенизации всего датасета
def en_tokenize(text: str, append_eos=True):
    words = en_tokenizer(text)
    if append_eos:
        words.append(eos_token)
    return en_vocab(words)

def fr_tokenize(text: str, append_eos=True):
    words = [sos_token] + fr_tokenizer(text)
    if append_eos:
        words.append(eos_token)
    return fr_vocab(words)

In [None]:
x_train, x_test, y_train, y_test = train_test_split(df['en'], df['fr'], test_size=0.2)

In [None]:
x_train_seqs = [en_tokenize(x) for x in x_train]
x_test_seqs = [en_tokenize(x) for x in x_test]

y_train_seqs = [fr_tokenize(y) for y in y_train]
y_test_seqs = [fr_tokenize(y) for y in y_test]

Для того, чтобы все последовательности в датасете были одинакового размера, добавим функцию pad_batch.

In [None]:
def pad_batch(batch):
    return nn.utils.rnn.pad_sequence(batch, batch_first=True)

def collate_fn(batch):
    encoder_input = []
    decoder_input = []
    answer = []
    for e, d, a in batch:
        encoder_input.append(e)
        decoder_input.append(d)
        answer.append(a)
    encoder_input = pad_batch(encoder_input)
    decoder_input = pad_batch(decoder_input)
    answer = pad_batch(answer)
    return encoder_input, decoder_input, answer

class SeqDataset(Dataset):
    def __init__(self, en_seqs, fr_seqs):
        super().__init__()
        self.en_seqs = en_seqs
        self.fr_seqs = fr_seqs

    def __len__(self):
        return len(self.en_seqs)

    def __getitem__(self, index):
        # в качестве входа энкодера будем брать последовательность токенов из английского словаря
        encoder_input = torch.tensor(self.en_seqs[index], requires_grad=False).long()
        # в качестве входа декодера будем брать последовательность токенов из французского словаря
        # без последнего слова - его попытаемся предсказать
        decoder_input = torch.tensor(self.fr_seqs[index][:-1], requires_grad=False).long()
        # в качестве выхода декодера будем брать последовательность токенов из французского словаря
        # без первого слова - предсказываем последнее
        answer = torch.tensor(self.fr_seqs[index][1:], requires_grad=False).long()
        return encoder_input, decoder_input, answer

    def __call__(self, batch):
        # применяем паддинг для батчей
        return collate_fn(batch)

In [None]:
trainset = SeqDataset(x_train_seqs, y_train_seqs)
testset = SeqDataset(x_test_seqs, y_test_seqs)

# Трансформер

Рассмотрим архитектуру трансформера с нуля.

![трансформер](https://habrastorage.org/getpro/habr/upload_files/ff3/412/8ea/ff34128ea15cd27e42b73d1acb260ac7.png)

Сначала реализуем слой-эмбеддинг.

In [None]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)

    def forward(self, x):
        out = self.embed(x)
        return out

Реализуем эмбеддинг позиции токена в последовательности. Позиционный эмбеддинг можно задавать разными способами, например, можно кодировать с помощью последовательности sin/cos. Можно завести просто обучаемый параметр. В семинаре будем рассматривать первый вариант.

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embed_model_dim, max_seq_len=100):
        super().__init__()
        self.embed_dim = embed_model_dim

        pe = torch.zeros(max_seq_len, self.embed_dim)
        for pos in range(max_seq_len):
            for i in range(0, self.embed_dim, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / self.embed_dim)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * i) / self.embed_dim)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x * torch.sqrt(torch.tensor(self.embed_dim, dtype=torch.float))
        seq_len = x.size(1)
        # прибавили к нашему тензору эмбеддинг позиции
        x = x + self.pe[:, :seq_len]
        return x


![attention](https://dz2cdn1.dzone.com/storage/temp/11139358-screen-shot-2019-01-07-at-84314-am.png)

Рассмотрим scaled dot product attention - то, на чем стоит эта архиткетура. Вспомним, что мы пытаемся реализовать следующую формулу:
$$
attention(Q,K,V) = softmax(\frac {QK^T} {\sqrt{d_k}}V)
$$

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(
        self, d_model: int = 64,
    ) -> None:
        super().__init__()
        self.scale = d_model ** -0.5
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None, e=1e-12) -> torch.Tensor:
        # делаем k^T
        k_t = k.transpose(2, 3)

        # attn map: softmax(q @ k^T / sqrt(d_k))
        # делим, чтобы значения не улетали
        score = (q @ k_t) / self.scale

        if mask is not None:
             product = product.masked_fill(mask == 0, e)
        # проходим софтмаксом по полученной матрице
        score = self.softmax(score)
        # получаем матричку того, куда нам надо смотреть в value
        v = score @ v
        return v, score

Теперь реализуем "многоголовое" внимание - разные головы аттеншна, потенциально, могут смотреть на разные части изображения.

In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(
        self,
        n_heads: int = 12,
        d_model: int = 64,
    ) -> None:
        super().__init__()
        self.n_heads = n_heads

        self.attn = ScaledDotProductAttention(d_model)
        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.w_concat = nn.Linear(d_model, d_model)


    def forward(self, q, k, v, mask=None) -> torch.Tensor:
        q, k, v = self.wq(q), self.wk(k), self.wv(v)

        q, k, v = self.split(q), self.split(k), self.split(v)

        out, attn = self.attn(q, k, v, mask=mask)

        out = self.concat(out)
        out = self.w_concat(out)

        return out

    def split(self, x):
        batch_size, seq_length, d_model = x.size()
        head_dim = d_model // self.n_heads
        x = x.view(batch_size, seq_length, self.n_heads, head_dim)
        return x.permute(0, 2, 1, 3)

    def concat(self, x):
        x = x.permute(0, 2, 1, 3)
        batch_size, seq_length, n_head, head_dim = x.size()
        return x.contiguous().view(batch_size, seq_length, n_head * head_dim)


Не забудем и про линейный слой.

In [None]:
class MLP(torch.nn.Module):
    def __init__(
        self,
        embed_dim: int = 768,
        mlp_hidden_size: int = 3072,
        dropout_rate: float = 0.1,
    ) -> None:
        super().__init__()

        # кэжуал многослойный персептрон с использованием GELU
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_size),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_size, embed_dim),
            nn.Dropout(dropout_rate),
        )

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return self.mlp(tensor)

Теперь создадим блок-энкодер.

In [None]:
class EncoderBlock(nn.Module):
    def __init__(
        self,
        n_heads: int = 12,
        d_model: int = 64,
        mlp_hidden_size: int = 1024,
        mlp_p: float = 0.1
        ):
        super().__init__()
        self.lnorm1 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(n_heads, d_model)
        self.lnorm2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, mlp_hidden_size, mlp_p)

    def forward(self, x):
        x = x + self.mha(self.lnorm1(x), self.lnorm1(x), self.lnorm1(x))
        x = x + self.mlp(self.lnorm2(x))
        return x

И сам энкодер.

In [None]:
class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        num_layers,
        n_heads: int = 12,
        d_model: int = 64,
        mlp_hidden_size: int = 1024,
        mlp_p: float = 0.1
        ):
        super().__init__()
        self.emb = Embedding(vocab_size, d_model)
        self.pos = PositionalEmbedding(d_model)
        self.layers = nn.ModuleList([EncoderBlock(n_heads, d_model, mlp_hidden_size, mlp_p) for _ in range(num_layers)])

    def forward(self, x):
        x = self.emb(x)
        x = self.pos(x)
        for layer in self.layers:
            x = layer(x)
        return x

Создадим также блок декодера. Вспомним, что он отличается дополнительным слоем, включающим в себя маску.

In [None]:
def make_trg_mask(trg):
    batch_size, trg_len = trg.shape
    trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
        batch_size, 1, trg_len, trg_len
    )
    return trg_mask

class DecoderBlock(nn.Module):
    def __init__(
        self,
        n_heads: int = 12,
        d_model: int = 64,
        mlp_hidden_size: int = 1024,
        mlp_p: float = 0.1
        ):
        super().__init__()
        self.lnorm1 = nn.LayerNorm(d_model)
        self.mha1 = MultiHeadAttention(n_heads, d_model)
        self.lnorm2 = nn.LayerNorm(d_model)
        self.mha2 = MultiHeadAttention(n_heads, d_model)
        self.lnorm3 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, mlp_hidden_size, mlp_p)

    def forward(self, x, enc_out, mask=None):
        x = x + self.mha1(self.lnorm1(x), self.lnorm1(x), self.lnorm1(x), mask=mask)
        x = x + self.mha2(self.lnorm2(x), self.lnorm2(enc_out), self.lnorm2(enc_out))
        x = x + self.mlp(self.lnorm3(x))
        return x

Создадим декодер.

In [None]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        num_layers,
        n_heads: int = 12,
        d_model: int = 64,
        mlp_hidden_size: int = 1024,
        mlp_p: float = 0.1,
        drop_p: float = 0.1
        ):
        super().__init__()
        self.emb = Embedding(vocab_size, d_model)
        self.pos = PositionalEmbedding(d_model)
        self.drop = nn.Dropout(drop_p)
        self.layers = nn.ModuleList([DecoderBlock(n_heads, d_model, mlp_hidden_size, mlp_p) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, enc_out, mask=None):
        x = self.emb(x)
        x = self.pos(x)
        x = self.drop(x)
        for layer in self.layers:
            x = layer(x, enc_out, mask)
        return self.softmax(self.fc_out(x))


Дополнительная функция создания маски.

Соберем наш трансформер.

In [None]:
class Transformer(nn.Module):
    def __init__(
            self,
            in_vocab_size,
            out_vocab_size,
            num_layers,
            n_heads: int = 8,
            d_model: int = 64,
            mlp_hidden_size: int = 1024,
            enc_mlp_p: float = 0.1,
            dec_mlp_p: float = 0.1,
            drop_p: float = 0.1,
    ):
        super().__init__()
        self.encoder = Encoder(
            vocab_size=in_vocab_size,
            num_layers=num_layers,
            n_heads=n_heads,
            d_model=d_model,
            mlp_hidden_size=mlp_hidden_size,
            mlp_p=enc_mlp_p
        )

        self.decoder = Decoder(
            vocab_size=out_vocab_size,
            num_layers=num_layers,
            n_heads=n_heads,
            d_model=d_model,
            mlp_hidden_size=mlp_hidden_size,
            mlp_p=dec_mlp_p,
            drop_p=drop_p
        )

        self.fc_out = nn.Linear(d_model, out_vocab_size)

    def decode(self, source, target):
        target_mask = make_trg_mask(target)
        enc_out = self.encoder(source)
        out_labels = []

        batch_size, seq_len = source.shape[0], source.shape[1]
        out = target
        for i in range(seq_len):
            out = self.decoder(out, enc_out, target_mask)
            out = out[:,-1,:]

            out = out.argmax(-1)
            out_labels.append(out.item())
            out = torch.unsqueeze(out,axis=0)

        return out_labels

    def forward(self, source, target):
        enc_out = self.encoder(source)
        dec_out = self.decoder(target, enc_out)
        return dec_out

# Обучение

In [None]:
def train(model, train_dataloader, valid_dataloader, criterion, optimizer, device, epochs):
    for epoch in range(epochs):
        model.train()
        model.to(device)
        losses = 0
        train_tqdm_iterator = tqdm(train_dataloader)
        for en, fr, out in train_tqdm_iterator:
            en = en.to(device)
            fr = fr.to(device)
            out = out.to(device)

            logits = model(en, fr)
            optimizer.zero_grad()
            loss = criterion(logits.reshape(-1, logits.shape[-1]), out.reshape(-1))
            loss.backward()

            optimizer.step()
            losses += loss.item()
            train_tqdm_iterator.set_description(f"Epoch {epoch}, train_loss: {loss.item()}")

        model.eval()
        losses = 0
        valid_tqdm_iterator = tqdm(valid_dataloader)
        for en, fr, out in valid_tqdm_iterator:
            en = en.to(device)
            fr = fr.to(device)
            out = out.to(device)

            logits = model(en, fr)
            loss = criterion(logits.reshape(-1, logits.shape[-1]), out.reshape(-1))
            losses += loss.item()
            valid_tqdm_iterator.set_description(f"Epoch {epoch}, valid_loss: {loss.item()}")

In [None]:
src_vocab_size = len(en_vocab)
tgt_vocab_size = len(fr_vocab)
num_layers = 6
n_heads = 8
d_model = 64
mlp_hidden_size = 256
enc_mlp_p = 0.1
dec_mlp_p = 0.1
drop_p = 0.1

model = Transformer(
    in_vocab_size=src_vocab_size,
    out_vocab_size=tgt_vocab_size,
    num_layers=num_layers,
    n_heads=n_heads,
    d_model=d_model,
    mlp_hidden_size=mlp_hidden_size,
    enc_mlp_p=enc_mlp_p,
    dec_mlp_p=dec_mlp_p,
    drop_p=drop_p
)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
batch_size = 50

train_dataloader = DataLoader(trainset, batch_size=batch_size, collate_fn=trainset)
valid_dataloader = DataLoader(testset, batch_size=batch_size, collate_fn=testset)
device = torch.device("cuda")

In [None]:
model

Transformer(
  (encoder): Encoder(
    (emb): Embedding(
      (embed): Embedding(15897, 64)
    )
    (pos): PositionalEmbedding()
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (lnorm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mha): MultiHeadAttention(
          (attn): ScaledDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (wq): Linear(in_features=64, out_features=64, bias=False)
          (wk): Linear(in_features=64, out_features=64, bias=False)
          (wv): Linear(in_features=64, out_features=64, bias=False)
          (w_concat): Linear(in_features=64, out_features=64, bias=True)
        )
        (lnorm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (mlp): Sequential(
            (0): Linear(in_features=64, out_features=256, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.1, inplace=False)
            (3): Linear(in_features=256, out

In [50]:
num_epochs = 10
train(
    model,
    train_dataloader,
    valid_dataloader,
    criterion,
    optimizer,
    device,
    epochs=num_epochs
    )

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]

  0%|          | 0/2810 [00:00<?, ?it/s]

  0%|          | 0/703 [00:00<?, ?it/s]