In [1]:
import os
import nltk
import torch
import heapq
import random

In [2]:
import torch.nn.functional as F

In [3]:
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning import seed_everything
from tqdm.auto import tqdm

In [4]:
from sacremoses import MosesTokenizer

In [5]:
import numpy as np

In [6]:
from dotless_arabic.experiments.translation.src.settings import (
    configure_environment,
)
from dotless_arabic.tokenizers import SentencePieceTokenizer



In [7]:
from dotless_arabic.experiments.translation.src.processing import (
    process_en,
    process_ar,
)
from dotless_arabic.experiments.translation.src.utils import get_source_tokenizer, get_target_tokenizer
from dotless_arabic.callbacks import EpochTimerCallback
from dotless_arabic.experiments.translation.src.datasets import get_dataloader
from dotless_arabic.experiments.translation.src.models import TranslationTransformer
from dotless_arabic.experiments.translation.src.utils import get_best_checkpoint, get_sequence_length, train_translator
from dotless_arabic.experiments.translation.src.utils import get_blue_score
from dotless_arabic.experiments.translation.src import constants
from dotless_arabic.experiments.translation.src.utils import create_features_from_text_list

In [8]:
constants.DEVICE

'cuda'

In [9]:
constants.DEVICE = 'cuda:0'

In [10]:
seed = 42

In [11]:
random.seed(seed)
torch.cuda.empty_cache()  # to free gpu memory
nltk.download("stopwords")
seed_everything(seed, workers=True)
os.environ["WANDB_MODE"] = "disabled"
torch.autograd.set_detect_anomaly(True)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # to see CUDA errors

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/majed_alshaibani/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Global seed set to 42


In [12]:
from dotless_arabic.datasets.iwslt2017.collect import (
    collect_parallel_train_dataset_for_translation,
    collect_parallel_val_dataset_for_translation,
    collect_parallel_test_dataset_for_translation,
)

In [13]:
train_dataset = collect_parallel_train_dataset_for_translation()
val_dataset = collect_parallel_val_dataset_for_translation()
test_dataset = collect_parallel_test_dataset_for_translation()

In [14]:
train_dataset = train_dataset.to_pandas()
val_dataset = val_dataset.to_pandas()
test_dataset = test_dataset.to_pandas()

In [15]:
tqdm.pandas()

In [16]:
train_dataset["ar"] = train_dataset["ar"].progress_map(lambda text: process_ar(text))
val_dataset["ar"] = val_dataset["ar"].progress_map(lambda text: process_ar(text))
test_dataset["ar"] = test_dataset["ar"].progress_map(lambda text: process_ar(text))

train_dataset["en"] = train_dataset["en"].progress_map(lambda text: process_en(text))
val_dataset["en"] = val_dataset["en"].progress_map(lambda text: process_en(text))
test_dataset["en"] = test_dataset["en"].progress_map(lambda text: process_en(text))

  0%|          | 0/231713 [00:00<?, ?it/s]

  0%|          | 0/888 [00:00<?, ?it/s]

  0%|          | 0/1205 [00:00<?, ?it/s]

  0%|          | 0/231713 [00:00<?, ?it/s]

  0%|          | 0/888 [00:00<?, ?it/s]

  0%|          | 0/1205 [00:00<?, ?it/s]

