In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
from transformers import MBart50TokenizerFast,MBartForConditionalGeneration,Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer
from datasets import load_dataset, load_metric
import numpy as np
import evaluate

In [3]:
raw_datasets = load_dataset("Helsinki-NLP/opus-100", "en-id")
model_mbart = 'facebook/mbart-large-50-one-to-many-mmt'

In [4]:
tokenizer = MBart50TokenizerFast.from_pretrained(model_mbart,src_lang="en_XX",tgt_lang = "id_ID")

In [5]:
source_lang = "en"
target_lang = "id"

def preprocess(data):
  inputs = [dt[source_lang] for dt in data["translation"]]
  targets = [dt[target_lang] for dt in data["translation"]]
  model_inputs = tokenizer(inputs, truncation=True)

  with tokenizer.as_target_tokenizer():
    labels = tokenizer(targets, truncation=True)
  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

tokenized_datasets = raw_datasets.map(preprocess, batched=True)

In [6]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [7]:
model = MBartForConditionalGeneration.from_pretrained(model_mbart)
model = model.cuda()

In [8]:
hyperparameters = {
    'learning_rate': 1e-5,
    'batch_size': 32,
    'num_epochs': 10
}

args = Seq2SeqTrainingArguments(
    f"mbart-large-50-one-to-many-mmt-finetuned-en-to-id",
    evaluation_strategy="epoch",
    learning_rate=hyperparameters['learning_rate'],
    per_device_train_batch_size=hyperparameters['batch_size'],
    per_device_eval_batch_size=hyperparameters['batch_size'],
    weight_decay=0.01,
    save_total_limit=hyperparameters['num_epochs'],
    num_train_epochs=hyperparameters['num_epochs'],
    predict_with_generate=True,
)

In [9]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [10]:
metric = evaluate.load("sacrebleu")
meteor = evaluate.load('meteor')

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(
        decoded_preds, decoded_labels)
    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels)
    meteor_result = meteor.compute(
        predictions=decoded_preds, references=decoded_labels)
    prediction_lens = [np.count_nonzero(
        pred != tokenizer.pad_token_id) for pred in preds]
    result = {'bleu': result['score']}
    result["gen_len"] = np.mean(prediction_lens)
    result["meteor"] = meteor_result["meteor"]
    result = {k: round(v, 4) for k, v in result.items()}
    return result

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\user\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\user\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\user\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [11]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
# trainer.train()

  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 2.1864113807678223, 'eval_bleu': 9.2087, 'eval_gen_len': 14.747, 'eval_meteor': 0.2429, 'eval_runtime': 66.0176, 'eval_samples_per_second': 15.147, 'eval_steps_per_second': 0.485, 'epoch': 1.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 2.0488064289093018, 'eval_bleu': 18.451, 'eval_gen_len': 13.771, 'eval_meteor': 0.4071, 'eval_runtime': 990.2109, 'eval_samples_per_second': 1.01, 'eval_steps_per_second': 0.032, 'epoch': 2.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.9852286577224731, 'eval_bleu': 24.5361, 'eval_gen_len': 13.173, 'eval_meteor': 0.4966, 'eval_runtime': 947.3509, 'eval_samples_per_second': 1.056, 'eval_steps_per_second': 0.034, 'epoch': 3.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.9715750217437744, 'eval_bleu': 25.659, 'eval_gen_len': 12.982, 'eval_meteor': 0.5097, 'eval_runtime': 878.4208, 'eval_samples_per_second': 1.138, 'eval_steps_per_second': 0.036, 'epoch': 4.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.9833670854568481, 'eval_bleu': 25.6724, 'eval_gen_len': 12.976, 'eval_meteor': 0.5141, 'eval_runtime': 863.4577, 'eval_samples_per_second': 1.158, 'eval_steps_per_second': 0.037, 'epoch': 5.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 2.0055649280548096, 'eval_bleu': 25.7811, 'eval_gen_len': 12.749, 'eval_meteor': 0.5102, 'eval_runtime': 442.4429, 'eval_samples_per_second': 2.26, 'eval_steps_per_second': 0.072, 'epoch': 6.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 2.024784803390503, 'eval_bleu': 25.5804, 'eval_gen_len': 12.702, 'eval_meteor': 0.5095, 'eval_runtime': 399.8456, 'eval_samples_per_second': 2.501, 'eval_steps_per_second': 0.08, 'epoch': 7.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 2.0386133193969727, 'eval_bleu': 25.6233, 'eval_gen_len': 12.685, 'eval_meteor': 0.5097, 'eval_runtime': 403.252, 'eval_samples_per_second': 2.48, 'eval_steps_per_second': 0.079, 'epoch': 8.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 2.0506856441497803, 'eval_bleu': 25.5965, 'eval_gen_len': 12.69, 'eval_meteor': 0.5108, 'eval_runtime': 395.0228, 'eval_samples_per_second': 2.531, 'eval_steps_per_second': 0.081, 'epoch': 9.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 2.054755687713623, 'eval_bleu': 25.4659, 'eval_gen_len': 12.686, 'eval_meteor': 0.5098, 'eval_runtime': 399.8797, 'eval_samples_per_second': 2.501, 'eval_steps_per_second': 0.08, 'epoch': 10.0}
{'train_runtime': 9587.6313, 'train_samples_per_second': 1.043, 'train_steps_per_second': 0.033, 'train_loss': 1.6048397064208983, 'epoch': 10.0}


