In [None]:
#加载所需要的库
%pip install transformers datasets evaluate tokenizer sacrebleu sentencepiece accelerate sacremoses

In [None]:
import os
os.chdir("/content/drive/MyDrive")

In [None]:
import numpy as np
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,AutoConfig
import torch
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer,DataCollatorForSeq2Seq
#checkpoint = 'Helsinki-NLP//opus-mt-en-zh'
checkpoint = "Jyshen/Translation_en2zh"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
# Training for DDDSSS dataset
from datasets import Dataset
DS_dataset = load_dataset("DDDSSS/en-zh-dataset")
def DS_pfunc(examples):
    inputs = [example["en"] for example in examples["translation"]]
    targets = [example["zh"] for example in examples["translation"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
    return model_inputs
model_inputs = DS_dataset.map(DS_pfunc, batched=True)
model_inputs = Dataset.from_dict({'input_ids':model_inputs['train']['input_ids'],'attention_mask':model_inputs['train']['attention_mask'],'labels':model_inputs['train']['labels']})

In [None]:
# Training for CoQA dataset
from datasets import concatenate_datasets
from huggingface_hub import notebook_login
notebook_login()

CoQA_dataset = load_dataset("silk-road/Luotuo-QA-A-CoQA-Chinese")
def Co_pfunc(examples):
    inputs = [example for example in examples["story"]]
    targets = [example for example in examples["story_zh"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
    return model_inputs
model_inputs2 = CoQA_dataset.map(Co_pfunc, batched=True)
model_inputs2 = Dataset.from_dict({'input_ids':model_inputs2['train']['input_ids'],'attention_mask':model_inputs2['train']['attention_mask'],'labels':model_inputs2['train']['labels']})
model_inputs = concatenate_datasets([model_inputs, model_inputs2])

In [None]:
# Training for Dolly dataset
Dolly_dataset = load_dataset("silk-road/chinese-dolly-15k")

def Do_pfunc(examples):
    inputs = [example for example in examples["instruction"]]
    targets = [example for example in examples["instruction_zh"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
    return model_inputs
model_inputs2 = Dolly_dataset.map(Do_pfunc, batched=True)
model_inputs2 = Dataset.from_dict({'input_ids':model_inputs2['train']['input_ids'],'attention_mask':model_inputs2['train']['attention_mask'],'labels':model_inputs2['train']['labels']})
model_inputs = concatenate_datasets([model_inputs, model_inputs2])

In [None]:
# Training for MMC4 dataset
MMC4_dataset = load_dataset("silk-road/MMC4-130k-chinese-image")

def M4_pfunc(examples):
    inputs = [example for example in examples["caption"]]
    targets = [example for example in examples["caption_zh"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
    return model_inputs
model_inputs2 = MMC4_dataset.map(M4_pfunc, batched=True)
model_inputs2 = Dataset.from_dict({'input_ids':model_inputs2['train']['input_ids'],'attention_mask':model_inputs2['train']['attention_mask'],'labels':model_inputs2['train']['labels']})
model_inputs = concatenate_datasets([model_inputs, model_inputs2])

In [None]:
model_inputs = model_inputs.train_test_split(test_size=0.2, shuffle=True, seed=2023)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
metric = evaluate.load("sacrebleu")

In [None]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
batchsize=32
training_args = Seq2SeqTrainingArguments(
    output_dir="./checkpoint",
    evaluation_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=batchsize,
    per_device_eval_batch_size=batchsize,
    weight_decay=0.01,
    # save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
    save_strategy="epoch",
    jit_mode_eval=True
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=model_inputs["train"],
    eval_dataset=model_inputs["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
dir = "./Translation-en2zh/"
model.save_pretrained(dir)
tokenizer.save_pretrained(dir)