Skip to content

Commit

Permalink
Refine the timing of tf's PaddingRNN. (PaddlePaddle#7)
Browse files Browse the repository at this point in the history
* Refine the timing of tf's PaddingRNN.

* Refine the tf codes.

* Enable the profiler of lstm_tf.
  • Loading branch information
Xreki committed Apr 1, 2019
1 parent 062b404 commit bb44dd4
Show file tree
Hide file tree
Showing 9 changed files with 608 additions and 813 deletions.
36 changes: 15 additions & 21 deletions PaddingRNN/lstm_tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
```text
.
├── README.md # 文档
├── ptb_lm.py # 训练脚本
├── train.py # 训练脚本
├── reader.py # 数据读取
└── lm_model.py # 模型定义文件
└── ptb_lm_model.py # 模型定义文件
```


Expand All @@ -22,38 +22,32 @@ http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
参考tf文档https://www.tensorflow.org/tutorials/sequences/recurrent

## 训练
修改debug.sh中的参数为
CUDA_VISIBLE_DEVICES='0' python ptb_lm.py large False
修改run.sh中的参数为
`CUDA_VISIBLE_DEVICES='0' python train.py --model_type large --inference_only False`
运行命令
bash debug.sh
bash run.sh
开始单卡训练模型。

修改debug.sh中的参数为
CUDA_VISIBLE_DEVICES='' python ptb_lm.py large False
修改run.sh中的参数为
`CUDA_VISIBLE_DEVICES='' python train.py --model_type large --inference_only False`
运行命令
bash debug.sh
bash run.sh
开始cpu训练模型

##预测
修改debug.sh中的参数为
CUDA_VISIBLE_DEVICES='0' python ptb_lm.py large True
修改run.sh中的参数为
`CUDA_VISIBLE_DEVICES='0' python train.py --model_type large --inference_only True`
运行命令
bash debug.sh
bash run.sh
开始单卡预测模型

修改debug.sh中的参数为
CUDA_VISIBLE_DEVICES='' python ptb_lm.py large True
修改run.sh中的参数为
`CUDA_VISIBLE_DEVICES='' python train.py --model_type large --inference_only True`
运行命令
bash debug.sh
bash run.sh
开始cpu预测模型。

##debug.sh参数
ptb_lm.py后的
第一个参数对应模型大小,large为大模型
第二个参数对应是否只做预测


实现采用双层的lstm,具体的参数和网络配置 可以参考 train.py, lm_model.py 文件中的设置
实现采用双层的lstm,具体的参数和网络配置 可以参考 train.py, ptb_lm_model.py 文件中的设置.


## 与tf结果对比
Expand Down
118 changes: 60 additions & 58 deletions PaddingRNN/lstm_tf/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,72 +26,74 @@ def parse_args():
"--model_type",
type=str,
default="small",
help="model_type [test|small|med|big]")
help="model_type [test|small|medium|large]")
# parser.add_argument(
# "--embed_size",
# type=int,
# default=300,
# help="The dimension of embedding table. (default: %(default)d)")
# parser.add_argument(
# "--hidden_size",
# type=int,
# default=300,
# help="The size of rnn hidden unit. (default: %(default)d)")
# parser.add_argument(
# "--batch_size",
# type=int,
# default=32,
# help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--embed_size",
"--max_epoch",
type=int,
default=300,
help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
"--hidden_size",
type=int,
default=300,
help="The size of rnn hidden unit. (default: %(default)d)")
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--pass_num",
type=int,
default=5,
default=0,
help="The pass number to train. (default: %(default)d)")
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--use_gpu",
type=distutils.util.strtobool,
default=True,
help="Whether to use gpu. (default: %(default)d)")
parser.add_argument(
"--debug",
action='store_true',
help="Whether to print debug info. (default: %(default)d)")
# parser.add_argument(
# "--learning_rate",
# type=float,
# default=0.001,
# help="Learning rate used to train the model. (default: %(default)f)")
# parser.add_argument(
# "--use_gpu",
# type=distutils.util.strtobool,
# default=True,
# help="Whether to use gpu. (default: %(default)d)")
# parser.add_argument(
# "--debug",
# action='store_true',
# help="Whether to print debug info. (default: %(default)d)")
parser.add_argument(
"--profile",
action='store_true',
type=bool,
default=False,
help='Early exit if profile = True. (default: %(default)d)')
parser.add_argument(
"--save_dir",
type=str,
default="model",
help="Specify the path to save trained models.")
# parser.add_argument(
# "--save_dir",
# type=str,
# default="model",
# help="Specify the path to save trained models.")
parser.add_argument(
"--inference_only",
action='store_true',
help="start inference only")
parser.add_argument(
"--save_interval",
type=int,
default=1,
help="Save the trained model every n passes."
"(default: %(default)d)")
parser.add_argument('--trainset', nargs='+', help='train dataset')
parser.add_argument('--devset', nargs='+', help='dev dataset')
parser.add_argument('--testset', nargs='+', help='test dataset')
parser.add_argument('--vocab_dir', help='dict')
parser.add_argument('--max_p_num', type=int, default=5)
parser.add_argument('--max_a_len', type=int, default=200)
parser.add_argument('--max_p_len', type=int, default=500)
parser.add_argument('--max_q_len', type=int, default=9)
parser.add_argument('--doc_num', type=int, default=1)
parser.add_argument('--single_doc', action='store_true')
parser.add_argument('--simple_net', type=int, default=0)
parser.add_argument('--para_init', action='store_true')
type=bool,
default=False,
help="if use inference only")
# parser.add_argument(
# "--save_interval",
# type=int,
# default=1,
# help="Save the trained model every n passes."
# "(default: %(default)d)")
# parser.add_argument('--trainset', nargs='+', help='train dataset')
# parser.add_argument('--devset', nargs='+', help='dev dataset')
# parser.add_argument('--testset', nargs='+', help='test dataset')
# parser.add_argument('--vocab_dir', help='dict')
# parser.add_argument('--max_p_num', type=int, default=5)
# parser.add_argument('--max_a_len', type=int, default=200)
# parser.add_argument('--max_p_len', type=int, default=500)
# parser.add_argument('--max_q_len', type=int, default=9)
# parser.add_argument('--doc_num', type=int, default=1)
# parser.add_argument('--single_doc', action='store_true')
# parser.add_argument('--simple_net', type=int, default=0)
# parser.add_argument('--para_init', action='store_true')
parser.add_argument('--log_path', help='path of the log file. If not set, logs are printed to console')
args = parser.parse_args()
return args
6 changes: 0 additions & 6 deletions PaddingRNN/lstm_tf/debug2.sh

This file was deleted.

Loading

0 comments on commit bb44dd4

Please sign in to comment.