In [1]:
# %%capture
# %pip install datasets
# %pip install transformers

## Libraries and Dependencies

In [2]:
import pandas as pd
import torch
import nltk
from nltk.tokenize import sent_tokenize
from datasets import load_dataset
from tqdm.notebook import tqdm

# The models the authors used:
from transformers import BertForMaskedLM, BertTokenizer
from transformers import AlbertForMaskedLM, AlbertTokenizer

nltk.download('punkt')

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /Users/nazanin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cpu')

## Algorithm Implementation

In [None]:
def mask_sentence(sentence, mask_token, i, M, L_min):
    return [mask_token
            if (j - i) % M == 0
            and (len(sentence[j]) >= L_min
                 or sentence[j].startswith('##')
                 or sentence[min(j+1, len(sentence)-1)].startswith('##'))
            else sentence[j]
            for j in range(len(sentence))]

In [None]:
def BLANC_help(sentence, translation, model, tokenizer, M=6, L_min=4, sep='[SEP]', device='cpu'):
    """
    Calculates BLANC score between a given sentence and its translation using a specified model.

    Parameters:
    - sentence (List[str]): A tokenized sentence.
    - translation (List[str]): The tokenized translation.
    - model: BERT-type model
    - tokenizer: The tokenizer associated with the model used.
    - M (int): Parameter M for the algorithm (default is 6).
    - L_min (int): Minimum length requirement for masked words (default is 4).
    - sep (str): Separator between the inference help (filler/summary) and a sentence from the text (default is '[SEP]').

    Returns:
    - float: BLANC score for the given sentence and its translation.
    """

    filler = ['.'] * len(translation)
    S = [[0, 0], [0, 0]]

    for i in range(M):
        masked_sentence = mask_sentence(sentence, tokenizer.mask_token, i, M, L_min)

        input_base = filler + [sep] + masked_sentence
        input_help = translation + [sep] + masked_sentence

        tokenized_input_base = torch.tensor(tokenizer.convert_tokens_to_ids(input_base)).to(device) # Shape: [sequence_length]
        tokenized_input_help = torch.tensor(tokenizer.convert_tokens_to_ids(input_help)).to(device) # Shape: [sequence_length]

        out_base = model(input_ids=tokenized_input_base.unsqueeze(0)).logits  # Shape: [1, sequence_length, model_vocab_size]
        out_help = model(input_ids=tokenized_input_help.unsqueeze(0)).logits  # Shape: [1, sequence_length, model_vocab_size]

        out_base = torch.argmax(out_base.squeeze(0), dim=-1)  # Shape: [sequence_length]
        out_help = torch.argmax(out_help.squeeze(0), dim=-1)  # Shape: [sequence_length]

        masked_tokens = [idx for idx, word in enumerate(masked_sentence) if word == tokenizer.mask_token]

        for j in masked_tokens:
            idx = len(translation + [sep]) + j
            predicted_word_base = tokenizer.convert_ids_to_tokens(out_base[idx].item())
            predicted_word_help = tokenizer.convert_ids_to_tokens(out_help[idx].item())

            # print(f'predicted_word_base[{idx - len(translation + [sep])}]: {predicted_word_base}')
            # print(f'predicted_word_help[{idx - len(translation + [sep])}]: {predicted_word_help}')
            # print(f'sentence[{j}]: {sentence[j]}')

            k = int(predicted_word_base == sentence[j])
            m = int(predicted_word_help == sentence[j])
            S[k][m] += 1


    B = (S[0][1] - S[1][0]) / (S[0][0] + S[1][1] + S[0][1] + S[1][0])

    return B

## Datasets

In [4]:
news_commentary_ds = load_dataset('news_commentary', 'en-fr', split='train')
news_commentary_ds

Dataset({
    features: ['id', 'translation'],
    num_rows: 209479
})

In [5]:
parallel_df = pd.DataFrame(news_commentary_ds['translation'][:300])
parallel_df

