In [None]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache/"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")

In [None]:
from transformers.models.mbart50.tokenization_mbart50_fast import FAIRSEQ_LANGUAGE_CODES
LANG_CODE_TO_FAIRSEQ_FORMAT = {long_language_code[:2]: long_language_code for long_language_code in FAIRSEQ_LANGUAGE_CODES}

In [None]:
article_en = ["The head of the United Nations says there is no military solution in Syria", "lol"]

num_beams = 2
tgt_lang_code = "he"
max_output_to_input_ratio = 1.2

model_inputs = tokenizer(article_en, return_tensors="pt", padding=True)
batch_size, input_length = model_inputs["input_ids"].shape

forced_bos_token_id = tokenizer.lang_code_to_id[LANG_CODE_TO_FAIRSEQ_FORMAT[tgt_lang_code]]

gen_output = model.generate(
    **model_inputs,
    forced_bos_token_id=forced_bos_token_id,
    num_beams=num_beams,
    num_return_sequences=num_beams,
    max_new_tokens=int(max_output_to_input_ratio * 1.2),
    return_dict_in_generate=True,
    output_scores=True,
)
# print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
tokenizer.batch_decode(gen_output.sequences)

In [None]:
def flatten(nested_list: list[list]) -> list:
    return [item for sublist in nested_list for item in sublist]

import torch
special_tokens = flatten([[toks] if isinstance(toks, str) else toks
                          for toks in tokenizer.special_tokens_map.values()])
special_token_ids = tokenizer.convert_tokens_to_ids(special_tokens)
special_token_ids = torch.tensor(special_token_ids)

In [None]:
sequences = gen_output.sequences.view(batch_size, num_beams, -1)
tokenizer.convert_ids_to_tokens(tokenizer("lol")["input_ids"])

In [None]:
sequences = gen_output.sequences.view(batch_size, num_beams, -1)
sequences = sequences[:, :, 1:]  # drop the eos token that starts generation
sequences = [[seq[seq != tokenizer.pad_token_id].tolist() for seq in beam] for beam in sequences]
scores = gen_output.sequences_scores.view(batch_size, num_beams).tolist()



In [None]:
from datasets import Dataset
dataset = Dataset.from_dict({"src_sentence": ["מדובר בחתול נאה מאוד", "אלליי, איזו מרשתת!"], "id": ["a", "b"]})
# dataset = dataset.map(tokenizer_many_to_en, batched=True, input_columns=["src_sentence"])
# tokenizer.batch_decode(model_many_to_en.generate(input_ids=torch.tensor([dataset[1]["input_ids"]]),
#                           attention_mask=torch.tensor([dataset[1]["attention_mask"]])))


In [None]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache/"
dataset.with_format(columns=["src_sentence"])[[1,0]]["src_sentence"][0]
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

In [None]:
def dict_of_lists_to_list_of_dicts(d: dict[list]) -> list[dict]:
    return [dict(zip(d.keys(), vals)) for vals in zip(*d.values())]


from transformers import DataCollatorForSeq2Seq
collator = DataCollatorForSeq2Seq(tokenizer)
dataset = dataset.map(tokenizer, input_columns="src_sentence", batched=True)
batch = dataset.with_format(columns=["input_ids", "attention_mask"])[[1,0]]
batch = dict_of_lists_to_list_of_dicts(batch)
collator(batch)

In [None]:
dataset = dataset.add_item({"src_sentence": "aaa", "id": "g", "input_ids": [3,4,4], "attention_mask": [1,1,1]})
dataset

In [None]:
from datasets import Dataset
from transformers import Seq2SeqTrainer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
dataset = Dataset.from_dict({"text": ["This is a setnence.", "How many woods are there in the woods?"]})
dataset = dataset.map(tokenizer, input_columns="text")
if not "forced_bos_token_id" in dataset.column_names:
    dataset = dataset.add_column("forced_bos_token_id", [tokenizer.lang_code_to_id["hi_IN"]] * len(dataset))
trainer_args = Seq2SeqTrainingArguments(output_dir='/tmp/lol', predict_with_generate=True)
trainer = Seq2SeqTrainer(model, args=trainer_args, data_collator=DataCollatorForSeq2Seq(tokenizer))

In [None]:
import torch

