Skip to content

Commit

Permalink
Merge pull request #439 from will-am/chinese_poetry
Browse files Browse the repository at this point in the history
Add preprocessor for generating Chinese poetry.
  • Loading branch information
lcy-seso committed Nov 20, 2017
2 parents 66f18a1 + 3bbe91d commit ede5a04
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 14 deletions.
112 changes: 111 additions & 1 deletion generate_chinese_poetry/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,111 @@
[TBD]
# 中国古诗生成

## 简介
基于编码器-解码器(encoder-decoder)神经网络模型,利用全唐诗进行诗句-诗句(sequence to sequence)训练,实现给定诗句后,生成下一诗句。

模型中的编码器、解码器均使用堆叠双向LSTM (stacked bi-directional LSTM),默认均为3层,带有注意力单元(attention)。

以下是本例的简要目录结构及说明:

```text
.
├── data # 存储训练数据及字典
│ ├── download.sh # 下载原始数据
├── README.md # 文档
├── index.html # 文档(html格式)
├── preprocess.py # 原始数据预处理
├── generate.py # 生成诗句脚本
├── network_conf.py # 模型定义
├── reader.py # 数据读取接口
├── train.py # 训练脚本
└── utils.py # 定义实用工具函数
```

## 数据处理
### 原始数据来源
本例使用[中华古诗词数据库](https://github.com/chinese-poetry/chinese-poetry)中收集的全唐诗作为训练数据,共有约5.4万首唐诗。

### 原始数据下载
```bash
cd data && ./download.sh && cd ..
```
### 数据预处理
```bash
python preprocess.py --datadir data/raw --outfile data/poems.txt --dictfile data/dict.txt
```

上述脚本执行完后将生成处理好的训练数据poems.txt和字典dict.txt。字典的构建以字为单位,使用出现频数至少为10的字构建字典。

poems.txt中每行为一首唐诗的信息,分为三列,分别为题目、作者、诗内容。在诗内容中,诗句之间用`.`分隔。

训练数据示例:
```text
登鸛雀樓 王之渙 白日依山盡.黃河入海流.欲窮千里目.更上一層樓
觀獵 李白 太守耀清威.乘閑弄晚暉.江沙橫獵騎.山火遶行圍.箭逐雲鴻落.鷹隨月兔飛.不知白日暮.歡賞夜方歸
晦日重宴 陳嘉言 高門引冠蓋.下客抱支離.綺席珍羞滿.文場翰藻摛.蓂華彫上月.柳色藹春池.日斜歸戚里.連騎勒金羈
```

模型训练时,使用每一诗句作为模型输入,下一诗句作为预测目标。


## 模型训练
训练脚本[train.py](./train.py)中的命令行参数可以通过`python train.py --help`查看。主要参数说明如下:
- `num_passes`: 训练pass数
- `batch_size`: batch大小
- `use_gpu`: 是否使用GPU
- `trainer_count`: trainer数目,默认为1
- `save_dir_path`: 模型存储路径,默认为当前目录下models目录
- `encoder_depth`: 模型中编码器LSTM深度,默认为3
- `decoder_depth`: 模型中解码器LSTM深度,默认为3
- `train_data_path`: 训练数据路径
- `word_dict_path`: 数据字典路径
- `init_model_path`: 初始模型路径,从头训练时无需指定

### 训练执行
```bash
python train.py \
--num_passes 50 \
--batch_size 256 \
--use_gpu True \
--trainer_count 1 \
--save_dir_path models \
--train_data_path data/poems.txt \
--word_dict_path data/dict.txt \
2>&1 | tee train.log
```
每个pass训练结束后,模型参数将保存在models目录下。训练日志保存在train.log中。

### 最优模型参数
寻找cost最小的pass,使用该pass对应的模型参数用于后续预测。
```bash
python -c 'import utils; utils.find_optiaml_pass("./train.log")'
```

## 生成诗句
使用[generate.py](./generate.py)脚本对输入诗句生成下一诗句,命令行参数可通过`python generate.py --help`查看。
主要参数说明如下:
- `model_path`: 训练好的模型参数文件
- `word_dict_path`: 数据字典路径
- `test_data_path`: 输入数据路径
- `batch_size`: batch大小,默认为1
- `beam_size`: beam search中搜索范围大小,默认为5
- `save_file`: 输出保存路径
- `use_gpu`: 是否使用GPU

### 执行生成
例如将诗句 `孤帆遠影碧空盡` 保存在文件 `input.txt` 中作为预测下句诗的输入,执行命令:
```bash
python generate.py \
--model_path models/pass_00049.tar.gz \
--word_dict_path data/dict.txt \
--test_data_path input.txt \
--save_file output.txt
```
生成结果将保存在文件 `output.txt` 中。对于上述示例输入,生成的诗句如下:
```text
-9.6987 萬 壑 清 風 黃 葉 多
-10.0737 萬 里 遠 山 紅 葉 深
-10.4233 萬 壑 清 波 紅 一 流
-10.4802 萬 壑 清 風 黃 葉 深
-10.9060 萬 壑 清 風 紅 葉 多
```
11 changes: 11 additions & 0 deletions generate_chinese_poetry/data/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash

git clone https://github.com/chinese-poetry/chinese-poetry.git

if [ ! -d raw ]
then
mkdir raw
fi

mv chinese-poetry/json/poet.tang.* raw/
rm -rf chinese-poetry
8 changes: 5 additions & 3 deletions generate_chinese_poetry/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout):
for j in xrange(beam_size):
end_pos = gen_sen_idx[i * beam_size + j]
fout.write("%s\n" % ("%.4f\t%s" % (beam_result[0][i][j], " ".join(
id_to_text[w] for w in beam_result[1][start_pos:end_pos]))))
id_to_text[w] for w in beam_result[1][start_pos:end_pos - 1]))))
start_pos = end_pos + 2
fout.write("\n")
fout.flush
Expand Down Expand Up @@ -80,9 +80,11 @@ def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
encoder_hidden_dim=512,
decoder_depth=3,
decoder_hidden_dim=512,
is_generating=True,
bos_id=0,
eos_id=1,
max_length=9,
beam_size=beam_size,
max_length=10)
is_generating=True)

