In [15]:
import os
os.environ["http_proxy"] = "http://127.0.0.1:8889"
os.environ["https_proxy"] = "http://127.0.0.1:8889"

In [16]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset, DatasetDict

In [17]:
ds = Dataset.load_from_disk("datasets/nlpcc_2017")
ds

Dataset({
    features: ['title', 'content'],
    num_rows: 5000
})

In [18]:
ds[0]

{'title': '澳大利亚央行将利率降至纪录低点,以应对疲软的经济前景,并遏制澳元进一步走强。',
 'content': '澳大利亚央行将利率降至纪录低点,以应对疲软的经济前景,并遏制澳元进一步走强。05/0513:37|评论(0)A+澳大利亚央行周二发布声明称,将关键利率由2.25%调降至2%,符合此前交易员及接受彭博调查的29位经济学家中25位的预期。据彭博社报道,上月澳央行官员曾警告,矿业之外的行业投资可能下滑。澳大利亚政府不太可能推出新的刺激措施,来扶助受本币升值和铁矿石价格下跌打击而低于潜在水平的经济增长。“鉴于大宗商品价格下跌,矿业投资还可能有低于当前预期的风险,”预计到降息的澳新银行高级经济学家FelicityEmmett在决议公布前编写的研究报告中称。他表示此次决议可能反映出“央行经济增长预估轨迹有所下调”。'}

In [19]:
ds = ds.train_test_split(test_size=0.1)
ds

DatasetDict({
    train: Dataset({
        features: ['title', 'content'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['title', 'content'],
        num_rows: 500
    })
})

In [20]:
tokenizer = AutoTokenizer.from_pretrained("Langboat/mengzi-t5-base")

def process_func(examples):
    contents = ["摘要生成：\n" + c for c in examples["content"]]
    inputs = tokenizer(contents, truncation=True, max_length=384)
    labels = tokenizer(text_target=examples["title"], max_length=64, truncation=True)
    inputs["labels"] = labels["input_ids"]
    return inputs

In [21]:
tokenized_ds = ds.map(process_func, batched=True)
tokenized_ds

Map: 100%|██████████| 4500/4500 [00:00<00:00, 12102.32 examples/s]
Map: 100%|██████████| 500/500 [00:00<00:00, 12326.69 examples/s]


DatasetDict({
    train: Dataset({
        features: ['title', 'content', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['title', 'content', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 500
    })
})

In [22]:
tokenizer.decode(tokenized_ds["train"][0]["labels"])

'预计明天白天到22日,北疆沿天山一带等地最高气温维持在40°C以上,吐鲁番最高气温将升至47°C</s>'

In [23]:
model = AutoModelForSeq2SeqLM.from_pretrained("Langboat/mengzi-t5-base")

In [46]:
import numpy as np
from rouge_chinese import Rouge

rouge = Rouge()

def compute_metric(evalPred):
    predictions, labels = evalPred
    decode_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decode_preds = [" ".join(pred) for pred in decode_preds]
    decode_labels = [" ".join(label) for label in decode_labels]
    scores = rouge.get_scores(decode_preds, decode_labels, avg=True)
    return {
        "rouge-1": scores["rouge-1"]["f"],
        "rouge-2": scores["rouge-2"]["f"],
        "rouge-l": scores["rouge-l"]["f"],
    }

In [47]:
args = Seq2SeqTrainingArguments(
    output_dir="./models/summarization",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    per_device_eval_batch_size=8,
    logging_strategy="steps",
    logging_steps=25,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    metric_for_best_model="rouge-l",
    predict_with_generate=True
)

In [48]:
trainer = Seq2SeqTrainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
    compute_metrics=compute_metric,
)

In [49]:
import torch
torch.cuda.empty_cache()
trainer.train()

Epoch,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
0,1.7614,2.082058,0.472156,0.310989,0.396397
1,1.5814,2.082574,0.483001,0.322162,0.406702
2,1.4252,2.086323,0.48276,0.32198,0.404416


TrainOutput(global_step=420, training_loss=1.4240690078054155, metrics={'train_runtime': 812.708, 'train_samples_per_second': 16.611, 'train_steps_per_second': 0.517, 'total_flos': 5467300639617024.0, 'train_loss': 1.4240690078054155, 'epoch': 2.99})

In [64]:
from transformers import pipeline

pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)


In [67]:
ds['test'][-1]['content']

'中新网4月27日电据台湾“今日新闻网”报道,台中市大里区发生一起人伦悲剧,疑似债务压力加上房子求售不成,75岁老母亲带2名儿子,在豪宅内烧炭自杀,现场未留下遗书。据报道,台中市大里区新芳路1栋透天别墅传出恶臭,当地里长报警处理,警方、消防员破门而入发现,1名妇人与2名儿子陈尸在二楼卧房,3人身体发黑浮肿,初估死亡时间逾5天。警方调查发现,老妇人9年前以800万(新台币,下同)买下该栋透天别墅,大儿子患有小儿麻痹与智能障碍,小儿子则是没工作,沉溺在线游戏。地方人士表示,半年前小儿子找上中介,说欠债300万元,要卖现居的透天别墅,开价1880万元,但高于市场行情因此买卖破局。房仲业者表示,虽然最后愿意降价,但坚持最低价1550万元,仍高于市场行情,导致无法成交。上月底取消房屋代售委任,上周五房仲上门询问是否再托售,始终无人应门,直到看新闻才得知该家人烧炭自杀。昨(26日)晚因无家属处理后事,社会局请礼仪公司将3人遗体运至殡仪馆存放,检察官将会同法医相验厘清确切的死因。'

In [69]:

pipe("摘要生成：\n" + ds["test"][-1]["content"], max_length=64)

[{'generated_text': '75岁老母亲带2儿子在豪宅内烧炭自杀,现场未留下遗书;当地里长称老人欠债300万买下该栋透天别墅,3人生前无家属处理后事。'}]