In [None]:
from mbart50_translate import LANG_CODE_TO_FAIRSEQ_FORMAT
from pathlib import Path
data_dir = Path("data/wmt20")
src_lang = "en"
available_languages = [tgt_lang for tgt_lang in LANG_CODE_TO_FAIRSEQ_FORMAT.keys()
                       if (data_dir / f"wmt20.{src_lang}-{tgt_lang}.src").exists()]
for tgt_lang in available_languages:
    print(f'GPUS=1 run_with_slurm {tgt_lang} $(which run_python_script) -m mbart50_translate '
          f'--num_examples=500 --data_dir="data/wmt20" --dump_dir="mbart50_dumps" '
          f'--src_lang="en" --tgt_lang="{tgt_lang}" ')
print()


In [None]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache/"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
collator = DataCollatorForSeq2Seq(tokenizer)

In [None]:
tokenizer2 = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-one-mmt", src_lang="he_IL")

In [None]:
tokenizer2(" בעבריצ משפט")

In [None]:
from mbart50_translate import dict_of_lists_to_list_of_dicts
import torch
LABEL_PAD = -100
tokenizer.src_lang = "en_XX"
batch = tokenizer(["One sentence", "Here's another sentence"])
tokenizer.src_lang = "he_IL"
batch["labels"] = tokenizer(["משפט אחד", "הנה עוד משפט"])["input_ids"]
batch = dict_of_lists_to_list_of_dicts(batch)
batch = collator(batch)
input_ids, attention_mask, labels = [batch[k] for k in ["input_ids", "attention_mask", "labels"]]
decoder_input_ids = torch.where(
    labels != LABEL_PAD, labels, tokenizer.pad_token_id)
start_of_generation_eos_column = tokenizer.eos_token_id * \
    labels.new_ones((labels.shape[0], 1))
decoder_input_ids = torch.concat(
    [start_of_generation_eos_column, decoder_input_ids], dim=1)
decoder_attention_mask = (
    decoder_input_ids != tokenizer.pad_token_id).int()

label_pad_column = LABEL_PAD * labels.new_ones((labels.shape[0], 1))
labels_for_loss = torch.concat(
    [label_pad_column, labels[:, 1:], label_pad_column], dim=1)

forward_just_labels = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
forward_all = model(input_ids=input_ids, attention_mask=attention_mask,
 decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask,
 labels=labels_for_loss)


In [None]:
forward_just_labels.logits.argmax(-1), forward_all.logits.argmax(-1)

In [None]:
forward_just_labels.logits.argmax(-1), batch["labels"]

In [None]:
def calc_loss(forward_out, decoder_input_ids, labels):
    logprobs = forward_out.logits.log_softmax(dim=-1)
    # without eos that starts generation, without forced bos token
    if labels is None:
        labels = decoder_input_ids[:, 2:].unsqueeze(-1)
        logprobs = logprobs[:, 1:-1]
    else:
        labels = labels[:,1:].unsqueeze(-1)
        labels = torch.where(labels != -100, labels, 1)
        # without forced bos token, without extra predicted token at the end
        logprobs = logprobs[:, 1:]
    labels_mask = (labels != tokenizer.pad_token_id)
    label_logprobs = logprobs.gather(index=labels, dim=-1)
    label_logprobs = torch.where(
        labels_mask, label_logprobs, logprobs.new([0.])).squeeze(-1)
    sequence_logprobs = label_logprobs.sum(
        dim=-1) / labels_mask.squeeze(-1).sum(dim=-1)
    manual_loss = -label_logprobs.sum() / labels_mask.sum()
    return manual_loss, sequence_logprobs

manual_loss_all, sequence_logprobs_all = calc_loss(forward_all, decoder_input_ids, None)
manual_loss_just_labels, sequence_logprobs_just_labels = calc_loss(forward_just_labels, None, labels)
manual_loss_all, forward_all.loss, manual_loss_just_labels, forward_just_labels.loss
sequence_logprobs_all, sequence_logprobs_just_labels

