In [None]:
import torch
from nltk.tokenize import sent_tokenize
from datasets import load_dataset
from tqdm.notebook import tqdm

# The models the authors used:
from transformers import BertForMaskedLM, BertTokenizer

from blanc import BLANC_tune_summary, BLANC_tune_translation, add_results_to_json

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

bert_checkpoint = "bert-base-uncased"
bert_model = BertForMaskedLM.from_pretrained(bert_checkpoint).to(DEVICE)
bert_tokenizer = BertTokenizer.from_pretrained(bert_checkpoint, do_lower_case=True)

mbert_checkpoint = "bert-base-multilingual-uncased"
mbert_model = BertForMaskedLM.from_pretrained(mbert_checkpoint).to(DEVICE)
mbert_tokenizer = BertTokenizer.from_pretrained(mbert_checkpoint, do_lower_case=True)

# BLANC tune for **summaries**

In [None]:
""" Datasets """

DailyNews_ds = load_dataset("json", data_files="DailyNews_300.json", split="train")
DailyNews_ds

In [None]:
""" Preprocessing """

summaries = DailyNews_ds["summary"]  # (List[str])
texts = DailyNews_ds[
    "text"
]  # (List[str]) each string is a paragraph made of a few sentences

# each text in texts is a list of sentences (each sentence is a string)
texts = [sent_tokenize(text.strip()) for text in texts]  # List[List[str]]
assert len(texts) == len(summaries) == 300

tokenized_texts = [
    [bert_tokenizer.tokenize(sentence) for sentence in text] for text in texts
]  # List[List[List[str]]]
tokenized_summaries = [
    bert_tokenizer.tokenize(summary) for summary in summaries
]  # [List[List[str]]]

In [None]:
""" Running the Program """

tune_summary_scores = [
    BLANC_tune_summary(
        text, summary, bert_checkpoint, bert_model, bert_tokenizer, device=DEVICE
    )
    for text, summary in tqdm(
        zip(tokenized_texts, tokenized_summaries), total=len(tokenized_texts)
    )
]

# Saving the results
summary_data = {}
summary_data["BLANC_tune_summary"] = tune_summary_scores
add_results_to_json(summary_data)

# BLANC tune for **translations**

In [None]:
""" Datasets """

# English - French
en_fr_ds = load_dataset("news_commentary", "en-fr", split="train")

# English - Persian (Farsi)
en_fa_ds = load_dataset("persiannlp/parsinlu_translation_en_fa", split="train")

In [None]:
""" Preprocessing (English - French)"""

en_fr_ds = (
    en_fr_ds.map(lambda example: example["translation"])
    .remove_columns(["id", "translation"])
    .rename_column("en", "sentence")
    .rename_column("fr", "translation")
    .select(range(300))
)

# Tokenization
en_fr_sentences = [
    mbert_tokenizer.tokenize(sentence) for sentence in en_fr_ds["sentence"]
]  # (List[List[str]])

en_fr_translations = [
    mbert_tokenizer.tokenize(translation) for translation in en_fr_ds["translation"]
]  # (List[List[str]])


""" Preprocessing (English - Persian (Farsi)) """

# Removing the 'category' column
en_fa_ds = en_fa_ds.remove_columns(["category"])

# Removing list encapsulation
en_fa_ds = en_fa_ds.map(lambda example: {"targets": example["targets"][0]}, num_proc=4)

# Filtering out:
# - rows with the '\u200c' symbol,
# - those where the length of either source or targets is less than a threshold
# - Headlines (ending in 'Global Voices') --> because they are very short and the 'Global Voices' part is never translated
length_threshold = 30
filtered_en_fa_ds = en_fa_ds.filter(
    lambda example: "\u200c" not in example["targets"]
    and len(example["source"]) >= length_threshold
    and len(example["targets"]) >= length_threshold
    and "Global Voices" not in example["source"],
    num_proc=4,
)

en_fa_ds = (
    filtered_en_fa_ds.rename_column("source", "sentence")
    .rename_column("targets", "translation")
    .select(range(300))
)

# Tokenization
en_fa_sentences = [
    mbert_tokenizer.tokenize(sentence) for sentence in en_fa_ds["sentence"]
]  # (List[List[str]])

en_fa_translations = [
    mbert_tokenizer.tokenize(translation) for translation in en_fa_ds["translation"]
]  # (List[List[str]])

In [None]:
""" Running the Program (English - French)"""

tune_en_fr_scores = [
    BLANC_tune_translation(
        sentences,
        translations,
        mbert_checkpoint,
        mbert_model,
        mbert_tokenizer,
        device=DEVICE,
    )
    for sentences, translations in tqdm(
        zip(en_fr_sentences, en_fr_translations), total=len(en_fr_sentences)
    )
]

# Saving the results
en_fr_data = {}
en_fr_data["BLANC_tune_en_fa_translation"] = tune_en_fr_scores
add_results_to_json(en_fr_data)

In [None]:
""" Running the Program (English - Persian)"""

tune_en_fa_scores = [
    BLANC_tune_translation(
        sentences,
        translations,
        mbert_checkpoint,
        mbert_model,
        mbert_tokenizer,
        device=DEVICE,
    )
    for sentences, translations in tqdm(
        zip(en_fa_sentences, en_fa_translations), total=len(en_fa_sentences)
    )
]

# Saving the results
en_fa_data = {}
en_fa_data["BLANC_tune_en_fa_translation"] = tune_en_fa_scores
add_results_to_json(en_fa_data)