generation_kwargs = dict(forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"], length_penalty=1.0, num_beams=2, num_return_sequences=2)
def custom_generate(*args, **kwargs):
    num_beams = 2
    kwargs = {**kwargs, **generation_kwargs}
    generated_tokens = model.orig_generate(*args, **kwargs)
    generated_tokens = torch.hstack([generated_tokens, -100 * torch.ones((generated_tokens.shape[0], 1), dtype=int)])
    batch_size = generated_tokens.shape[0] // num_beams
    generated_tokens = generated_tokens.reshape(batch_size, -1)
    return generated_tokens
model.generate = custom_generate

In [None]:
concatenated_preds = trainer.predict(dataset).predictions
concatenated_preds.shape

In [None]:
tokenizer.convert_ids_to_tokens(gen_output["sequences"][0])

In [None]:
tokenizer_many_to_en = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
tokenizer_many_to_en.src_lang = LANG_CODE_TO_FAIRSEQ_FORMAT["he"]
tokenizer_many_to_en.convert_ids_to_tokens(tokenizer_many_to_en("זהו משפט בעברית.")["input_ids"])

In [None]:
import numpy as np
np.repeat(["a","fff"], 3)

In [None]:
import bert_score
from pathlib import Path
bertscore_baseline_languages = [path.name for path in (Path(bert_score.__file__).parent / "rescale_baseline").iterdir()]
bertscore_baseline_languages

In [None]:
from pathlib import Path
str(Path("a")) + "b"

In [None]:
import numpy as np

def flatten(nested_list: list[list]) -> list:
    return [item for sublist in nested_list for item in sublist]

num_beams = 2
preds = []
for pred in concatenated_preds:
    different_beams = np.array_split(pred, np.flatnonzero(pred == -100) + 1)
    different_beams = different_beams[:-1]  # last one is padding
    for beam_pred in different_beams:
        beam_pred = beam_pred[beam_pred != -100]
        preds.append(beam_pred)

tokenizer.batch_decode(preds, skip_special_tokens=True)

In [None]:
# trainer.data_collator(dataset.to_list())
# trainer.data_collator(dataset.to_dict(orient="list"))
batch = trainer.data_collator(dataset.to_pandas()[["input_ids", "attention_mask"]].to_dict(orient="records"))
# trainer.data_collator([dataset[i] for i in range(len(dataset))])

In [None]:
res = trainer.model.generate(**batch)
res

In [None]:
preds = trainer.predict(dataset)
preds.predictions.shape

In [None]:
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

In [None]:
article_en = "The head of the United Nations says there is no military solution in Syria"

model_inputs = tokenizer(article_en, return_tensors="pt")
generation_params = {"num_beams": 5, "length_penalty": 1.0}

# translate from English to Hindi
generated_tokens = model.generate(
    **model_inputs,
    forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"],
    **generation_params
)
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
# => 'संयुक्त राष्ट्र के नेता कहते हैं कि सीरिया में कोई सैन्य समाधान नहीं है'

# translate from English to Chinese
generated_tokens = model.generate(
    **model_inputs,
    forced_bos_token_id=tokenizer.lang_code_to_id["zh_CN"],
    **generation_params
)
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
# => '联合国首脑说,叙利亚没有军事解决办法'



In [None]:
model_many_to_en.generate(**tokenizer_many_to_en("אני חתול", return_tensors="pt"), forced_bos_token_id=250004)

In [None]:
model_many_to_en = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")


In [None]:
model_many_to_en = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
tokenizer_many_to_en = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")

In [None]:
article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."

# translate Hindi to English
tokenizer_many_to_en.src_lang = "hi_IN"
encoded_hi = tokenizer(article_hi, return_tensors="pt")
generated_tokens = model_many_to_en.generate(**encoded_hi, **generation_params)
print(tokenizer_many_to_en.batch_decode(generated_tokens, skip_special_tokens=True))
# => "The head of the UN says there is no military solution in Syria."

# translate Arabic to English
tokenizer_many_to_en.src_lang = "ar_AR"
encoded_ar = tokenizer_many_to_en(article_ar, return_tensors="pt")
generated_tokens = model_many_to_en.generate(**encoded_ar, **generation_params)
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
# => "The Secretary-General of the United Nations says there is no military solution in Syria."

