# 基于GLM的文本摘要

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

## Step1 导入相关包

In [2]:
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


## Step2 加载数据集

In [3]:
ds = Dataset.load_from_disk("./nlpcc_2017/")
ds

Dataset({
    features: ['title', 'content'],
    num_rows: 5000
})

In [4]:
ds = ds.train_test_split(100, seed=42)
ds

DatasetDict({
    train: Dataset({
        features: ['title', 'content'],
        num_rows: 4900
    })
    test: Dataset({
        features: ['title', 'content'],
        num_rows: 100
    })
})

In [5]:
ds["train"][0]

{'title': '组图:黑河边防军人零下30℃户外训练,冰霜沾满眉毛和睫毛,防寒服上满是冰霜。',
 'content': '中国军网2014-12-1709:08:0412月16日,黑龙江省军区驻黑河某边防团机动步兵连官兵,冒着-30℃严寒气温进行体能训练,挑战极寒,锻造钢筋铁骨。该连素有“世界冠军的摇篮”之称,曾有5人24人次登上世界军事五项冠军的领奖台。(魏建顺摄)黑龙江省军区驻黑河某边防团机动步兵连官兵冒着-30℃严寒气温进行体能训练驻黑河某边防团机动步兵连官兵严寒中户外训练,防寒服上满是冰霜驻黑河某边防团机动步兵连官兵严寒中户外训练,防寒服上满是冰霜官兵睫毛上都被冻上了冰霜官兵们睫毛上都被冻上了冰霜驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练'}

## Step3 数据处理

In [6]:
tokenizer = AutoTokenizer.from_pretrained("/data/PLM/glm-large-chinese", trust_remote_code=True)
tokenizer

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


GLMChineseTokenizer(name_or_path='/data/PLM/glm-large-chinese', vocab_size=50000, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='left', special_tokens={'eos_token': '<|endoftext|>', 'unk_token': '[UNK]', 'pad_token': '<|endoftext|>', 'cls_token': '[CLS]', 'mask_token': '[MASK]', 'additional_special_tokens': ['<|startofpiece|>', '<|endofpiece|>', '[gMASK]', '[sMASK]']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50000: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	50001: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50002: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	50003: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	50004: AddedToken("[UNUSED1]", rstrip=False, lstrip=False, single_wor

In [7]:
def process_func(exmaples):
    contents = ["摘要生成: \n" + e + tokenizer.mask_token for e in exmaples["content"]] # 除了prefix，还要再最后加上tokenizer.mask_token！
    inputs = tokenizer(contents, max_length=384, truncation=True, padding="max_length", return_tensors="pt")
    # glm特殊的处理labels的方式
    inputs = tokenizer.build_inputs_for_generation(inputs, targets=exmaples['title'], padding=True, max_gen_length=64)
    return inputs

In [8]:
tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds["train"].column_names) # remove_columns会去掉原本的'title', 'content'列，转为'input_ids', 'position_ids', 'attention_mask', 'labels'
tokenized_ds

Map:   0%|          | 0/4900 [00:00<?, ? examples/s]

Map: 100%|██████████| 4900/4900 [00:58<00:00, 83.96 examples/s]
Map: 100%|██████████| 100/100 [00:01<00:00, 84.54 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'position_ids', 'attention_mask', 'labels'],
        num_rows: 4900
    })
    test: Dataset({
        features: ['input_ids', 'position_ids', 'attention_mask', 'labels'],
        num_rows: 100
    })
})

In [9]:
tokenizer.decode(tokenized_ds["train"][0]["input_ids"]) # 注意input_ids的开始是[CLS]，但labels的开始是<|startofpiece|>

'[CLS] 摘要生成: 中国军网2014-12-1709:08:0412月16日,黑龙江省军区驻黑河某边防团机动步兵连官兵,冒着-30°C严寒气温进行体能训练,挑战极寒,锻造钢筋铁骨。该连素有“世界冠军的摇篮”之称,曾有5人24人次登上世界军事五项冠军的领奖台。(魏建顺摄)黑龙江省军区驻黑河某边防团机动步兵连官兵冒着-30°C严寒气温进行体能训练驻黑河某边防团机动步兵连官兵严寒中户外训练,防寒服上满是冰霜驻黑河某边防团机动步兵连官兵严寒中户外训练,防寒服上满是冰霜官兵睫毛上都被冻上了冰霜官兵们睫毛上都被冻上了冰霜驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练驻黑河某边防团机动步兵连官兵严寒中进行户外体能训练[MASK]<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endofte

In [10]:
print(tokenized_ds["train"][0]["labels"]) # 直接decode会因为-100而报错