In [None]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache/"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")

In [None]:
from transformers.models.mbart50.tokenization_mbart50_fast import FAIRSEQ_LANGUAGE_CODES
LANG_CODE_TO_FAIRSEQ_FORMAT = {long_language_code[:2]: long_language_code for long_language_code in FAIRSEQ_LANGUAGE_CODES}

In [None]:
article_en = ["The head of the United Nations says there is no military solution in Syria", "lol"]

num_beams = 1
tgt_lang_code = "he"
max_output_to_input_ratio = 1.2

model_inputs = tokenizer(article_en, return_tensors="pt", padding=True)
batch_size, input_length = model_inputs["input_ids"].shape

forced_bos_token_id = tokenizer.lang_code_to_id[LANG_CODE_TO_FAIRSEQ_FORMAT[tgt_lang_code]]

gen_output = model.generate(
    **model_inputs,
    forced_bos_token_id=forced_bos_token_id,
    num_beams=num_beams,
    num_return_sequences=num_beams,
    max_new_tokens=int(max_output_to_input_ratio * input_length),
    return_dict_in_generate=True,
    output_scores=True,
)
# print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
tokenizer.batch_decode(gen_output.sequences)

In [None]:
import torch
input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"]
decoder_input_ids = gen_output.sequences
decoder_attention_mask = (decoder_input_ids != tokenizer.pad_token_id).int()
labels_for_loss = torch.concat([-100*torch.ones((batch_size,1), dtype=int),torch.where(decoder_input_ids != 1, decoder_input_ids, -100)[:,2:], -100*torch.ones((batch_size,1), dtype=int)], dim=-1)
forward_out = model(input_ids=input_ids, attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask,
                    labels=labels_for_loss)

logprobs = forward_out.logits.log_softmax(dim=-1)
logprobs = logprobs[:,1:-1]  # without forced bos token, without extra predicted token at the end
labels = decoder_input_ids[:,2:].unsqueeze(-1)  # without eos that starts generation, without forced bos token
labels_mask = (labels != tokenizer.pad_token_id)
label_logprobs = logprobs.gather(index=labels, dim=-1)
label_logprobs = torch.where(labels_mask, label_logprobs, logprobs.new([0.])).squeeze(-1)
sequence_prob = label_logprobs.sum(dim=-1) / labels_mask.squeeze(-1).sum(dim=-1)



In [None]:
from mbart50_translate import MBart50Translator

In [None]:
forward_out.loss, label_logprobs.sum() / labels_mask.sum()
# labels_for_loss, decoder_input_ids, forward_out.logits.argmax(-1)

In [None]:
forward_out2 = model(input_ids=input_ids, attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids[:,1:], decoder_attention_mask=decoder_attention_mask[:,1:])


In [None]:
forward_out2.logits.argmax(dim=-1), gen_output.sequences
tokenizer.batch_decode(forward_out2.logits.argmax(dim=-1))

In [None]:
from transformers import MBartForConditionalGeneration
from torch.nn import CrossEntropyLoss

In [None]:
logprobs = forward_out.logits.log_softmax(dim=-1)
logprobs = logprobs[:,1:-1]  # without forced bos token, without extra predicted token at the end
labels = decoder_input_ids[:,2:].unsqueeze(-1)  # without eos that starts generation, without forced bos token
labels_mask = (labels != tokenizer.pad_token_id)
label_logprobs = logprobs.gather(index=labels, dim=-1)
label_logprobs = torch.where(labels_mask, label_logprobs, logprobs.new([0.])).squeeze(-1)
sequence_prob = label_logprobs.sum(dim=-1) / labels_mask.squeeze(-1).sum(dim=-1)