In [17]:
moses_tokenizer = MosesTokenizer()
train_dataset["en"] = train_dataset["en"].progress_map(
    lambda text: moses_tokenizer.tokenize(
        text,
        return_str=True,
    )
)
val_dataset["en"] = val_dataset["en"].progress_map(
    lambda text: moses_tokenizer.tokenize(
        text,
        return_str=True,
    )
)
test_dataset["en"] = test_dataset["en"].progress_map(
    lambda text: moses_tokenizer.tokenize(
        text,
        return_str=True,
    )
)
train_dataset["ar"] = train_dataset["ar"].progress_map(
    lambda text: moses_tokenizer.tokenize(
        text,
        return_str=True,
    )
)
val_dataset["ar"] = val_dataset["ar"].progress_map(
    lambda text: moses_tokenizer.tokenize(
        text,
        return_str=True,
    )
)
test_dataset["ar"] = test_dataset["ar"].progress_map(
    lambda text: moses_tokenizer.tokenize(
        text,
        return_str=True,
    )
)

  0%|          | 0/231713 [00:00<?, ?it/s]

  0%|          | 0/888 [00:00<?, ?it/s]

  0%|          | 0/1205 [00:00<?, ?it/s]

  0%|          | 0/231713 [00:00<?, ?it/s]

  0%|          | 0/888 [00:00<?, ?it/s]

  0%|          | 0/1205 [00:00<?, ?it/s]

In [18]:
source_tokenizer_class = SentencePieceTokenizer
target_tokenizer_class = SentencePieceTokenizer
source_language_code = "ar"
target_language_code = "en"

In [19]:
source_tokenizer = get_source_tokenizer(
    train_dataset=train_dataset,
    tokenizer_class=source_tokenizer_class,
    source_language_code=source_language_code,
    undot_text=False,
)
target_tokenizer = get_target_tokenizer(
    train_dataset=train_dataset,
    tokenizer_class=target_tokenizer_class,
    target_language_code=target_language_code,
    undot_text=False,
)

  0%|          | 0/231713 [00:00<?, ?it/s]

Training SentencePiece ...


  0%|          | 0/231713 [00:00<?, ?it/s]

Training SentencePiece ...


In [20]:
source_max_sequence_length = get_sequence_length(
    dataset=list(
        map(
            source_tokenizer.split_text,
            tqdm(train_dataset[source_language_code]),
        )
    ),
)

target_max_sequence_length = get_sequence_length(
    dataset=list(
        map(
            target_tokenizer.split_text,
            tqdm(train_dataset[target_language_code]),
        )
    )
)

sequence_length = max(source_max_sequence_length, target_max_sequence_length)
sequence_length

  0%|          | 0/231713 [00:00<?, ?it/s]



  0%|          | 0/231713 [00:00<?, ?it/s]

  0%|          | 0/231713 [00:00<?, ?it/s]

  0%|          | 0/231713 [00:00<?, ?it/s]

91

In [21]:
model = TranslationTransformer.load_from_checkpoint(
        get_best_checkpoint(
            is_dotted=True,
            source_language_code=source_language_code,
            target_language_code=target_language_code,
            source_tokenizer_class=source_tokenizer_class,
            target_tokenizer_class=target_tokenizer_class,
        ),
        map_location=constants.DEVICE
    )
model

best ckpt: NMT/ar_to_en/SentencePieceTokenizer_to_SentencePieceTokenizer/dotted/checkpoints/epoch=26-val_loss=2.568-step=145503.ckpt


