In [31]:
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
import tqdm
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

In [29]:
train_data = load_dataset("wmt16", "de-en", split="train[:50000]")
val_data = load_dataset("wmt16", "de-en", split="validation")
test_data = load_dataset("wmt16", "de-en", split="test")

In [30]:
print(f"Train data size: {len(train_data)}, type: {type(train_data)}")
print(f"Validation data size: {len(val_data)}, type: {type(val_data)}")
print(f"Test data size: {len(test_data)}, type: {type(test_data)}")

Train data size: 50000, type: <class 'datasets.arrow_dataset.Dataset'>
Validation data size: 2169, type: <class 'datasets.arrow_dataset.Dataset'>
Test data size: 2999, type: <class 'datasets.arrow_dataset.Dataset'>


In [4]:
for i in range(5):
    data = train_data[i]
    german = data["translation"]["de"]
    english = data["translation"]["en"]
    print(f"German: {german}")
    print(f"English: {english}")

German: Wiederaufnahme der Sitzungsperiode
English: Resumption of the session
German: Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten.
English: I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.
German: Wie Sie feststellen konnten, ist der gefürchtete "Millenium-Bug " nicht eingetreten. Doch sind Bürger einiger unserer Mitgliedstaaten Opfer von schrecklichen Naturkatastrophen geworden.
English: Although, as you will have seen, the dreaded 'millennium bug' failed to materialise, still the people in a number of countries suffered a series of natural disasters that truly were dreadful.
German: Im Parlament besteht der Wunsch nach einer Aussprache im Verlauf dieser Sit

In [50]:
source_language = "de"
target_language = "en"

In [51]:
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

#### Installations

In [None]:
!pip install portalocker

In [8]:
!pip install -U  torchdata