In [None]:
import torch
LABEL_PAD = -100
forward_out.logits.gather(dim=-1, index=decoder_input_ids.unsqueeze(-1))
forward_out.logits[:,1:].argmax(dim=-1), decoder_input_ids[:,2:]
label_pad_column = LABEL_PAD * decoder_input_ids.new_ones((batch_size,1))
faux_labels = torch.concat([label_pad_column, decoder_input_ids[:, 2:], label_pad_column], dim=1)
faux_labels = torch.where(faux_labels != tokenizer.pad_token_id, faux_labels, LABEL_PAD)
forward_out.logits.argmax(dim=-1), faux_labels


In [None]:
def flatten(nested_list: list[list]) -> list:
    return [item for sublist in nested_list for item in sublist]

import torch
special_tokens = flatten([[toks] if isinstance(toks, str) else toks
                          for toks in tokenizer.special_tokens_map.values()])
special_token_ids = tokenizer.convert_tokens_to_ids(special_tokens)
special_token_ids = torch.tensor(special_token_ids)

In [None]:
sequences = gen_output.sequences.view(batch_size, num_beams, -1)
tokenizer.convert_ids_to_tokens(tokenizer("lol")["input_ids"])

In [None]:
sequences = gen_output.sequences.view(batch_size, num_beams, -1)
sequences = sequences[:, :, 1:]  # drop the eos token that starts generation
sequences = [[seq[seq != tokenizer.pad_token_id].tolist() for seq in beam] for beam in sequences]
scores = gen_output.sequences_scores.view(batch_size, num_beams).tolist()



In [None]:
from datasets import Dataset
dataset = Dataset.from_dict({"src_sentence": ["מדובר בחתול נאה מאוד", "אלליי, איזו מרשתת!"], "id": ["a", "b"]})
# dataset = dataset.map(tokenizer_many_to_en, batched=True, input_columns=["src_sentence"])
# tokenizer.batch_decode(model_many_to_en.generate(input_ids=torch.tensor([dataset[1]["input_ids"]]),
#                           attention_mask=torch.tensor([dataset[1]["attention_mask"]])))


In [None]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache/"
dataset.with_format(columns=["src_sentence"])[[1,0]]["src_sentence"][0]
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

In [None]:
def dict_of_lists_to_list_of_dicts(d: dict[list]) -> list[dict]:
    return [dict(zip(d.keys(), vals)) for vals in zip(*d.values())]


from transformers import DataCollatorForSeq2Seq
collator = DataCollatorForSeq2Seq(tokenizer)
dataset = dataset.map(tokenizer, input_columns="src_sentence", batched=True)
batch = dataset.with_format(columns=["input_ids", "attention_mask"])[[1,0]]
batch = dict_of_lists_to_list_of_dicts(batch)
collator(batch)

In [None]:
dataset = dataset.add_item({"src_sentence": "aaa", "id": "g", "input_ids": [3,4,4], "attention_mask": [1,1,1]})
dataset

In [None]:
from datasets import Dataset
from transformers import Seq2SeqTrainer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
dataset = Dataset.from_dict({"text": ["This is a setnence.", "How many woods are there in the woods?"]})
dataset = dataset.map(tokenizer, input_columns="text")
if not "forced_bos_token_id" in dataset.column_names:
    dataset = dataset.add_column("forced_bos_token_id", [tokenizer.lang_code_to_id["hi_IN"]] * len(dataset))
trainer_args = Seq2SeqTrainingArguments(output_dir='/tmp/lol', predict_with_generate=True)
trainer = Seq2SeqTrainer(model, args=trainer_args, data_collator=DataCollatorForSeq2Seq(tokenizer))

In [None]:
import torch