[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -10

In [11]:
# input_ids还比较正常：
# 50002是[CLS]，表示inputs的开始；50003是[MASK]；50000是<|endoftext|>，表示inputs的结束，也作为inputs和labels的[PAD]使用；
# 50006是<|startofpiece|>，表示labels的开始；50007是<|endofpiece|>，表示表示labels的结束；需要注意的是50006不算在labels中！
print(tokenized_ds["train"][0]["input_ids"])
# 位置编码去看下GLM的原理就能清楚：input_ids部分前面是正常的编码，最后接上[MASK]位置的编码重复max_gen_length次；
# labels部分前面是0的重复，最后接上正常的编码直到max_gen_length
print(tokenized_ds["train"][0]["position_ids"]) 

[50002, 43358, 23217, 4490, 43383, 576, 43790, 43593, 1251, 2979, 10422, 1902, 43383, 1976, 43383, 2638, 64, 43491, 195, 43498, 43359, 12929, 9872, 45218, 43979, 43965, 44221, 31855, 43895, 4828, 9404, 43905, 17586, 43359, 25716, 6905, 44801, 43573, 39991, 5316, 74, 20977, 995, 43359, 2265, 44003, 44773, 43359, 29329, 13922, 44210, 44394, 43361, 43655, 43905, 21178, 43430, 91, 1534, 43360, 23052, 43432, 12292, 43359, 31750, 43402, 43371, 369, 15386, 11946, 91, 2227, 37620, 1534, 43360, 43952, 44069, 43820, 3700, 45176, 43555, 44302, 44415, 43396, 12929, 9872, 45218, 43979, 43965, 44221, 31855, 43895, 4828, 9404, 43905, 17586, 25716, 6905, 44801, 43573, 39991, 5316, 74, 20977, 995, 45218, 43979, 43965, 44221, 31855, 43895, 4828, 9404, 43905, 17586, 39991, 43378, 4620, 995, 43359, 44010, 44773, 43674, 43387, 32058, 44508, 45333, 45218, 43979, 43965, 44221, 31855, 43895, 4828, 9404, 43905, 17586, 39991, 43378, 4620, 995, 43359, 44010, 44773, 43674, 43387, 32058, 44508, 45333, 17586, 15022

## Step4 创建模型

In [12]:
model = AutoModelForSeq2SeqLM.from_pretrained("/data/PLM/glm-large-chinese", trust_remote_code=True)

In [13]:
model # GLM是仅编码器结构，但不是encoder-only而是prefix-lm

GLMForConditionalGeneration(
  (glm): GLMModel(
    (word_embeddings): VocabEmbedding()
    (transformer): GLMStack(
      (embedding_dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(1025, 1024)
      (block_position_embeddings): Embedding(1025, 1024)
      (layers): ModuleList(
        (0-23): 24 x GLMBlock(
          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attention): SelfAttention(
            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
            (attention_dropout): Dropout(p=0.1, inplace=False)
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (output_dropout): Dropout(p=0.1, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
            (dense_4h_to_h): Linear(in_fe

## Step6 配置训练参数

glm不太适合在训练时评估，所以这里就没有compute_metric

In [14]:
args = Seq2SeqTrainingArguments(
    output_dir="./summary_glm",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=8,
    logging_steps=8,
    num_train_epochs=1 # 注意默认值为3
)

## Step7 创建训练器

In [15]:
trainer = Seq2SeqTrainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds["train"],
    tokenizer=tokenizer,
    #data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer) # 主要是做padding，但上面已经做好了
)  

Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


## Step8 模型训练

In [16]:
trainer.train()

Step,Training Loss
8,1.9781
16,1.8071
24,1.7341
32,1.683
40,1.7534
48,1.556
56,1.8072
64,1.5954
72,1.7246
80,1.5922


TrainOutput(global_step=153, training_loss=1.6627078874438417, metrics={'train_runtime': 1003.8592, 'train_samples_per_second': 4.881, 'train_steps_per_second': 0.152, 'total_flos': 4653015575298048.0, 'train_loss': 1.6627078874438417, 'epoch': 1.0})

## Step9 模型推理

In [17]:
input_text = ds["test"][-1]["content"]
inputs = tokenizer("摘要生成: \n" + input_text + tokenizer.mask_token, return_tensors="pt")
inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=64)
inputs = inputs.to("cuda")
output = model.generate(**inputs, max_new_tokens=64, eos_token_id=tokenizer.eop_token_id, do_sample=True) # eos_token_id必须设置成eop_token_id而不是eos_token_id
tokenizer.decode(output[0].tolist())

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50007 for open-end generation.


'[CLS] 摘要生成: 中国证券网讯(记者严政)中国重工6月14日晚间公告称,公司日前接到控股股东中国船舶重工集团公司(简称“中船重工”)通知,中船重工拟对其自身相关业务进行调整,部分业务涉及到公司。其中,公司拟以持有的动力相关资产进行对外投资,参与中船重工拟打造的动力业务平台公司。公司上述对外投资的方案不构成重大资产重组,也不涉及公司发行股份。中国重工表示,目前方案还需进一步论证,存在不确定性。为保证公平信息披露,维护投资者利益,避免造成公司股价异常波动,经公司申请,自6月12日下午开市起公司股票停牌。同时公司将与中船重工保持密切联系,尽快确认是否进行上述事项,并于股票停牌之日起5个工作日内(含停牌当日)公告事项进展情况。[MASK]<|endoftext|> <|startofpiece|> 中国重工公告称,接到中船重工通知,公司拟以持有的动力资产进行对外投资参与中船重工拟打造的动力业务平台公司,目前方案需进一步论证;公司股将于12日下午开市起停牌。 <|endofpiece|>'

In [18]:
import torch

model = model.eval()

def predict_test():
    predict = []
    with torch.inference_mode():
        for d in ds["test"]:
            inputs = tokenizer("摘要生成: \n" + d["content"] + tokenizer.mask_token, return_tensors="pt")
            # 这里也要进行修改，而不是简单地用tokenizer，position_ids也要这一步才有！
            inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=64) # 会在最后的[MASK]和<endoftext>后加<|startofpiece|>(50006)
            inputs = inputs.to("cuda")
            # tokenizer.eop_token_id其实就是50007
            output = model.generate(**inputs, max_new_tokens=64, eos_token_id=tokenizer.eop_token_id, do_sample=True)
            predict.append(tokenizer.decode(output[0].tolist()).split("<|startofpiece|>")[1].replace("<|endofpiece|>", "").strip())
            print("curID:", len(predict))
    return predict

In [None]:
result = predict_test()

In [20]:
result

['媒体称IS公布1400名西方政界人士名单,包括美国国务院和国防部人员,其中不乏美国政界重要人士,其威胁将杀1400人。',
 '宿松县2公职人员吸毒被开除党籍、行政撤职,县纪委通报:全县党政机关一把手要引以为戒。',
 '媒体称黑龙江省“伪基站”犯罪数量已超100万户,每天影响手机用户达190万个。',
 '北京明天实行机动车尾号轮换,周一至周五限行2和7!',
 '今日10时许,松北宾馆门前一对母子打成一团,母亲怒撕儿子上衣;警方随后介入后,三人和好离开现场',
 '苏州吴江一初中副校长泄考题被停职:泄题人系副校长亲戚,题目自己先泄露给亲戚,随后将其散播,导致题目传至网上。',
 '承德广播电视台2名处级干部严重违纪被开除党籍、开除公职,其涉嫌犯罪问题及线索移送司法机关查处,目前正进一步办理中',
 '曝李帅佩斯微博首秀写汉字,网友:真佩服这大汉,字间距堪比尺子',
 '调查表明,穆沙拉夫称“伊斯兰国”控制了至少两个主要小麦产区,已攫取100多万吨小麦,占伊拉克年消费量的1/5',
 '辽宁省气象台发布大风黄色预警:预计今日夜间到明天白天渤海海面,将有西北风9级,阵风10-11级。请有关单位和人员作好防范准备...',
 '上海市委书记会见缅甸全国民主联盟主席,介绍浦东发展情况;缅甸称上海是发展最快经济体之一。中联部副部长、上海市委秘书长尹弘参加会见。',
 '湖南凤凰有位122岁老寿星,一生生育13个孩子,最小的仅活至18岁,被评为“湖南省十大寿星之首”。',
 '杭州拱墅区大关街道正科级干部包某被举报涉嫌强奸他人,已被开除党籍、行政撤职。',
 '腰缠万贯”藏腰间逃法海 女子藏16斤金条闯海关',
 '武汉一男大学生摆摊招聘女友,称想找女友无需太夸张,其前女友曾在该校招聘广告上留下QQ号码',
 '荆州市发布霾黄色预警:目前我市部分地区已经出现能见度小于5公里的霾,未来仍将持续,请注意防范。...',
 '[拌饭]成龙广告代言受辱,曝广告台词被篡改成“龙爪手”;朱孝天晒拼图疑公布恋情;大s晒照为好友庆祝生仔',
 '山东莱州农民屋顶建50千瓦光伏发电系统,获得国家政策支持,平均每度电上网电价1元左右',
 '晋中:姐姐失踪4天,家人四处寻找弟弟,孩子无外伤,警方正寻人',
 '5月份,郑州市商品住宅销售均价达10243元/平方米,涨幅29.04%;其中郑东

In [21]:
from rouge_chinese import Rouge

rouge = Rouge()

# 用Rouge评估之前要先将中文字符用空格隔开，所以BART和GPT-2是因为空格多了才导致BLEU和Rouge为0！
docode_preds = [" ".join(p) for p in result]
decode_labels = [" ".join(l) for l in ds["test"]["title"]]
scores = rouge.get_scores(docode_preds, decode_labels, avg=True)
{
    "rouge-1": scores["rouge-1"]["f"],
    "rouge-2": scores["rouge-2"]["f"],
    "rouge-l": scores["rouge-l"]["f"],
}

{'rouge-1': 0.49659550301358446,
 'rouge-2': 0.3007335561279025,
 'rouge-l': 0.4081939467759942}