TrainOutput(global_step=320, training_loss=1.6048397064208983, metrics={'train_runtime': 9587.6313, 'train_samples_per_second': 1.043, 'train_steps_per_second': 0.033, 'total_flos': 697985092288512.0, 'train_loss': 1.6048397064208983, 'epoch': 10.0})

320/320 [2:39:47<00:00, 7.31s/it]
{'eval_loss': 2.1864113807678223, 'eval_bleu': 9.2087, 'eval_gen_len': 14.747, 'eval_meteor': 0.2429, 'eval_runtime': 66.0176, 'eval_samples_per_second': 15.147, 'eval_steps_per_second': 0.485, 'epoch': 1.0}
{'eval_loss': 2.0488064289093018, 'eval_bleu': 18.451, 'eval_gen_len': 13.771, 'eval_meteor': 0.4071, 'eval_runtime': 990.2109, 'eval_samples_per_second': 1.01, 'eval_steps_per_second': 0.032, 'epoch': 2.0}
{'eval_loss': 1.9852286577224731, 'eval_bleu': 24.5361, 'eval_gen_len': 13.173, 'eval_meteor': 0.4966, 'eval_runtime': 947.3509, 'eval_samples_per_second': 1.056, 'eval_steps_per_second': 0.034, 'epoch': 3.0}
{'eval_loss': 1.9715750217437744, 'eval_bleu': 25.659, 'eval_gen_len': 12.982, 'eval_meteor': 0.5097, 'eval_runtime': 878.4208, 'eval_samples_per_second': 1.138, 'eval_steps_per_second': 0.036, 'epoch': 4.0}
{'eval_loss': 1.9833670854568481, 'eval_bleu': 25.6724, 'eval_gen_len': 12.976, 'eval_meteor': 0.5141, 'eval_runtime': 863.4577, 'eval_samples_per_second': 1.158, 'eval_steps_per_second': 0.037, 'epoch': 5.0}
{'eval_loss': 2.0055649280548096, 'eval_bleu': 25.7811, 'eval_gen_len': 12.749, 'eval_meteor': 0.5102, 'eval_runtime': 442.4429, 'eval_samples_per_second': 2.26, 'eval_steps_per_second': 0.072, 'epoch': 6.0}
{'eval_loss': 2.024784803390503, 'eval_bleu': 25.5804, 'eval_gen_len': 12.702, 'eval_meteor': 0.5095, 'eval_runtime': 399.8456, 'eval_samples_per_second': 2.501, 'eval_steps_per_second': 0.08, 'epoch': 7.0}
{'eval_loss': 2.0386133193969727, 'eval_bleu': 25.6233, 'eval_gen_len': 12.685, 'eval_meteor': 0.5097, 'eval_runtime': 403.252, 'eval_samples_per_second': 2.48, 'eval_steps_per_second': 0.079, 'epoch': 8.0}
{'eval_loss': 2.0506856441497803, 'eval_bleu': 25.5965, 'eval_gen_len': 12.69, 'eval_meteor': 0.5108, 'eval_runtime': 395.0228, 'eval_samples_per_second': 2.531, 'eval_steps_per_second': 0.081, 'epoch': 9.0}
{'eval_loss': 2.054755687713623, 'eval_bleu': 25.4659, 'eval_gen_len': 12.686, 'eval_meteor': 0.5098, 'eval_runtime': 399.8797, 'eval_samples_per_second': 2.501, 'eval_steps_per_second': 0.08, 'epoch': 10.0}
{'train_runtime': 9587.6313, 'train_samples_per_second': 1.043, 'train_steps_per_second': 0.033, 'train_loss': 1.6048397064208983, 'epoch': 10.0}
TrainOutput(global_step=320, training_loss=1.6048397064208983, metrics={'train_runtime': 9587.6313, 'train_samples_per_second': 1.043, 'train_steps_per_second': 0.033, 'total_flos': 697985092288512.0, 'train_loss': 1.6048397064208983, 'epoch': 10.0})

In [12]:
# trainer.save_model('opus-mt-en-id-finetuned-en-to-id')

Non-default generation parameters: {'max_length': 200, 'early_stopping': True, 'num_beams': 5, 'forced_eos_token_id': 2}


In [16]:
src_text = ["I hope we all passed NLP. Hendrik is the best lecturer in Calvin Institute of Technology!!!"]
model_path = 'opus-mt-en-id-finetuned-en-to-id'

tokenizer = MBart50TokenizerFast.from_pretrained(model_path,src_lang="en_XX")
model = MBartForConditionalGeneration.from_pretrained(model_path)

model_inputs = tokenizer(src_text, return_tensors="pt")
generated_tokens = model.generate(**model_inputs,forced_bos_token_id=tokenizer.lang_code_to_id["id_ID"], max_new_tokens=360)
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

print(translation)

['Aku harap kita semua lulus NLP. Hendrik adalah dosen terbaik di Institut Teknologi Calvin!!!']
