In [None]:
!pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
!pip install transformers==4.37.0
!pip install datasets==2.21.0
!pip install accelerate==0.21.0
!pip install rouge==1.0.1
!pip install tqdm==4.66.5
!pip install jieba==0.42.1

In [None]:
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from datasets import load_dataset
from rouge import Rouge
import numpy as np
import pickle
import os
import jieba

In [None]:
# Load the dataset from Hugging Face Datasets
# Note that we don't use to_list() here.
# This is because we are going to use the map() function to process the dataset.
data_name = "hugcyp/LCSTS"
raw_train = load_dataset(data_name, split="train")
raw_valid = load_dataset(data_name, split="validation")

# To speed up evaluations during fine-tuning,
# we only use a small subset of the validation set
raw_small_valid = raw_valid.select(range(100))

In [None]:
# We use the multi-lingual T5 model for Chinese Abstractive Summarization.
model_name = "google/mt5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Data Preprocessing takes time, so we save the preprocessed data with pickle.
model_prefix = model_name.split("/")[1]
train_saved_pkl = f"train_{model_prefix}.pkl"
valid_saved_pkl = f"val_{model_prefix}.pkl"
small_val_saved_pkl = f"val_{model_prefix}_100.pkl"

In [None]:
# We first check if the preprocessed data exists.
# If the processed data exists, we load the data from the pickle file.
if os.path.exists(small_val_saved_pkl):
    assert os.path.exists(valid_saved_pkl)
    with open(train_saved_pkl, "rb") as f:
        train = pickle.load(f)
    with open(valid_saved_pkl, "rb") as f:
        valid = pickle.load(f)
    with open(small_val_saved_pkl, "rb") as f:
        small_valid = pickle.load(f)
else:
    token_replacement = [
        ["：", ":"],
        ["，", ","],
        ["“", '"'],
        ["”", '"'],
        ["？", "?"],
        ["……", "..."],
        ["！", "!"],
    ]

    def replace_tokens(examples):
        # Substitute some punctuations to prevent too many [UNK] tokens
        for k in ["text", "summary"]:
            for i, _ in enumerate(examples[k]):
                for tok in token_replacement:
                    examples[k][i] = examples[k][i].replace(tok[0], tok[1])
        return examples

    def preprocess_function(examples):
        examples = replace_tokens(examples)
        model_inputs = tokenizer(examples["text"], padding=True, truncation=True)
        labels = tokenizer(
            text_target=examples["summary"],
            max_length=200,
            truncation=True,
        )
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs

    train = raw_train.map(preprocess_function, batched=True)
    valid = raw_valid.map(preprocess_function, batched=True)
    small_valid = raw_small_valid.map(preprocess_function, batched=True)

    with open(train_saved_pkl, "wb") as f:
        pickle.dump(train, f)
    with open(valid_saved_pkl, "wb") as f:
        pickle.dump(valid, f)
    with open(small_val_saved_pkl, "wb") as f:
        pickle.dump(small_valid, f)

In [None]:
# Set up the evaluation metric
rouge_metric = Rouge()

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Decode the predictions to sentences
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # We should first replace -100 in the labels with the pad token.
    # -100 does not exist in the vocabulary, so we should restore it to the pad token.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode the labels to sentences
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # We use jieba to perform word-level evaluations with ROUGE
    predictions = [" ".join(jieba.lcut(o)) for o in decoded_preds]
    references = [" ".join(jieba.lcut(t)) for t in decoded_labels]
    
    # Set `avg=True` to compute the average scores for all samples
    result = rouge_metric.get_scores(predictions, references, avg=True)
    score = {f"{rouge_i}_f": v["f"] for rouge_i, v in result.items()}
    # Compute the average generation length
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    score["gen_len"] = np.mean(prediction_lens)
    return {k: v for k, v in score.items()}

In [None]:
# DataCollatorForSeq2Seq dynamically pads batched data and transforms padded labels into -100.
# The operation provided by this object does a similar job like collate_fn.

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model_name,
)

In [None]:
# Use AutoModelForSeq2SeqLM for T5
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results/mt5",
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=1000,
    save_steps=10000,
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    logging_dir=f"./logs/{model_prefix}",
    logging_steps=1,
    push_to_hub=False,
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=small_valid, # We use the small validation set for evaluation
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.args._n_gpu = 1
trainer.train()

In [None]:
# After training, we evaluate the model on the validation set.
results = trainer.predict(valid)
for metric in ["1", "2", "l"]:
    rouge_item = f"test_rouge-{metric}"
    print(f"{rouge_item}: ", results.metrics[rouge_item])