Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish generating chinese poetry #439

Merged
merged 14 commits into from
Nov 20, 2017
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion generate_chinese_poetry/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,141 @@
[TBD]
# 中国古诗生成

## 简介
基于编码器-解码器(encoder-decoder)神经网络模型,利用全唐诗进行诗句-诗句(sequence to sequence)训练,实现给定诗句后,生成下一诗句。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用一两句话描述一下默认的网络结构信息,例如默认几层LSTM encoder/decoder,是否带attention。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经更新README,增加了简要描述


以下是本例的简要目录结构及说明:

```text
.
├── data # 存储训练数据及字典
│ ├── download.sh # 下载原始数据
├── models # 存储训练好的模型
├── README.md # 文档
├── index.html # 文档(html格式)
├── preprocess.py # 原始数据预处理
├── generate.py # 生成诗句脚本
├── network_conf.py # 模型定义
├── reader.py # 数据读取接口
├── train.py # 训练脚本
└── utils.py # 定义实用工具函数
```

## 数据处理
### 原始数据来源
本例使用[中华古诗词数据库](https://github.com/chinese-poetry/chinese-poetry)中收集的全唐诗作为训练数据,共有约5.4万首唐诗。

### 原始数据下载
```bash
cd data && ./download.sh && cd ..
```
### 数据预处理
```bash
python preprocess.py --datadir data/raw --outfile data/poems.txt --dictfile data/dict.txt
```

上述脚本执行完后将生成处理好的训练数据poems.txt和数据字典dict.txt。poems.txt中每行为一首唐诗的信息,分为三列,分别为题目、作者、诗内容。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 数据字典 --> 字典。
  2. 默认情况下,字典如何构建?分词/分字?字频率统计,默认截断频率是多少,提供一些基本的信息。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经更新README,增加了字典构建的描述

在诗内容中,诗句之间用`.`分隔。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"." 分隔之后,训练数据的构造策略是什么?谁是源谁是目标?请解释一下数据策略。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经更新README,增加了数据构建的简要描述


训练数据示例:
```text
登鸛雀樓 王之渙 白日依山盡,黃河入海流.欲窮千里目,更上一層樓
觀獵 李白 太守耀清威,乘閑弄晚暉.江沙橫獵騎,山火遶行圍.箭逐雲鴻落,鷹隨月兔飛.不知白日暮,歡賞夜方歸
晦日重宴 陳嘉言 高門引冠蓋,下客抱支離.綺席珍羞滿,文場翰藻摛.蓂華彫上月,柳色藹春池.日斜歸戚里,連騎勒金羈
```

## 模型训练
训练脚本[train.py](./train.py)中的命令行参数如下:
```
Usage: train.py [OPTIONS]

Options:
--num_passes INTEGER Number of passes for the training task.
--batch_size INTEGER The number of training examples in one
forward/backward pass.
--use_gpu TEXT Whether to use gpu to train the model.
--trainer_count INTEGER The thread number used in training.
--save_dir_path TEXT The path to saved the trained models.
--encoder_depth INTEGER The number of stacked LSTM layers in encoder.
--decoder_depth INTEGER The number of stacked LSTM layers in decoder.
--train_data_path TEXT The path of trainning data. [required]
--word_dict_path TEXT The path of word dictionary. [required]
--init_model_path TEXT The path of a trained model used to initialized all
the model parameters.
--help Show this message and exit.
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 48 ~ 64 行删去。其它例子后面会考虑各自进行修改。
    • 这个命令行参数只是直接复制粘贴了python train.py --help 的运行结果,并没有提供比这个更多的信息,如果需要,用户可以自行执行脚本查看。只需要在README中提醒用户查看即可。
    • 直接复制粘贴也会让代码修改情况下,这里需要同步,增加工作量。

### 参数说明
- `num_passes`: 训练pass数
- `batch_size`: batch大小
- `use_gpu`: 是否使用GPU
- `trainer_count`: trainer数目,默认为1
- `save_dir_path`: 模型存储路径,默认为当前目录下models目录
- `encoder_depth`: 模型中编码器LSTM深度,默认为3
- `decoder_depth`: 模型中解码器LSTM深度,默认为3
- `train_data_path`: 训练数据路径
- `word_dict_path`: 数据字典路径
- `init_model_path`: 初始模型路径,从头训练时无需指定

### 训练执行
```bash
python train.py \
--num_passes 20 \
--batch_size 256 \
--use_gpu True \
--trainer_count 1 \
--save_dir_path models \
--train_data_path data/poems.txt \
--word_dict_path data/dict.txt \
2>&1 | tee train.log
```
每个pass训练结束后,模型参数将保存在models目录下。训练日志保存在train.log中。

### 最优模型参数
寻找cost最小的pass,使用该pass对应的模型参数用于后续预测。
```bash
python -c 'import utils; utils.find_optiaml_pass("./train.log")'
```

## 生成诗句
使用[generate.py](./generate.py)脚本对输入诗句生成下一诗句,
命令行参数如下:
```
Usage: generate.py [OPTIONS]

Options:
--model_path TEXT The path of the trained model for generation.
--word_dict_path TEXT The path of word dictionary. [required]
--test_data_path TEXT The path of input data for generation. [required]
--batch_size INTEGER The number of testing examples in one forward pass in
generation.
--beam_size INTEGER The beam expansion in beam search.
--save_file TEXT The file path to save the generated results.
[required]
--use_gpu TEXT Whether to use GPU in generation.
--help Show this message and exit.
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 104 ~ 115 行删去。原因同上。

  • 脚本 generate.py 的详细命令行参数请通过执行 python generate.py --help进行查阅。这里对重要参数进行说明。(后面如果需要说明请使用中文。)

### 参数说明
- `model_path`: 训练好的模型参数文件
- `word_dict_path`: 数据字典路径
- `test_data_path`: 输入数据路径
- `batch_size`: batch大小,默认为1
- `beam_size`: beam search中搜索范围大小,默认为5
- `save_file`: 输出保存路径
- `use_gpu`: 是否使用GPU

### 执行生成
例如将诗句 `白日依山盡,黃河入海流` 保存在文件 `input.txt` 中作为预测下句诗的输入,执行命令:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要这样构造源和目标。
源:“白日依山尽” --> 目标:"黄河如海流"。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改构建方法,并重新训练了模型,根据模型训练效果调整了默认训练参数,更新了例子

```bash
python generate.py \
--model_path models/pass_00014.tar.gz \
--word_dict_path data/dict.txt \
--test_data_path input.txt \
--save_file output.txt
```
生成结果将保存在文件 `output.txt` 中。对于上述示例输入,生成的诗句如下:
```text
-21.2048 不 知 身 外 事 , 何 處 是 閑 遊
-21.3982 不 知 身 外 事 , 何 處 是 何 由
-21.6564 不 知 身 外 事 , 何 處 是 何 求
-21.7312 不 知 身 外 事 , 何 事 是 何 求
-22.1956 不 知 身 外 事 , 何 處 是 人 愁
```
Loading