In [49]:
# %%capture
# !pip install datasets

## Libraries and Dependencies

In [50]:
import pandas as pd
import torch
import nltk
from datasets import load_dataset
from tqdm.notebook import tqdm

# The models the authors used:
from transformers import BertForMaskedLM, BertTokenizer

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [51]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

## Algorithm Implementation

In [52]:
def mask_sentence(sentence, mask_token, i, M, L_min):
    return [mask_token
            if (j - i) % M == 0
            and (len(sentence[j]) >= L_min
                 or sentence[j].startswith('##')
                 or sentence[min(j+1, len(sentence)-1)].startswith('##'))
            else sentence[j]
            for j in range(len(sentence))]

In [53]:
def BLANC_help(sentence, translation, model, tokenizer, M=6, L_min=4, sep='[SEP]', device='cpu'):
    """
    Calculates BLANC score between a given sentence and its translation using a specified model.

    Parameters:
    - sentence (List[str]): A tokenized sentence.
    - translation (List[str]): The tokenized translation.
    - model: BERT-type model
    - tokenizer: The tokenizer associated with the model used.
    - M (int): Parameter M for the algorithm (default is 6).
    - L_min (int): Minimum length requirement for masked words (default is 4).
    - sep (str): Separator between the inference help (filler/summary) and a sentence from the text (default is '[SEP]').

    Returns:
    - float: BLANC score for the given sentence and its translation.
    """

    filler = ['.'] * len(translation)
    S = [[0, 0], [0, 0]]

    for i in range(M):
        masked_sentence = mask_sentence(sentence, tokenizer.mask_token, i, M, L_min)

        input_base = filler + [sep] + masked_sentence
        input_help = translation + [sep] + masked_sentence

        tokenized_input_base = torch.tensor(tokenizer.convert_tokens_to_ids(input_base)).to(device) # Shape: [sequence_length]
        tokenized_input_help = torch.tensor(tokenizer.convert_tokens_to_ids(input_help)).to(device) # Shape: [sequence_length]

        out_base = model(input_ids=tokenized_input_base.unsqueeze(0)).logits  # Shape: [1, sequence_length, model_vocab_size]
        out_help = model(input_ids=tokenized_input_help.unsqueeze(0)).logits  # Shape: [1, sequence_length, model_vocab_size]

        out_base = torch.argmax(out_base.squeeze(0), dim=-1)  # Shape: [sequence_length]
        out_help = torch.argmax(out_help.squeeze(0), dim=-1)  # Shape: [sequence_length]

        masked_tokens = [idx for idx, word in enumerate(masked_sentence) if word == tokenizer.mask_token]

        for j in masked_tokens:
            idx = len(translation + [sep]) + j
            predicted_word_base = tokenizer.convert_ids_to_tokens(out_base[idx].item())
            predicted_word_help = tokenizer.convert_ids_to_tokens(out_help[idx].item())

            # print(f'predicted_word_base[{idx - len(translation + [sep])}]: {predicted_word_base}')
            # print(f'predicted_word_help[{idx - len(translation + [sep])}]: {predicted_word_help}')
            # print(f'sentence[{j}]: {sentence[j]}')

            k = int(predicted_word_base == sentence[j])
            m = int(predicted_word_help == sentence[j])
            S[k][m] += 1


    B = (S[0][1] - S[1][0]) / (S[0][0] + S[1][1] + S[0][1] + S[1][0])

    return B

## Datasets

In [54]:
# English - French
en_fr_ds = load_dataset('news_commentary', 'en-fr', split='train')

en_fr_df = pd.DataFrame(en_fr_ds['translation'][:300])
en_fr_df