inferer = paddle.inference.Inference(
output_layer=beam_gen, parameters=parameters)
Expand Down
8 changes: 5 additions & 3 deletions generate_chinese_poetry/network_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def encoder_decoder_network(word_count,
encoder_hidden_dim,
decoder_depth,
decoder_hidden_dim,
bos_id,
eos_id,
max_length,
beam_size=10,
max_length=15,
is_generating=False):
src_emb = paddle.layer.embedding(
input=paddle.layer.data(
Expand Down Expand Up @@ -106,8 +108,8 @@ def encoder_decoder_network(word_count,
name=decoder_group_name,
step=_attended_decoder_step,
input=group_inputs + [gen_trg_emb],
bos_id=0,
eos_id=1,
bos_id=bos_id,
eos_id=eos_id,
beam_size=beam_size,
max_length=max_length)

Expand Down
76 changes: 76 additions & 0 deletions generate_chinese_poetry/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
import os
import io
import re
import json
import click
import collections


def build_vocabulary(dataset, cutoff=0):
dictionary = collections.defaultdict(int)
for data in dataset:
for sent in data[2]:
for char in sent:
dictionary[char] += 1
dictionary = filter(lambda x: x[1] >= cutoff, dictionary.items())
dictionary = sorted(dictionary, key=lambda x: (-x[1], x[0]))
vocab, _ = list(zip(*dictionary))
return (u"<s>", u"<e>", u"<unk>") + vocab


@click.command("preprocess")
@click.option("--datadir", type=str, help="Path to raw data")
@click.option("--outfile", type=str, help="Path to save the training data")
@click.option("--dictfile", type=str, help="Path to save the dictionary file")
def preprocess(datadir, outfile, dictfile):
dataset = []
note_pattern1 = re.compile(u"(.*?)", re.U)
note_pattern2 = re.compile(u"〖.*?〗", re.U)
note_pattern3 = re.compile(u"-.*?-。?", re.U)
note_pattern4 = re.compile(u"(.*$", re.U)
note_pattern5 = re.compile(u"。。.*)$", re.U)
note_pattern6 = re.compile(u"。。", re.U)
note_pattern7 = re.compile(u"[《》「」\[\]]", re.U)
print("Load raw data...")
for fn in os.listdir(datadir):
with io.open(os.path.join(datadir, fn), "r", encoding="utf8") as f:
for data in json.load(f):
title = data['title']
author = data['author']
p = "".join(data['paragraphs'])
p = "".join(p.split())
p = note_pattern1.sub(u"", p)
p = note_pattern2.sub(u"", p)
p = note_pattern3.sub(u"", p)
p = note_pattern4.sub(u"", p)
p = note_pattern5.sub(u"。", p)
p = note_pattern6.sub(u"。", p)
p = note_pattern7.sub(u"", p)
if (p == u"" or u"{" in p or u"}" in p or u"{" in p or
u"}" in p or u"、" in p or u":" in p or u";" in p or
u"!" in p or u"?" in p or u"●" in p or u"□" in p or
u"囗" in p or u")" in p):
continue
paragraphs = re.split(u"。|,", p)
paragraphs = filter(lambda x: len(x), paragraphs)
if len(paragraphs) > 1:
dataset.append((title, author, paragraphs))

print("Construct vocabularies...")
vocab = build_vocabulary(dataset, cutoff=10)
with io.open(dictfile, "w", encoding="utf8") as f:
for v in vocab:
f.write(v + "\n")

print("Write processed data...")
with io.open(outfile, "w", encoding="utf8") as f:
for data in dataset:
title = data[0]
author = data[1]
paragraphs = ".".join(data[2])
f.write("\t".join((title, author, paragraphs)) + "\n")


if __name__ == "__main__":
preprocess()
16 changes: 9 additions & 7 deletions generate_chinese_poetry/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_initial_model(model_path, parameters):
@click.option(
"--decoder_depth",
default=3,
help="The number of stacked LSTM layers in encoder.")
help="The number of stacked LSTM layers in decoder.")
@click.option(
"--train_data_path", required=True, help="The path of trainning data.")
@click.option(
Expand Down Expand Up @@ -75,10 +75,9 @@ def train(num_passes,
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)

# define optimization method and the trainer instance
optimizer = paddle.optimizer.AdaDelta(
learning_rate=1e-3,
gradient_clipping_threshold=25.0,
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
optimizer = paddle.optimizer.Adam(
learning_rate=1e-4,
regularization=paddle.optimizer.L2Regularization(rate=1e-5),
model_average=paddle.optimizer.ModelAverage(
average_window=0.5, max_average_window=2500))

Expand All @@ -88,7 +87,10 @@ def train(num_passes,
encoder_depth=encoder_depth,
encoder_hidden_dim=512,
decoder_depth=decoder_depth,
decoder_hidden_dim=512)
decoder_hidden_dim=512,
bos_id=0,
eos_id=1,
max_length=9)

parameters = paddle.parameters.create(cost)
if init_model_path:
Expand All @@ -113,7 +115,7 @@ def event_handler(event):
(event.pass_id, event.batch_id))
save_model(trainer, save_path, parameters)

if not event.batch_id % 5:
if not event.batch_id % 10:
logger.info("Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics))

Expand Down

0 comments on commit ede5a04

Please sign in to comment.