In [30]:
from datasets import load_dataset
import numpy as np
import pandas as pd
import torch
import os
import sentencepiece as spm
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math


In [2]:
os.environ["HF_DATASETS_ALLOW_CODE"] = "1"

In [3]:
dataset = load_dataset("opus100","de-en")

In [4]:
src_lang = "en"
target_lang = "de"

In [5]:
train_data = dataset["train"].shuffle(seed=42).select(range(100_000))
val_data = dataset["validation"]
test_data = dataset["test"]

In [6]:
train_data

Dataset({
    features: ['translation'],
    num_rows: 100000
})

In [7]:
print(train_data[0]["translation"]["en"])
print(train_data[0]["translation"]["de"])

[{"Id":"52281fa3-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa5-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa6-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa7-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":2},{"Id":"52281fa8-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":1},{"Id":"3711adf0-87d6-4615-aed3-abc0671c8b85","Index":0,"Count":1},{"Id":"a8696682-c170-4f76-a6f2-92aaa7276042","Index":0,"Count":1}]
[{"Id":"52281fa3-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa5-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa6-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa7-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":2},{"Id":"52281fa8-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":1},{"Id":"3711adf0-87d6-4615-aed3-abc0671c8b85","Index":0,"Count":1},{"Id":"a8696682-c170-4f76-a6f2-92aaa7276042","Index":0,"Count":1}]


In [8]:
def get_lengths(example):
    src = example["translation"][src_lang]
    tgt = example["translation"][target_lang]
    return {
        "src_len": len(src.split()),
        "tgt_len": len(tgt.split())
    }
lengths = train_data.map(get_lengths)

In [9]:
lengths[0]

{'translation': {'de': '[{"Id":"52281fa3-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa5-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa6-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa7-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":2},{"Id":"52281fa8-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":1},{"Id":"3711adf0-87d6-4615-aed3-abc0671c8b85","Index":0,"Count":1},{"Id":"a8696682-c170-4f76-a6f2-92aaa7276042","Index":0,"Count":1}]',
  'en': '[{"Id":"52281fa3-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa5-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa6-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":0},{"Id":"52281fa7-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":2},{"Id":"52281fa8-51ca-11d1-8f14-00a02427d15e","Index":0,"Count":1},{"Id":"3711adf0-87d6-4615-aed3-abc0671c8b85","Index":0,"Count":1},{"Id":"a8696682-c170-4f76-a6f2-92aaa7276042","Index":0,"Count":1}]'},
 'src_len': 1,
 'tgt_len': 1}

In [10]:
print(train_data[1]["translation"]["en"])

1 x Spider Straps System f...


In [11]:
def is_valid_text(example):
    src = example["translation"][src_lang]
    tgt = example["translation"][target_lang]

    if src.strip().startswith("[{") or tgt.strip().startswith("[{"):
        return False

    if len(src.split()) < 3 or len(tgt.split()) < 3:
        return False

    return True

In [12]:
clean_train = train_data.filter(is_valid_text)

In [13]:
print(clean_train[0]["translation"]["en"])
print(clean_train[0]["translation"]["de"])

1 x Spider Straps System f...
1 x Metall-Helmhalterung f...


In [14]:
df = clean_train.to_pandas()

In [15]:
df["en"] = df["translation"].apply(lambda x: x["en"])
df["de"] = df["translation"].apply(lambda x: x["de"])

In [16]:
df = df.drop(columns=["translation"])

In [17]:
df

Unnamed: 0,en,de
0,1 x Spider Straps System f...,1 x Metall-Helmhalterung f...
1,Leave a message.,Hinterlasst 'ne Nachricht.
2,But you take the night.,Aber schlafen Sie einmal drüber.
3,Weekly rate from: EUR 750,Wochenpreis ab: EUR 1.070
4,"Stop with the coffee, okay?","Hör mit dem Kaffee auf, okay?"
...,...,...
83226,"Rare, wild flowers include a species of red wi...","Seltene, wilde Blumen einschließlich eine Art ..."
83227,Look at him over there having lunch with his dad.,"Guck nur, wie er da mit seinem Vater isst."
83228,"6. In view of the above considerations, the cl...","13. Die Kammer kommt deshalb zum Schluss, dass..."
83229,No sign of Zorn.,Kein Anzeichen von Zorn.


In [18]:
df_sample = df.sample(30000, random_state=42)

df_sample["en"].to_csv("train.en", index=False, header=False)
df_sample["de"].to_csv("train.de", index=False, header=False)

In [19]:
spm.SentencePieceTrainer.train(
    input="train.en,train.de",
    model_prefix="bpe_en_de",
    vocab_size=8000,
    model_type="bpe",
    character_coverage=1.0,
    bos_id=1,
    eos_id=2,
    unk_id=0,
    pad_id=3
)


sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: train.en
  input: train.de
  input_format: 
  model_prefix: bpe_en_de
  model_type: BPE
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 1
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: 3
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
  enable_differential_privacy: 0
  

In [20]:
sp = spm.SentencePieceProcessor()
sp.load("bpe_en_de.model")