Unnamed: 0,en,fr
0,"$10,000 Gold?",L’or à 10.000 dollars l’once ?
1,SAN FRANCISCO – It has never been easy to have...,SAN FRANCISCO – Il n’a jamais été facile d’avo...
2,"Lately, with gold prices up more than 300% ove...","Et aujourd’hui, alors que le cours de l’or a a..."
3,"Just last December, fellow economists Martin F...","En décembre dernier, mes collègues économistes..."
4,Wouldn’t you know it?,Mais devinez ce qui s’est passé ?
...,...,...
295,Although Abdullah is usually referred to in th...,Bien qu'Abdallah soit généralement considéré à...
296,"The Sudairis, it seems, have apparently left t...",Ils semblent avoir laissé leur demi-frère se c...
297,For although Crown Prince Abdullah has his own...,Même si le prince héritier Abdallah bénéficie ...
298,The idea of normalizing relations with Israel ...,L'idée d'une normalisation des relations avec ...


In [72]:
# English - Persian (Farsi)

en_fa_ds = load_dataset('persiannlp/parsinlu_translation_en_fa', split='train')

# Removing the 'category' column
en_fa_ds = en_fa_ds.remove_columns(['category'])

# Removing list encapsulation
en_fa_ds = en_fa_ds.map(lambda example: {'targets': example['targets'][0]}, num_proc=4)

# Filtering out rows with the '\u200c' symbol and those where the length of either source or targets is less than a threshold
length_threshold = 10
filtered_en_fa_ds = en_fa_ds.filter(
    lambda example: '\u200c' not in example['targets']
    and len(example['source']) >= length_threshold
    and len(example['targets']) >= length_threshold,
    num_proc=4)

en_fa_df = pd.DataFrame(filtered_en_fa_ds[:300])
en_fa_df

Filter (num_proc=4):   0%|          | 0/1621665 [00:00<?, ? examples/s]

Unnamed: 0,source,targets
0,Due Thank You note by Egyptian blogger Abdel M...,بلاگر مصری عبدل منعم محمود (عربی) پس از آزاد ش...
1,He was extremely surprised and happy to receiv...,وی همچنین از دریافت تعداد بسیار زیادی پیام تبر...
2,Monem blogs under the name of “Ana Ikwan”.,منعم به دلیل اتهامات سیاسی ۴۵ روز در زندان بود.
3,Ikhwan in Egyptian Arabic means Muslim Brother...,آزادی دینی در مصر
4,"On December 16, 2006, the Supreme Administrati...",در ۱۶ دسامبر ۲۰۰۶ شورای عالی اداری مصر که دولت...
...,...,...
295,Photos are included in the description of this...,وی عکس هایی را از این آهنگرانران به چاپ رسانده...
296,Turkey: Hrant Dink Named World Press Freedom H...,ارمنستان: قهرمان آزادی بیان
297,Jordan: New Traffic Law · Global Voices,اردن: ترافیک و دولت
298,Iraq: Yahoo Account Hacked · Global Voices,عراق: ای میل هک شده


## Model and Tokenizer

In [56]:
%%capture
mbert_model = BertForMaskedLM.from_pretrained('bert-base-multilingual-uncased').to(DEVICE)
mbert_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased', do_lower_case = True)

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Preprocessing

In [57]:
# English - French

en_fr_sentences = [mbert_tokenizer.tokenize(sentence)
                   for sentence in en_fr_df['en']]  # (List[List[str]])

en_fr_translations = [mbert_tokenizer.tokenize(translation)
                      for translation in en_fr_df['fr']] # (List[List[str]])

In [74]:
# English - Persian (Farsi)

en_fa_sentences = [mbert_tokenizer.tokenize(sentence)
                   for sentence in en_fa_df['source']]  # (List[List[str]])

en_fa_translations = [mbert_tokenizer.tokenize(translation)
                      for translation in en_fa_df['targets']] # (List[List[str]])

## Running the Program

English - French

In [59]:
BLANC_help(en_fr_translations[1], en_fr_sentences[1], mbert_model, mbert_tokenizer, device=DEVICE)

0.35714285714285715

In [62]:
%%time
en_fr_scores = [BLANC_help(translation, sentence, mbert_model, mbert_tokenizer, device=DEVICE)
                for translation, sentence in tqdm(zip(en_fr_translations, en_fr_sentences), total=len(en_fr_sentences))]

  0%|          | 0/300 [00:00<?, ?it/s]