generation_kwargs = dict(forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"], length_penalty=1.0, num_beams=2, num_return_sequences=2)
def custom_generate(*args, **kwargs):
    num_beams = 2
    kwargs = {**kwargs, **generation_kwargs}
    generated_tokens = model.orig_generate(*args, **kwargs)
    generated_tokens = torch.hstack([generated_tokens, -100 * torch.ones((generated_tokens.shape[0], 1), dtype=int)])
    batch_size = generated_tokens.shape[0] // num_beams
    generated_tokens = generated_tokens.reshape(batch_size, -1)
    return generated_tokens
model.generate = custom_generate

In [None]:
concatenated_preds = trainer.predict(dataset).predictions
concatenated_preds.shape

In [None]:
tokenizer.convert_ids_to_tokens(gen_output["sequences"][0])

In [None]:
tokenizer_many_to_en = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
tokenizer_many_to_en.src_lang = LANG_CODE_TO_FAIRSEQ_FORMAT["he"]
tokenizer_many_to_en.convert_ids_to_tokens(tokenizer_many_to_en("זהו משפט בעברית.")["input_ids"])

In [None]:
import numpy as np
np.repeat(["a","fff"], 3)

In [None]:
import bert_score
from pathlib import Path
bertscore_baseline_languages = [path.name for path in (Path(bert_score.__file__).parent / "rescale_baseline").iterdir()]
bertscore_baseline_languages

In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache/"
from mbart50_translate import MBart50Translator
runner = MBart50Translator(device="cpu", num_examples=200, batch_size=2, data_dir="data/wmt20", src_lang="en", tgt_lang="ta")


In [None]:
import torch
input_ids=torch.tensor([[250004,     62,  75533,  13416,   4568,   6602,  14037,     99,  19713,
          35389,  77987,  27941,      7,     23,     70,   7082,   1902,   2809,
          61689,   5281,    111,  17688,    538,     10,   1192,    202,   1916,
            707, 162753,    449,   2363,  44828,  26255,    645,     10,  14922,
            111,  22759,   5369,      5,      2,      1,      1,      1,      1,
              1],
        [250004,   6300,   1177,     33,   7582,   3640,    136,  18982,  11075,
              4,  12638,  15889, 125413,   1221,  27154,     67,   2363,   7175,
              4,    678,  56480,   8035,  19667,     19,   8305,  40101,     10,
          85727,   1118,    707,  72761, 233547,     20, 117934,     10,  15889,
             28,  27591,    818,  12126,   7175,  21771,  32316,      7,      5,
              2],
        [250004,   1529,     25,      7,   7730,     47,    186,  37515,      4,
           1284,    450,  22027,     25,     18,  16401,    398,   5792,   4989,
             23,    903,   6712,      5,      2,      1,      1,      1,      1,
              1,      1,      1,      1,      1,      1,      1,      1,      1,
              1,      1,      1,      1,      1,      1,      1,      1,      1,
              1],
        [250004,  11853,      4,     10,  27150,    332,  74703,      7,  41896,
             99,     10,  33816,   8752,      6,  92621,   9149,  10519,     10,
              6,  44720,  53470,     53,  43613,  20016,     15,  22489,     73,
            434,  54969,     83,  49726,     71,   1660, 107314,   4049,      4,
         179493,     10,  57571,    384,      9,  38184,    194,      2,      1,
              1]])
attention_mask=torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
labels=torch.tensor([[250004,     62,  75533,  13416,   4568,   6602,  14037,     99,
          19713,  35389,  77987,  27941,      7,     23,     70,   7082,   1902,
           2809,  61689,   5281,    111,  17688,    538,     10,   1192,    202,
           1916,    707, 162753,    449,   2363,  44828,  26255,    645,     10,
          14922,    111,  22759,   5369,      5,      2,      1,      1,      1,
              1,      1],
        [250004,   6300,   1177,     33,   7582,   3640,    136,  18982,
          11075,      4,  12638,  15889, 125413,   1221,  27154,     67,   2363,
           7175,      4,    678,  56480,   8035,  19667,     19,   8305,  40101,
             10,  85727,   1118,    707,  72761, 233547,     20, 117934,     10,
          15889,     28,  27591,    818,  12126,   7175,  21771,  32316,      7,
              5,      2],
        [250004,   1529,     25,      7,   7730,     47,    186,  37515,
              4,   1284,    450,  22027,     25,     18,  16401,    398,   5792,
           4989,     23,    903,   6712,      5,      2,      1,      1,      1,
              1,      1,      1,      1,      1,      1,      1,      1,      1,
              1,      1,      1,      1,      1,      1,      1,      1,      1,
              1,      1],
        [250004,  11853,      4,     10,  27150,    332,  74703,      7,
          41896,     99,     10,  33816,   8752,      6,  92621,   9149,  10519,
             10,      6,  44720,  53470,     53,  43613,  20016,     15,  22489,
             73,    434,  54969,     83,  49726,     71,   1660, 107314,   4049,
              4, 179493,     10,  57571,    384,      9,  38184,    194,      2,
              1,      1]])
labels = torch.where(labels != 1, labels, -100)
forward_args, forward_out = runner._calculate_target_logprobs(input_ids, attention_mask, labels)

forward_args["labels"], labels, forward_out.logits.argmax(-1)

decoder_input_ids = torch.where(labels != LABEL_PAD, labels, tokenizer.pad_token_id)
eos_column = tokenizer.eos_token_id * labels.new_ones((labels.shape[0], 1))
decoder_input_ids = torch.concat([eos_column, decoder_input_ids], dim=1)  # eos token marks start of generation

logprobs = forward_out.logits.log_softmax(dim=-1)
logprobs = logprobs[:,1:-1]  # without forced bos token, without extra predicted token at the end
labels = decoder_input_ids[:,2:].unsqueeze(-1)  # without eos that starts generation, without forced bos token
labels_mask = (labels != tokenizer.pad_token_id)
label_logprobs = logprobs.gather(index=labels, dim=-1)
label_logprobs = torch.where(labels_mask, label_logprobs, logprobs.new([0.])).squeeze(-1)
sequence_logprobs = label_logprobs.sum(dim=-1) / labels_mask.squeeze(-1).sum(dim=-1)
manual_loss = -label_logprobs.sum() / labels_mask.sum()
manual_loss

In [None]:
# print(forward_out.logits.shape)
forward_out.logits.log_softmax(-1)[0,1,62], label_logprobs
labels.squeeze(), logprobs.argmax(-1), forward_args["labels"], forward_out.logits.argmax(-1)

In [None]:
runner._run_metrics_calculation()

In [None]:
example = {"gen_sequence": [250021, 94, 5216, 4, 819, 35, 38626, 22238, 9542, 4, 9309, 743, 4, 414, 1097, 129, 60927, 1730, 29, 11373, 2192, 23054, 1339, 20, 13398, 1266, 5, 2], "gen_score": -0.3600703179836273, "gen_text": "«Я, как и многие другие люди, верю, что это повлияет на мирный процесс», - сказал он.", "src_sentence": "\"I, along with many other people, believe that it will affect the peace process,\" he said.", "tgt_sentence": "«Я, как и многие другие люди, убежден, что это скажется на мирном процессе», – заявил он.", "id": 115, "input_ids": [250004, 44, 568, 4, 33233, 678, 5941, 3789, 3395, 4, 18822, 450, 442, 1221, 52490, 70, 88669, 9433, 4, 58, 764, 2804, 5, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "bertscore_f1": 0.8917006850242615, "bertscore_precision": 0.9027066230773926, "bertscore_recall": 0.8809599876403809, "bleu_score": 0.4372411072254181, "rougeL_score": 0.0, "rouge2_score": 0.0, "rouge1_score": 0.0}
from utils import rouge, sacrebleu
gen, tgt = example["gen_text"], example["tgt_sentence"]

def remove_wat(s):
    return s.replace('«', '').replace('»', '')

gen, tgt = remove_wat(gen), remove_wat(tgt)

rouge(pred=gen, label=tgt, rouge_key="rouge1")
# sacrebleu(pred=gen, label=tgt)
gen, tgt
x = 'Я, как и многие другие люди, верю, что это повлияет на мирный'
y = 'Я, как и многие другие люди, убежден, что это скажется на мирном'
rouge(pred="a a a a", label="a a a b", rouge_key="rouge1")
[c for c in list(x) if not c.isalnum()]
import re
gen, re.sub(r'\W', ' ', gen)


In [None]:
from datasets import Dataset
dataset = Dataset.from_json("/home/olab/tomerronen1/git_repos/last_projects_playground/confidence_estimation/mbart50_dumps/wmt20_en-ru_200examples.jsonl")
dataset

In [None]:
import torch
x = torch.ones(1000000).to("cuda:2")

In [None]:
!nvidia-smi

In [None]:
del x

In [None]:
import os
os.environ["XDG_CACHE_HOME"] = "/home/olab/tomerronen1/xdg_cache/"
from datasets import Dataset
ds = Dataset.from_json("/home/olab/tomerronen1/git_repos/last_projects_playground/confidence_estimation/mbart50_dumps/wmt20_en-ru_200examples.jsonl")
ds

In [None]:
from bert_score import BERTScorer
bertscore_model = BERTScorer(model_type="microsoft/deberta-xlarge-mnli", lang="en", rescale_with_baseline=True, device="cpu")
bertscore_model._model_type

In [None]:
import numpy as np

def flatten(nested_list: list[list]) -> list:
    return [item for sublist in nested_list for item in sublist]

num_beams = 2
preds = []
for pred in concatenated_preds:
    different_beams = np.array_split(pred, np.flatnonzero(pred == -100) + 1)
    different_beams = different_beams[:-1]  # last one is padding
    for beam_pred in different_beams:
        beam_pred = beam_pred[beam_pred != -100]
        preds.append(beam_pred)

tokenizer.batch_decode(preds, skip_special_tokens=True)

In [None]:
# trainer.data_collator(dataset.to_list())
# trainer.data_collator(dataset.to_dict(orient="list"))
batch = trainer.data_collator(dataset.to_pandas()[["input_ids", "attention_mask"]].to_dict(orient="records"))
# trainer.data_collator([dataset[i] for i in range(len(dataset))])

In [None]:
res = trainer.model.generate(**batch)
res

In [None]:
preds = trainer.predict(dataset)
preds.predictions.shape

In [None]:
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

In [None]:
article_en = "The head of the United Nations says there is no military solution in Syria"

model_inputs = tokenizer(article_en, return_tensors="pt")
generation_params = {"num_beams": 5, "length_penalty": 1.0}

# translate from English to Hindi
generated_tokens = model.generate(
    **model_inputs,
    forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"],
    **generation_params
)
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
# => 'संयुक्त राष्ट्र के नेता कहते हैं कि सीरिया में कोई सैन्य समाधान नहीं है'

# translate from English to Chinese
generated_tokens = model.generate(
    **model_inputs,
    forced_bos_token_id=tokenizer.lang_code_to_id["zh_CN"],
    **generation_params
)
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
# => '联合国首脑说,叙利亚没有军事解决办法'



In [None]:
model_many_to_en.generate(**tokenizer_many_to_en("אני חתול", return_tensors="pt"), forced_bos_token_id=250004)

In [None]:
model_many_to_en = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")


In [None]:
model_many_to_en = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
tokenizer_many_to_en = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")

In [None]:
article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."

# translate Hindi to English
tokenizer_many_to_en.src_lang = "hi_IN"
encoded_hi = tokenizer(article_hi, return_tensors="pt")
generated_tokens = model_many_to_en.generate(**encoded_hi, **generation_params)
print(tokenizer_many_to_en.batch_decode(generated_tokens, skip_special_tokens=True))
# => "The head of the UN says there is no military solution in Syria."

# translate Arabic to English
tokenizer_many_to_en.src_lang = "ar_AR"
encoded_ar = tokenizer_many_to_en(article_ar, return_tensors="pt")
generated_tokens = model_many_to_en.generate(**encoded_ar, **generation_params)
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
# => "The Secretary-General of the United Nations says there is no military solution in Syria."

