In [2]:
import transformers 
import numpy as np
#from transformers import M2M100Tokenizer, M2M100Model
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import wandb

from torch.utils.checkpoint import checkpoint

In [None]:
raw_datasets = load_dataset("ted_hrlr", "az_to_en")
metric = load_metric("sacrebleu")

model_checkpoint = "facebook/m2m100_418M"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, src_lang="az", tgt_lang="en")
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# We have to set the model decoding to force the target language as the bos token. 
model.config.forced_bos_token_id = tokenizer.get_lang_id("en")

if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "translate English to Romanian: "
else:
    prefix = ""

max_input_length = 128
max_target_length = 128
source_lang = "az"
target_lang = "en"

def preprocess_function(examples):
    inputs = [prefix + ex[source_lang] for ex in examples["translation"]]
    targets = [ex[target_lang] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

In [4]:
train_az = tokenized_datasets['train']
test_az = tokenized_datasets['test']

# List of sentences as array containing IDs of tokens
s_train_ids_az = train_az['input_ids']
s_test_ids_az = test_az['input_ids']

# All tokens in the train and test set
ids_train_az = np.concatenate(train_az['input_ids']).ravel()
ids_test_az = np.concatenate(test_az['input_ids']).ravel()

# Unique tokens in test set
unique_tokens_test_az = np.unique(ids_test_az)

# Count how many times unique_tokens_test_az appears in ids_train_az
count_tokens_test_az = np.array([np.count_nonzero(ids_train_az == token) for token in unique_tokens_test_az])

# Link the tokens to their frequencies
token_freq_test_az = dict(zip(unique_tokens_test_az, count_tokens_test_az))
# Sort the dictionary by value
token_freq_test_az = {k: v for k, v in sorted(token_freq_test_az.items(), key=lambda item: item[1])}
# Find the tokens with value 0
oov_tokens_test_az = {k: v for k, v in token_freq_test_az.items() if v == 0}

# Find the indices of the arrays in s_test_ids_az that contain the oov tokens
oov_indices_test_az = []
for i, array in enumerate(s_test_ids_az):
    if any(token in array for token in oov_tokens_test_az):
        oov_indices_test_az.append(i)

# Find the arrays in s_test_ids_az that contain the oov tokens
oov_s_test_ids_az = [s for s in s_test_ids_az if any(token in s for token in oov_tokens_test_az)]

# Concatenate oov_s_test_ids_az and oov_indices_test_az in a dictionary
oov_test_az = dict(zip(oov_indices_test_az, oov_s_test_ids_az))

# For each sentence in oov_s_test_ids_az, find the number of oov tokens and save it in a dictionary
oov_s_test_ids_az_count = {}
for s in oov_s_test_ids_az:
    count = 0
    for token in s:
        if token in oov_tokens_test_az:
            count += 1
    oov_s_test_ids_az_count[tuple(s)] = count

# Sort the dictionary by value
oov_s_test_ids_az_count = {k: v for k, v in sorted(oov_s_test_ids_az_count.items(), key=lambda item: item[1], reverse=True)}

# Convert the first 10 sentences from ids to tokens and print them along the number of oov tokens
print('\n== Printing the 10 sentences with the most OOV tokens ==\n')
for s in list(oov_s_test_ids_az_count.keys())[:10]:
    print('== Sentence ==')
    print(tokenizer.decode(s), oov_s_test_ids_az_count[s])
    print('== END Sentence ==')
print('\n== END ==\n')

# Retrieve the index of the first 10 sentences in oov_s_test_ids_az_count by looking at oov_test_az dictionary
oov_test_az_index = []
for s in list(oov_s_test_ids_az_count.keys())[:10]:
    for k, v in oov_test_az.items():
        if tuple(s) == tuple(v):
            oov_test_az_index.append(k)
            
print('\n== Printing indices ==\n')
print(oov_test_az_index)

[5947   59   42 ...    3    1 5947]
{114: 0, 389: 0, 567: 0, 634: 0, 641: 0, 847: 0, 887: 0, 956: 0, 1001: 0, 1056: 0, 1103: 0, 1152: 0, 1243: 0, 1291: 0, 1787: 0, 1798: 0, 1910: 0, 2072: 0, 2135: 0, 2158: 0, 2355: 0, 2394: 0, 2607: 0, 2707: 0, 2738: 0, 2891: 0, 3024: 0, 3152: 0, 3270: 0, 3360: 0, 3383: 0, 3741: 0, 3748: 0, 3907: 0, 3964: 0, 4076: 0, 4088: 0, 4226: 0, 4235: 0, 4741: 0, 4754: 0, 4886: 0, 5110: 0, 5210: 0, 5213: 0, 5293: 0, 5296: 0, 5372: 0, 5670: 0, 6214: 0, 6373: 0, 6470: 0, 6752: 0, 6816: 0, 6953: 0, 6967: 0, 7059: 0, 7112: 0, 7277: 0, 7529: 0, 8047: 0, 8216: 0, 8259: 0, 8444: 0, 8642: 0, 8668: 0, 8750: 0, 9640: 0, 9718: 0, 9734: 0, 9745: 0, 9907: 0, 10129: 0, 10583: 0, 10672: 0, 10720: 0, 10781: 0, 11103: 0, 11359: 0, 11863: 0, 12184: 0, 12363: 0, 12454: 0, 12620: 0, 12628: 0, 12647: 0, 12878: 0, 13379: 0, 13702: 0, 13708: 0, 13726: 0, 13766: 0, 13824: 0, 14846: 0, 15477: 0, 15742: 0, 15901: 0, 15911: 0, 16358: 0, 17036: 0, 17327: 0, 17398: 0, 17585: 0, 17837: 0, 182