[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [9]:
!pip install -U spacy

Collecting spacy
  Downloading spacy-3.7.4-cp311-cp311-win_amd64.whl.metadata (27 kB)
Collecting spacy-legacy<3.1.0,>=3.0.11 (from spacy)
  Downloading spacy_legacy-3.0.12-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting spacy-loggers<2.0.0,>=1.0.0 (from spacy)
  Using cached spacy_loggers-1.0.5-py3-none-any.whl.metadata (23 kB)
Collecting murmurhash<1.1.0,>=0.28.0 (from spacy)
  Using cached murmurhash-1.0.10-cp311-cp311-win_amd64.whl.metadata (2.0 kB)
Collecting cymem<2.1.0,>=2.0.2 (from spacy)
  Using cached cymem-2.0.8-cp311-cp311-win_amd64.whl.metadata (8.6 kB)
Collecting preshed<3.1.0,>=3.0.2 (from spacy)
  Using cached preshed-3.0.9-cp311-cp311-win_amd64.whl.metadata (2.2 kB)
Collecting thinc<8.3.0,>=8.2.2 (from spacy)
  Downloading thinc-8.2.3-cp311-cp311-win_amd64.whl.metadata (15 kB)
Collecting wasabi<1.2.0,>=0.9.1 (from spacy)
  Using cached wasabi-1.1.2-py3-none-any.whl.metadata (28 kB)
Collecting srsly<3.0.0,>=2.4.3 (from spacy)
  Using cached srsly-2.4.8-cp311-cp311-win_


[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [13]:
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.7.1


[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip



  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
     --------------------------------------- 0.1/12.8 MB 660.6 kB/s eta 0:00:20
     - -------------------------------------- 0.5/12.8 MB 3.5 MB/s eta 0:00:04
     --- ------------------------------------ 1.1/12.8 MB 6.2 MB/s eta 0:00:02
     ----- ---------------------------------- 1.8/12.8 MB 8.3 MB/s eta 0:00:02
     -------- ------------------------------- 2.8/12.8 MB 10.5 MB/s eta 0:00:01
     ---------- ----------------------------- 3.5/12.8 MB 11.1 MB/s eta 0:00:01
     ------------- -------------------------- 4.3/12.8 MB 12.0 MB/s eta 0:00:01
     ---------------- ----------------------- 5.3/12.8 MB 13.6 MB/s eta 0:00:01
     ------------------- -------------------- 6.1/12.8 MB 14.0 MB/s eta 0:0

In [14]:
!python -m spacy download de_core_news_sm

Collecting de-core-news-sm==3.7.0


[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip



  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.7.0/de_core_news_sm-3.7.0-py3-none-any.whl (14.6 MB)
     ---------------------------------------- 0.0/14.6 MB ? eta -:--:--
     ---------------------------------------- 0.0/14.6 MB ? eta -:--:--
     --------------------------------------- 0.1/14.6 MB 518.5 kB/s eta 0:00:29
     - -------------------------------------- 0.4/14.6 MB 2.9 MB/s eta 0:00:05
     -- ------------------------------------- 0.8/14.6 MB 4.7 MB/s eta 0:00:03
     --- ------------------------------------ 1.4/14.6 MB 5.9 MB/s eta 0:00:03
     ----- ---------------------------------- 2.0/14.6 MB 6.9 MB/s eta 0:00:02
     ------- -------------------------------- 2.6/14.6 MB 7.9 MB/s eta 0:00:02
     -------- ------------------------------- 3.2/14.6 MB 8.9 MB/s eta 0:00:02
     ---------- ----------------------------- 3.9/14.6 MB 9.6 MB/s eta 0:00:02
     ------------ --------------------------- 4.7/14.6 MB 9.9 MB/s eta 0:00:0

#### Dataloaders

In [123]:
# Tokenization
token_transform = {}
token_transform[source_language] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[target_language] = get_tokenizer('spacy', language='en_core_web_sm')

In [67]:
for i in range(5): 
    print(train_data[i])
    data_pt = train_data[i]
    data_src = data_pt['translation'][source_language]
    data_tgt = data_pt['translation'][target_language]
    print(f"German: {data_src}")
    print(f"English: {data_tgt}")
    print(f"Tokenized German: {token_transform[source_language](data_src)}")
    print(f"Tokenized English: {token_transform[target_language](data_tgt)}")
    break
    

{'translation': {'de': 'Wiederaufnahme der Sitzungsperiode', 'en': 'Resumption of the session'}}
German: Wiederaufnahme der Sitzungsperiode
English: Resumption of the session
Tokenized German: ['Wiederaufnahme', 'der', 'Sitzungsperiode']
Tokenized English: ['Resumption', 'of', 'the', 'session']


In [97]:
# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    for data_sample in data_iter:
        yield token_transform[language](data_sample['translation'][language])

tokens = yield_tokens(train_data, source_language)
for token in tokens: 
    print(token)
    break

['Wiederaufnahme', 'der', 'Sitzungsperiode']


In [93]:
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

In [102]:
for ln in [source_language, target_language]:
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_data, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

In [100]:
for ln in [source_language, target_language]:
  vocab_transform[ln].set_default_index(UNK_IDX)

In [122]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True)

for i, batch in enumerate(train_dataloader):
    print(batch)
    if i == 4:
        break

{'translation': {'de': ['Dieses Parlament wird heute einen gemeinsamen Entschließungsentwurf mehrerer Fraktionen annehmen, den ich, Herr Präsident, meine Damen und Herren, für wichtig und angebracht halte.', 'Wir tun wesentlich mehr, als ich dachte, und ich bin gern bereit, das Parlament über die vielfältigen Aktivitäten im Bereich der Primarbildung zu informieren, die wir in Afrika durchführen.', 'Die Kommission ist aufgerufen, eine Bestandsaufnahme der nationalen Strafgesetzgebungen in diesem Bereich vorzunehmen und bei Bedarf entsprechende Vorschläge für Straftatbestände zu machen.', 'Nach diesen meines Erachtens doch sehr grundsätzlichen Bemerkungen bestehen nun praktische Vorbehalte gegen die Tobin-Steuer, wie sie von Herrn Karas, aber auch von anderen Rednern geäußert wurden.', 'Die PSE-Fraktion hat in den letzten Tagen einige wichtige Verbesserungsvorschläge unterbreitet. Ich weise Sie vor allem auf Abänderungsantrag 45 hin.', 'Es wäre schlimm, wenn dieses Beispiel Schule machen

In [103]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [104]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [105]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)



In [106]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [source_language, target_language]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[source_language](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[target_language](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

In [107]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

In [119]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = train_data
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE)

    for i, batch in enumerate(train_dataloader):
        src = batch[i]['translation'][SRC_LANGUAGE]
        tgt = batch[i]['translation'][TGT_LANGUAGE]
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [120]:
from timeit import default_timer as timer
NUM_EPOCHS = 18

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))


# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

KeyError: 0