True

In [21]:
encoded_data_dict = {
    "src_en_ids": [],
    "tgt_de_ids": []
}

max_src_len = 0
max_tgt_len = 0
for _, row in df.iterrows():
    en_text = row["en"]
    de_text = row["de"]

    en_ids = [sp.bos_id()] + sp.encode(en_text, out_type=int) + [sp.eos_id()]
    de_ids = [sp.bos_id()] + sp.encode(de_text, out_type=int) + [sp.eos_id()]

    encoded_data_dict["src_en_ids"].append(en_ids)
    encoded_data_dict["tgt_de_ids"].append(de_ids)

    if len(en_ids) > max_src_len:
        max_src_len = len(en_ids)

    if len(de_ids) > max_tgt_len:
        max_tgt_len = len(de_ids)

In [22]:
print(encoded_data_dict["src_en_ids"][0])
print(encoded_data_dict["tgt_de_ids"][0])

[1, 65, 1236, 580, 958, 2294, 3624, 2524, 33, 289, 2]
[1, 65, 1236, 4074, 234, 7530, 7547, 86, 7512, 402, 85, 62, 33, 289, 2]


In [23]:
print(sp.decode(encoded_data_dict["src_en_ids"][0][1:-1]))
print(sp.decode(encoded_data_dict["tgt_de_ids"][0][1:-1]))

1 x Spider Straps System f...
1 x Metall-Helmhalterung f...


In [24]:
print("Max EN length:", max_src_len)
print("Max DE length:", max_tgt_len)

Max EN length: 2593
Max DE length: 3445


In [39]:
PAD_ID = sp.pad_id()
BOS_ID = sp.bos_id()
EOS_ID = sp.eos_id()

def process_sequence(seq, max_len):
    if len(seq) > max_len:
        seq = seq[:max_len]
        seq[-1] = EOS_ID 

    if len(seq) < max_len:
        seq = seq + [PAD_ID] * (max_len - len(seq))

    return seq

In [40]:
MAX_LEN = 512
encoded_data_dict["src_en_ids"] = [
    process_sequence(seq, MAX_LEN)
    for seq in encoded_data_dict["src_en_ids"]
]

encoded_data_dict["tgt_de_ids"] = [
    process_sequence(seq, MAX_LEN)
    for seq in encoded_data_dict["tgt_de_ids"]
]

In [41]:
print(len(encoded_data_dict["src_en_ids"][0]))
print(len(encoded_data_dict["tgt_de_ids"][0]))

512
512


In [42]:
print(encoded_data_dict["src_en_ids"][0][:10])
print(encoded_data_dict["src_en_ids"][0][-10:])  

[1, 65, 1236, 580, 958, 2294, 3624, 2524, 33, 289]
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]


In [48]:
class SingleAttentionHead(torch.nn.Module):

    def __init__(self, query_key_embedding_dim, value_embedding_dim,
                 sha_dim, masked, is_dropout, dropout_probability):
        super().__init__()

        self.sha_dim = sha_dim
        self.masked = masked
        self.is_dropout = is_dropout

        self.query_projection_layer = torch.nn.Linear(
            query_key_embedding_dim, sha_dim, bias=False
        )
        self.key_projection_layer = torch.nn.Linear(
            query_key_embedding_dim, sha_dim, bias=False
        )
        self.value_projection_layer = torch.nn.Linear(
            value_embedding_dim, sha_dim, bias=False
        )

        self.softmax_activation = torch.nn.Softmax(dim=-1)

        if self.is_dropout:
            self.dropout = torch.nn.Dropout(p=dropout_probability)

    def forward(self, query_embedding, key_embedding, value_embedding):

        Q = self.query_projection_layer(query_embedding)
        K = self.key_projection_layer(key_embedding)
        V = self.value_projection_layer(value_embedding)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(self.sha_dim, dtype=torch.float32)
        )

        if self.masked:
            mask = torch.tril(torch.ones_like(scores))
            scores = scores.masked_fill(mask == 0, float("-inf"))

        attention_weights = self.softmax_activation(scores)

        if self.is_dropout:
            attention_weights = self.dropout(attention_weights)

        return torch.matmul(attention_weights, V)


In [49]:
class MultiHeadAttentionLayer(torch.nn.Module):

    def __init__(self, embedding_dim, num_attn_heads,
                 masked, is_dropout, dropout_probability):
        super().__init__()

        sha_dim = embedding_dim // num_attn_heads
        self.attn_heads = torch.nn.ModuleList([
            SingleAttentionHead(
                embedding_dim, embedding_dim,
                sha_dim, masked, is_dropout, dropout_probability
            )
            for _ in range(num_attn_heads)
        ])

        self.output_projection = torch.nn.Linear(
            embedding_dim, embedding_dim, bias=False
        )

        if is_dropout:
            self.dropout = torch.nn.Dropout(p=dropout_probability)
        self.is_dropout = is_dropout

    def forward(self, query, key, value):

        head_outputs = [
            head(query, key, value) for head in self.attn_heads
        ]

        concat = torch.cat(head_outputs, dim=-1)
        out = self.output_projection(concat)

        if self.is_dropout:
            out = self.dropout(out)

        return out