CPU times: user 1min 8s, sys: 292 ms, total: 1min 9s
Wall time: 1min 25s


In [63]:
en_fr_scores

[0.0,
 0.35714285714285715,
 0.125,
 0.21621621621621623,
 -0.2,
 0.09523809523809523,
 0.0,
 0.05555555555555555,
 0.0,
 0.0,
 0.0,
 0.0,
 0.07692307692307693,
 0.23076923076923078,
 0.10526315789473684,
 0.0,
 0.2857142857142857,
 0.10256410256410256,
 0.0,
 0.0,
 0.13043478260869565,
 0.1836734693877551,
 0.17647058823529413,
 0.0,
 0.02857142857142857,
 0.13043478260869565,
 0.25,
 0.08,
 0.09523809523809523,
 0.04,
 0.0,
 0.07407407407407407,
 0.05263157894736842,
 0.0,
 0.0,
 0.0,
 0.08333333333333333,
 0.0,
 0.0,
 0.0,
 0.08333333333333333,
 0.0,
 0.0,
 0.0,
 0.16666666666666666,
 0.13043478260869565,
 0.0,
 0.16666666666666666,
 0.10344827586206896,
 0.07692307692307693,
 0.15384615384615385,
 0.0,
 0.26666666666666666,
 0.18181818181818182,
 0.029411764705882353,
 0.10526315789473684,
 0.13333333333333333,
 0.16129032258064516,
 0.21428571428571427,
 0.15625,
 0.14285714285714285,
 0.05405405405405406,
 0.1388888888888889,
 0.16666666666666666,
 0.14285714285714285,
 0.0,
 0.0

English - Persian (Farsi)

In [76]:
BLANC_help(en_fa_translations[1], en_fa_sentences[1], mbert_model, mbert_tokenizer, device=DEVICE)

0.25

In [77]:
%%time
en_fa_scores = [BLANC_help(translation, sentence, mbert_model, mbert_tokenizer, device=DEVICE)
                for translation, sentence in tqdm(zip(en_fa_translations, en_fa_sentences), total=len(en_fa_sentences))]

  0%|          | 0/300 [00:00<?, ?it/s]

CPU times: user 53 s, sys: 247 ms, total: 53.2 s
Wall time: 57.8 s


In [78]:
en_fa_scores

[0.07407407407407407,
 0.25,
 0.1,
 0.0,
 0.047619047619047616,
 0.125,
 0.045454545454545456,
 0.0,
 0.08333333333333333,
 0.0,
 0.0,
 0.0,
 0.1,
 -0.2,
 -0.14285714285714285,
 0.2,
 0.0,
 0.0,
 0.0,
 0.0625,
 0.0,
 0.09090909090909091,
 0.0,
 0.0,
 0.125,
 0.043478260869565216,
 -0.16666666666666666,
 0.0,
 0.0,
 0.0,
 0.05555555555555555,
 0.1875,
 0.07142857142857142,
 0.034482758620689655,
 0.75,
 0.0,
 0.0,
 0.0,
 0.07692307692307693,
 0.0,
 0.2,
 0.0,
 0.0,
 0.0,
 0.09090909090909091,
 0.09090909090909091,
 0.4,
 0.05333333333333334,
 0.0,
 0.19047619047619047,
 0.05555555555555555,
 0.2857142857142857,
 0.0,
 0.07692307692307693,
 0.0,
 0.3333333333333333,
 0.0,
 0.0,
 0.25,
 0.0,
 0.0,
 0.0,
 0.0,
 -0.1111111111111111,
 -0.1111111111111111,
 0.09090909090909091,
 0.0,
 0.16666666666666666,
 0.125,
 0.1111111111111111,
 0.09090909090909091,
 -0.046511627906976744,
 0.0,
 -0.1111111111111111,
 0.0,
 0.0,
 0.06666666666666667,
 0.0,
 0.36363636363636365,
 -0.08333333333333333,
 0