Unnamed: 0,en,fr
0,"$10,000 Gold?",L’or à 10.000 dollars l’once ?
1,SAN FRANCISCO – It has never been easy to have...,SAN FRANCISCO – Il n’a jamais été facile d’avo...
2,"Lately, with gold prices up more than 300% ove...","Et aujourd’hui, alors que le cours de l’or a a..."
3,"Just last December, fellow economists Martin F...","En décembre dernier, mes collègues économistes..."
4,Wouldn’t you know it?,Mais devinez ce qui s’est passé ?
...,...,...
295,Although Abdullah is usually referred to in th...,Bien qu'Abdallah soit généralement considéré à...
296,"The Sudairis, it seems, have apparently left t...",Ils semblent avoir laissé leur demi-frère se c...
297,For although Crown Prince Abdullah has his own...,Même si le prince héritier Abdallah bénéficie ...
298,The idea of normalizing relations with Israel ...,L'idée d'une normalisation des relations avec ...


## Model and Tokenizer

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)
model = BertForMaskedLM.from_pretrained('bert-base-uncased').to(DEVICE)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'cls.seq_relationship.weight', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Preprocessing

In [9]:
sentences = [tokenizer.tokenize(sentence)
             for sentence in parallel_df['en']]  # (List[List[str]])

translations = [tokenizer.tokenize(translation)
                for translation in parallel_df['fr']] # (List[List[str]])

In [10]:
# print(f'longest sentence: {max([len(sent) for sent in sentences])}')
# print(f'longest translation: {max([len(sent) for sent in translations])}')

## Running the Program

In [11]:
BLANC_help(translations[0], sentences[0], model, tokenizer, device=DEVICE)

predicted_word_base[4]: .
predicted_word_help[4]: once
sentence[4]: gold


0.0

In [13]:
%%time
scores = [BLANC_help(translation, sentence, model, tokenizer, device=DEVICE)
          for translation, sentence in zip(translations, sentences)]

predicted_word_base[4]: .
predicted_word_help[4]: once
sentence[4]: gold
predicted_word_base[6]: been
predicted_word_help[6]: been
sentence[6]: been
predicted_word_base[12]: opinion
predicted_word_help[12]: ##e
sentence[12]: conversation
predicted_word_base[1]: francisco
predicted_word_help[1]: francisco
sentence[1]: francisco
predicted_word_base[7]: possible
predicted_word_help[7]: possible
sentence[7]: easy
predicted_word_base[13]: about
predicted_word_help[13]: about
sentence[13]: about
predicted_word_base[9]: have
predicted_word_help[9]: have
sentence[9]: have
predicted_word_base[15]: value
predicted_word_help[15]: nature
sentence[15]: value
predicted_word_base[5]: always
predicted_word_help[5]: always
sentence[5]: never
predicted_word_base[11]: real
predicted_word_help[11]: real
sentence[11]: rational
predicted_word_base[17]: .
predicted_word_help[17]: art
sentence[17]: gold
predicted_word_base[0]: now
predicted_word_help[0]: but
sentence[0]: lately
predicted_word_base[6]: more
pr

In [14]:
scores

[0.0,
 -0.1,
 0.0,
 0.13333333333333333,
 0.0,
 0.0,
 -0.09090909090909091,
 0.0,
 0.047619047619047616,
 0.0,
 0.14285714285714285,
 0.08333333333333333,
 0.05555555555555555,
 0.1,
 -0.07142857142857142,
 0.1,
 0.0,
 0.047619047619047616,
 0.0,
 0.0,
 0.0,
 0.26666666666666666,
 0.21428571428571427,
 0.0,
 0.047619047619047616,
 0.045454545454545456,
 0.0,
 0.0625,
 0.0625,
 0.05263157894736842,
 0.0,
 -0.047619047619047616,
 0.058823529411764705,
 0.0,
 0.0,
 0.0,
 0.0,
 -0.05555555555555555,
 0.0,
 0.09523809523809523,
 0.1111111111111111,
 0.0,
 0.0,
 0.0,
 0.125,
 -0.13333333333333333,
 0.0,
 0.0,
 0.0,
 0.045454545454545456,
 0.0,
 -0.09523809523809523,
 0.07142857142857142,
 0.2,
 -0.12,
 0.058823529411764705,
 0.1,
 0.125,
 0.0,
 0.0,
 0.0,
 -0.038461538461538464,
 0.043478260869565216,
 0.1111111111111111,
 0.0,
 0.0,
 0.0,
 -0.125,
 0.09090909090909091,
 0.18181818181818182,
 0.045454545454545456,
 -0.16666666666666666,
 0.0,
 0.07692307692307693,
 -0.058823529411764705,
 0.