## 基于T5的文本摘要

In [1]:
! pip install rouge-chinese

Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple


### Step1 导入相关包

In [2]:
import os

# 设置可见的 GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,7"

import torch
from datasets import  Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

### Step2 加载数据集

In [3]:
ds = Dataset.load_from_disk("nlpcc_2017")
ds

Dataset({
    features: ['title', 'content'],
    num_rows: 5000
})

In [4]:
ds = ds.train_test_split(200, seed=42)
ds

DatasetDict({
    train: Dataset({
        features: ['title', 'content'],
        num_rows: 4800
    })
    test: Dataset({
        features: ['title', 'content'],
        num_rows: 200
    })
})

In [5]:
ds["train"][0]

{'title': '郴州市发布雷电橙色预警:过去2小时北湖区、苏仙区、郴州市区、桂阳县、宜章县、嘉禾县、资兴市、桂东县、汝城县已经受...',
 'content': '发布日期:2015-03-3007:55:33郴州市气象台3月30日7时52分发布雷电橙色预警信号:过去2小时北湖区、苏仙区、郴州市区、桂阳县、宜章县、嘉禾县、资兴市、桂东县、汝城县已经受雷电活动影响,并将持续,出现雷电灾害事故的可能性比较大,请注意防范。图例标准防御指南2小时内发生雷电活动的可能性很大,或者已经受雷电活动影响,且可能持续,出现雷电灾害事故的可能性比较大。1、政府及相关部门按照职责落实防雷应急措施;2、人员应当留在室内,并关好门窗;3、户外人员应当躲入有防雷设施的建筑物或者汽车内;4、切断危险电源,不要在树下、电杆下、塔吊下避雨;5、在空旷场地不要打伞,不要把农具、羽毛球拍、高尔夫球杆等扛在肩上。'}

### Step3 数据集处理

In [6]:
tokenizer = AutoTokenizer.from_pretrained("Langboat/mengzi-t5-base")
# tokenizer

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [11]:
def process_function(examples):
    contents = ["摘要生成：\n" + e for e in examples["content"]]
    inputs = tokenizer(contents, max_length=384, truncation=True)
    labels = tokenizer(text_target=examples["title"], max_length=64, truncation=True)
    inputs["labels"] = labels["input_ids"]
    return inputs

In [12]:
tokenized_ds = ds.map(process_function, batched=True)
tokenized_ds

Map:   0%|          | 0/4800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['title', 'content', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 4800
    })
    test: Dataset({
        features: ['title', 'content', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
})

In [13]:
tokenizer.decode(tokenized_ds["train"][0]["input_ids"])

'摘要生成: 发布日期:2015-03-3007:55:33郴州市气象台3月30日7时52分发布雷电橙色预警信号:过去2小时北湖区、苏仙区、郴州市区、桂阳县、宜章县、嘉禾县、资兴市、桂东县、汝城县已经受雷电活动影响,并将持续,出现雷电灾害事故的可能性比较大,请注意防范。图例标准防御指南2小时内发生雷电活动的可能性很大,或者已经受雷电活动影响,且可能持续,出现雷电灾害事故的可能性比较大。1、政府及相关部门按照职责落实防雷应急措施;2、人员应当留在室内,并关好门窗;3、户外人员应当躲入有防雷设施的建筑物或者汽车内;4、切断危险电源,不要在树下、电杆下、塔吊下避雨;5、在空旷场地不要打伞,不要把农具、羽毛球拍、高尔夫球杆等扛在肩上。</s>'

In [14]:
tokenizer.decode(tokenized_ds["train"][0]["labels"])

'郴州市发布雷电橙色预警:过去2小时北湖区、苏仙区、郴州市区、桂阳县、宜章县、嘉禾县、资兴市、桂东县、汝城县已经受...</s>'

### Step4 创建模型

In [16]:
model = AutoModelForSeq2SeqLM.from_pretrained("Langboat/mengzi-t5-base")

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

### Step5 创建评估函数

In [23]:
import numpy as np
from rouge_chinese import Rouge

rouge = Rouge()
def compute_metric(evalPred):
    predictions, labels = evalPred
    decode_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decode_preds = [" ".join(p) for p in decode_preds]
    decode_labels = [" ".join(l) for l in decode_labels]
    scores = rouge.get_scores(decode_preds, decode_labels, avg=True)

    return {
        "rouge-1": scores["rouge-1"]["f"],
        "rouge-2": scores["rouge-2"]["f"],
        "rouge-l": scores["rouge-l"]["f"]
    }

### Step6 配置训练参数

In [24]:
import logging

logging.basicConfig(level=logging.INFO)

In [27]:
args = Seq2SeqTrainingArguments(
    output_dir="./summary-t5",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end="rouge-l",
    predict_with_generate=True
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### Step7 创建训练器

In [28]:
trainer = Seq2SeqTrainer(
    args=args,
    model=model,
    tokenizer=tokenizer,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    compute_metrics=compute_metric,
    data_collator=DataCollatorForSeq2Seq(tokenizer)
)

### Step8 模型训练

In [29]:
trainer.train()



Epoch,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
1,2.303,2.421367,0.469352,0.299668,0.387081
2,2.147,2.350675,0.475941,0.305267,0.392435
3,2.0633,2.336023,0.481376,0.311972,0.396005


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


TrainOutput(global_step=150, training_loss=2.1627472178141276, metrics={'train_runtime': 177.021, 'train_samples_per_second': 81.346, 'train_steps_per_second': 0.847, 'total_flos': 6775246029127680.0, 'train_loss': 2.1627472178141276, 'epoch': 3.0})

### Step9 模型推理

In [31]:
from transformers import pipeline

pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)

In [38]:
pipe("摘要生成：\n" + ds["test"][-1]["content"], max_length=64, do_sample=True)

[{'generated_text': '美国男子同意妻子前往火星的单程之旅,计划的目的是为了开拓人类的居住区,为人类争取更多生存空间。'}]

In [34]:
ds["test"][-1]["title"]

'美男子称将把妻子送往火星:预计2026年启程,目标是开拓人类居住地;男子称虽想念妻子但任务意义更大。'