TranslationTransformer(
  (train_ppl): Perplexity()
  (val_ppl): Perplexity()
  (test_ppl): Perplexity()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleL

## Greedy

get the score, on-by-one

In [22]:
get_blue_score(
            model=model,
            is_dotted=True,
            show_translations_for=5,
            decode_with_beam_search=False,
            source_tokenizer=source_tokenizer,
            target_tokenizer=target_tokenizer,
            save_predictions_and_targets=False,
            max_sequence_length=sequence_length,
            source_language_code=source_language_code,
            target_language_code=target_language_code,
            source_sentences=test_dataset[source_language_code],
            target_sentences=test_dataset[target_language_code],
        )

  0%|          | 0/1205 [00:00<?, ?it/s]

  0%|          | 0/1205 [00:00<?, ?it/s]

Source: سأحدثكم اليوم عن 30 عاما من تاريخ الهندسة.
Prediction: i'm going to talk to you today about 30 years of the history of engineering.
Target: today i'm going to speak to you about the last 30 years of architectural history.
********************************************************************************
Source: هذا أمر كبير جدا لألخصه في 18 دقيقة.
Prediction: this is a very big thing to summarize in 18 minutes.
Target: that's a lot to pack into 18 minutes.
********************************************************************************
Source: إنه موضوع معقد ، لذلك فإننا سنتوجه مباشرة إلى مكان معقد: إلى نيو جيرسي ،
Prediction: it's a complicated topic, so we're going to head right into a complex place: to new jersey.
Target: it's a complex topic, so we're just going to dive right in at a complex place: new jersey.
********************************************************************************
Source: لأنه منذ 30 سنة ، أنا من نيوجيرسي ، كنت في السادسة من عمري ، وكنت أعيش هناك مع 

30.242

In [23]:
def batch_translate(
    model,
    input_sentences,
    source_tokenizer=source_tokenizer,
    target_tokenizer=target_tokenizer,
    max_sequence_length=sequence_length,
    # temperature=0.1,
    temperature=0,
    batch_size=64,
):
    model.eval()
    with torch.no_grad():
        encoded_input_sentences = (
            torch.tensor(
                create_features_from_text_list(
                    use_tqdm=False,
                    text_list=input_sentences,
                    tokenizer=source_tokenizer,
                    sequence_length=max_sequence_length,
                )
            )
            .view(-1, max_sequence_length)
            .to(model.device)
        )
        # decrease batch size untill it is divisible by the encoded_input_sentences size
        actual_batch_size = batch_size
        if actual_batch_size > encoded_input_sentences.size(0):
            actual_batch_size = encoded_input_sentences.size(0)
        while encoded_input_sentences.size(0) % actual_batch_size != 0:
            actual_batch_size -= 1
        assert actual_batch_size >= 0, 'could not find a proper batch size'
        if batch_size != actual_batch_size:
            print(f'changing batch size from {batch_size} to {actual_batch_size} to fit inputs of size {encoded_input_sentences.size(0)}')
        # reshape encoded_input_sentences to [batch_size,max_sequence_length]
        encoded_input_sentences = encoded_input_sentences.view(
            -1,
            actual_batch_size,
            max_sequence_length,
        )
        predictions = list()
        for batch in tqdm(encoded_input_sentences):
            target = "<bos> "
            encoded_target = [
                [target_tokenizer.token_to_id(token) for token in target.split()]
                for _ in range(len(batch))
            ]
            finished_sequences = [False]*len(batch)
            for t in range(max_sequence_length):
                if all(finished_sequences):
                    # print(f'finished all sequences in the batch while we are in token {t}, early stopping')
                    break
                outputs = model(
                    src=batch,
                    trg=torch.LongTensor(encoded_target)
                    .view(len(batch), -1)
                    .to(model.device),
                )
                outputs = outputs[:, -1, :]
                # if temperature > 0:
                #     outputs = torch.softmax(outputs / temperature, dim=-1)
                #     next_word_id = torch.multinomial(outputs, num_samples=1).item()
                # else:
                next_words_ids = torch.argmax(outputs, dim=-1)
                for i,(target, next_word_id) in enumerate(zip(encoded_target, next_words_ids)):
                    next_word_id = next_word_id.item()
                    target.append(next_word_id)
                    if next_word_id == target_tokenizer.token_to_id("<eos>"):
                        finished_sequences[i] = True

            batch_predictions = []
            for target in encoded_target:
                prediction = target_tokenizer.detokenize(
                    target_tokenizer.decode(target)
                )
                if "<eos>" in prediction:
                    prediction = prediction[: prediction.index("<eos>") + 5]
                prediction = (
                    prediction.replace("<bos>", "<bos> ")
                    .replace("<eos>", " <eos>")
                    .strip()
                )
                batch_predictions.append(prediction)
            predictions += batch_predictions
        return predictions

get the scores as batches

In [24]:
preds = batch_translate(
    model=model,
    input_sentences=test_dataset[source_language_code],
    batch_size=241,
)

  0%|          | 0/5 [00:00<?, ?it/s]

In [25]:
get_blue_score(
            model=model,
            is_dotted=True,
            predictions=preds,
            show_translations_for=5,
            decode_with_beam_search=False,
            source_tokenizer=source_tokenizer,
            target_tokenizer=target_tokenizer,
            save_predictions_and_targets=False,
            max_sequence_length=sequence_length,
            source_language_code=source_language_code,
            target_language_code=target_language_code,
            source_sentences=test_dataset[source_language_code],
            target_sentences=test_dataset[target_language_code],
        )

  0%|          | 0/1205 [00:00<?, ?it/s]

  0%|          | 0/1205 [00:00<?, ?it/s]

Source: سأحدثكم اليوم عن 30 عاما من تاريخ الهندسة.
Prediction: i'm going to talk to you today about 30 years of the history of engineering.
Target: today i'm going to speak to you about the last 30 years of architectural history.
********************************************************************************
Source: هذا أمر كبير جدا لألخصه في 18 دقيقة.
Prediction: this is a very big thing to summarize in 18 minutes.
Target: that's a lot to pack into 18 minutes.
********************************************************************************
Source: إنه موضوع معقد ، لذلك فإننا سنتوجه مباشرة إلى مكان معقد: إلى نيو جيرسي ،
Prediction: it's a complicated topic, so we're going to head right into a complex place: to new jersey.
Target: it's a complex topic, so we're just going to dive right in at a complex place: new jersey.
********************************************************************************
Source: لأنه منذ 30 سنة ، أنا من نيوجيرسي ، كنت في السادسة من عمري ، وكنت أعيش هناك مع 

30.242

## Beam Search

In [26]:
class BeamNode:
    def __init__(self, tokens, score):
        self.tokens = tokens
        self.score = score


class Beam:
    def __init__(
        self,
        beam_width=4,
    ):
        self.nodes = []
        self.beam_width = beam_width

    def add(self, node):
        self.nodes.append(node)

    def get_topk(self):
        return heapq.nlargest(
            self.beam_width,
            self.nodes,
            key=lambda node: Beam.ordering_function(node),
        )

    def has_finished(self):
        return all(
            [
                node.tokens[-1] == target_tokenizer.token_to_id("<eos>")
                for node in self.nodes
            ]
        )

    @classmethod
    def ordering_function(
        cls,
        node,
        length_penalty_alpha=0.6,
        minimum_penalty_tokens_length=5,
    ):
        tokens = node.tokens
        eos_id = target_tokenizer.token_to_id("<eos>")
        if eos_id in tokens:
            tokens = tokens[: tokens.index(eos_id) + 1]
        return node.score / (
            (len(tokens) + minimum_penalty_tokens_length) ** length_penalty_alpha
        )

In [27]:
def beam_search(
    model,
    encoded_input_sentences,
    target_tokenizer,
    max_sequence_length,
    beam_width,
):
    predictions = []
    for batch in tqdm(encoded_input_sentences):
        beams = [Beam(beam_width) for _ in range(batch.size()[0])]
        for beam in beams:
            beam.add(
                BeamNode(tokens=[target_tokenizer.token_to_id("<bos>")], score=0.0)
            )
        for t in range(max_sequence_length):
            next_beams = [
                Beam(beam_width) for _ in range(len(batch))
            ]  # Initialize here
            best_beams_nodes = [beam.get_topk() for beam in beams]
            for k in range(beam_width):
                if k >= len(best_beams_nodes[0]):
                    continue
                encoded_targets = [
                    topk_nodes[k].tokens for topk_nodes in best_beams_nodes
                ]
                outputs = model(
                    src=batch,
                    trg=torch.LongTensor(encoded_targets).to(model.device),
                )
                outputs = F.softmax(outputs[:, -1, :], dim=-1)
                # topk_scores, topk_indices = torch.topk(outputs, beam_width)
                topk_scores, topk_indices = torch.topk(outputs, beam_width)
                batch_next_token_id = topk_indices.view(batch.size()[0], beam_width)
                batch_next_token_score = torch.log(
                    topk_scores.view(
                        batch.size()[0],
                        beam_width,
                    )
                )
                for beam_index, beam in enumerate(next_beams):
                    # Check if this sequence has already reached <eos>
                    if best_beams_nodes[beam_index][k].tokens[
                        -1
                    ] == target_tokenizer.token_to_id("<eos>"):
                        beam.add(
                            BeamNode(
                                tokens=best_beams_nodes[beam_index][k].tokens
                                + [target_tokenizer.token_to_id("<eos>")],
                                score=best_beams_nodes[beam_index][k].score,
                            )
                        )
                    else:
                        for k2 in range(beam_width):
                            beam.add(
                                BeamNode(
                                    tokens=best_beams_nodes[beam_index][k].tokens
                                    + [batch_next_token_id[beam_index][k2].item()],
                                    score=best_beams_nodes[beam_index][k].score
                                    + batch_next_token_score[beam_index][k2].item(),
                                )
                            )
            beams = next_beams  # Update beams for the next timestep
            if all(beam.has_finished() for beam in beams):
                break

        batch_predictions = []
        for beam in beams:
            best_node = max(beam.nodes, key=lambda node: Beam.ordering_function(node))
            translated_tokens = best_node.tokens
            translated_tokens = [
                target_tokenizer.id_to_token(token_id) for token_id in translated_tokens
            ]
            translated_sequence = target_tokenizer.detokenize(translated_tokens).strip()
            if "<eos>" in translated_sequence:
                translated_sequence = translated_sequence[
                    : translated_sequence.index("<eos>") + 5
                ]
            translated_sequence = (
                translated_sequence.replace("<bos>", "<bos> ")
                .replace("<eos>", " <eos>")
                .strip()
            )
            batch_predictions.append(translated_sequence)
        predictions += batch_predictions
    return predictions


def batch_translate_with_beam_search(
    model,
    input_sentences,
    max_sequence_length=sequence_length,
    source_tokenizer=source_tokenizer,
    target_tokenizer=target_tokenizer,
    batch_size=2,
    beam_width=4,
):
    model.eval()
    with torch.no_grad():
        encoded_input_sentences = (
            torch.tensor(
                create_features_from_text_list(
                    use_tqdm=False,
                    text_list=input_sentences,
                    tokenizer=source_tokenizer,
                    sequence_length=max_sequence_length,
                )
            )
            .view(-1, max_sequence_length)
            .to(model.device)
        )
        actual_batch_size = batch_size
        if actual_batch_size > encoded_input_sentences.size(0):
            actual_batch_size = encoded_input_sentences.size(0)
        while encoded_input_sentences.size(0) % actual_batch_size != 0:
            actual_batch_size -= 1
        assert actual_batch_size >= 0, "could not find a proper batch size"
        if batch_size != actual_batch_size:
            print(
                f"changing batch size from {batch_size} to {actual_batch_size} to fit inputs of size {encoded_input_sentences.size(0)}"
            )
        # reshape encoded_input_sentences to [batch_size,max_sequence_length]
        batched_encoded_input_sentences = encoded_input_sentences.view(
            -1,
            actual_batch_size,
            max_sequence_length,
        )
        predictions = beam_search(
            model,
            batched_encoded_input_sentences,
            target_tokenizer,
            max_sequence_length,
            beam_width,
        )
        return predictions

get preds in batches with beam search

In [28]:
preds = batch_translate_with_beam_search(
    model=model,
    input_sentences=test_dataset[source_language_code],
    batch_size=241,
)

  0%|          | 0/5 [00:00<?, ?it/s]

In [29]:
preds

['<bos> i &apos;m going to talk to you today about 30 years of the history of engineering . <eos>',
 '<bos> this is a very big thing to summarize in 18 minutes . <eos>',
 '<bos> it &apos;s a complicated topic , so we &apos;re going to direct a complex place : to new jersey . <eos>',
 '<bos> because 30 years ago , i was from new jersey , i was six years old , and i was living there with my father in a town named levittgenston , and this was a bedroom . <eos>',
 '<bos> at the corner of my room , there was a bathroom that i shared with my sister . <eos>',
 '<bos> and between my room and the bathroom , there was a balcony overlooking the living room . <eos>',
 '<bos> and that &apos;s where everybody would gather to watch tv , so every time i would go from my room to the bathroom , everyone would see me , and every time i bathed and pulled out my anchor , everybody would see me . <eos>',
 '<bos> i was looking like this . <eos>',
 '<bos> i was weird , i didn &apos;t feel secure , so i hated 

In [30]:
get_blue_score(
            model=model,
            is_dotted=True,
            predictions=preds,
            show_translations_for=5,
            decode_with_beam_search=False,
            source_tokenizer=source_tokenizer,
            target_tokenizer=target_tokenizer,
            save_predictions_and_targets=False,
            max_sequence_length=sequence_length,
            source_language_code=source_language_code,
            target_language_code=target_language_code,
            source_sentences=test_dataset[source_language_code],
            target_sentences=test_dataset[target_language_code],
        )

  0%|          | 0/1205 [00:00<?, ?it/s]

  0%|          | 0/1205 [00:00<?, ?it/s]

Source: سأحدثكم اليوم عن 30 عاما من تاريخ الهندسة.
Prediction: i'm going to talk to you today about 30 years of the history of engineering.
Target: today i'm going to speak to you about the last 30 years of architectural history.
********************************************************************************
Source: هذا أمر كبير جدا لألخصه في 18 دقيقة.
Prediction: this is a very big thing to summarize in 18 minutes.
Target: that's a lot to pack into 18 minutes.
********************************************************************************
Source: إنه موضوع معقد ، لذلك فإننا سنتوجه مباشرة إلى مكان معقد: إلى نيو جيرسي ،
Prediction: it's a complicated topic, so we're going to direct a complex place: to new jersey.
Target: it's a complex topic, so we're just going to dive right in at a complex place: new jersey.
********************************************************************************
Source: لأنه منذ 30 سنة ، أنا من نيوجيرسي ، كنت في السادسة من عمري ، وكنت أعيش هناك مع والدي في 

30.779

get preds without batches, one-by-one

In [31]:
get_blue_score(
            model=model,
            is_dotted=True,
            show_translations_for=5,
            decode_with_beam_search=True,
            source_tokenizer=source_tokenizer,
            target_tokenizer=target_tokenizer,
            save_predictions_and_targets=False,
            max_sequence_length=sequence_length,
            source_language_code=source_language_code,
            target_language_code=target_language_code,
            source_sentences=test_dataset[source_language_code],
            target_sentences=test_dataset[target_language_code],
        )

  0%|          | 0/1205 [00:00<?, ?it/s]

  0%|          | 0/1205 [00:00<?, ?it/s]

Source: سأحدثكم اليوم عن 30 عاما من تاريخ الهندسة.
Prediction: i'm going to talk to you today about 30 years of the history of engineering.
Target: today i'm going to speak to you about the last 30 years of architectural history.
********************************************************************************
Source: هذا أمر كبير جدا لألخصه في 18 دقيقة.
Prediction: this is a very big thing to summarize in 18 minutes.
Target: that's a lot to pack into 18 minutes.
********************************************************************************
Source: إنه موضوع معقد ، لذلك فإننا سنتوجه مباشرة إلى مكان معقد: إلى نيو جيرسي ،
Prediction: it's a complicated topic, so we're going to direct a complex place: to new jersey.
Target: it's a complex topic, so we're just going to dive right in at a complex place: new jersey.
********************************************************************************
Source: لأنه منذ 30 سنة ، أنا من نيوجيرسي ، كنت في السادسة من عمري ، وكنت أعيش هناك مع والدي في 

30.779