In [50]:
class EncoderLayer(torch.nn.Module):

    def __init__(self, model_dim, num_heads,
                 dropout_probability, ffn_dim):
        super().__init__()

        self.mha = MultiHeadAttentionLayer(
            model_dim, num_heads, False, True, dropout_probability
        )
        self.norm1 = torch.nn.LayerNorm(model_dim)

        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(model_dim, ffn_dim),
            torch.nn.GELU(),
            torch.nn.Linear(ffn_dim, model_dim)
        )
        self.norm2 = torch.nn.LayerNorm(model_dim)
        self.dropout = torch.nn.Dropout(dropout_probability)

    def forward(self, x):

        x = self.norm1(x + self.mha(x, x, x))
        x = self.norm2(x + self.dropout(self.ffn(x)))

        return x


In [51]:
class DecoderLayer(torch.nn.Module):

    def __init__(self, model_dim, num_heads,
                 dropout_probability, ffn_dim):
        super().__init__()

        self.self_attn = MultiHeadAttentionLayer(
            model_dim, num_heads, True, True, dropout_probability
        )
        self.enc_dec_attn = MultiHeadAttentionLayer(
            model_dim, num_heads, False, True, dropout_probability
        )

        self.norm1 = torch.nn.LayerNorm(model_dim)
        self.norm2 = torch.nn.LayerNorm(model_dim)
        self.norm3 = torch.nn.LayerNorm(model_dim)

        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(model_dim, ffn_dim),
            torch.nn.GELU(),
            torch.nn.Linear(ffn_dim, model_dim)
        )

        self.dropout = torch.nn.Dropout(dropout_probability)

    def forward(self, x, encoder_output):

        x = self.norm1(x + self.self_attn(x, x, x))
        x = self.norm2(x + self.enc_dec_attn(x, encoder_output, encoder_output))
        x = self.norm3(x + self.dropout(self.ffn(x)))

        return x


In [52]:
class TransformerTranslator(torch.nn.Module):

    def __init__(self, vocab_size, model_dim,
                 num_layers, num_heads,
                 ffn_dim, max_len):
        super().__init__()

        self.token_embedding = torch.nn.Embedding(vocab_size, model_dim)
        self.position_embedding = torch.nn.Embedding(max_len, model_dim)

        self.encoder_layers = torch.nn.ModuleList([
            EncoderLayer(model_dim, num_heads, 0.1, ffn_dim)
            for _ in range(num_layers)
        ])

        self.decoder_layers = torch.nn.ModuleList([
            DecoderLayer(model_dim, num_heads, 0.1, ffn_dim)
            for _ in range(num_layers)
        ])

        self.output_projection = torch.nn.Linear(model_dim, vocab_size)

    def forward(self, src_ids, tgt_ids):

        src_pos = torch.arange(src_ids.size(1)).to(src_ids.device)
        tgt_pos = torch.arange(tgt_ids.size(1)).to(tgt_ids.device)

        src = self.token_embedding(src_ids) + self.position_embedding(src_pos)
        tgt = self.token_embedding(tgt_ids) + self.position_embedding(tgt_pos)

        for enc in self.encoder_layers:
            src = enc(src)

        for dec in self.decoder_layers:
            tgt = dec(tgt, src)

        return self.output_projection(tgt)

In [53]:
def training_data_generator(encoded_data_dict, max_len):

    for src_ids, tgt_ids in zip(
        encoded_data_dict["src_en_ids"],
        encoded_data_dict["tgt_de_ids"]
    ):
        yield (
            torch.tensor(src_ids[:max_len]),
            torch.tensor(tgt_ids[:-1][:max_len]),
            torch.tensor(tgt_ids[1:][:max_len])
        )


In [None]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=sp.pad_id())
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

epochs = 20
X_src = torch.nn.utils.rnn.pad_sequence(
    [torch.tensor(x) for x in encoded_data_dict["src_en_ids"]],
    batch_first=True,
    padding_value=sp.pad_id()
)

Y_tgt = torch.nn.utils.rnn.pad_sequence(
    [torch.tensor(y) for y in encoded_data_dict["tgt_de_ids"]],
    batch_first=True,
    padding_value=sp.pad_id()
)

X_src = X_src[:, :512].cuda()
Y_tgt = Y_tgt[:, :512].cuda()
model.train()

for epoch in range(epochs):

    # shift target
    tgt_input  = Y_tgt[:, :-1]
    tgt_output = Y_tgt[:, 1:]

    # forward pass
    y_hat = model(X_src, tgt_input)

    # loss
    loss = loss_fn(
        y_hat.reshape(-1, y_hat.size(-1)),
        tgt_output.reshape(-1)
    )

    # backward
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(
        "Epoch # {}, Train Loss Value = {}".format(
            epoch + 1, loss.item()
        )
    )


KeyboardInterrupt: 