Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

中文plato2,单机单卡可以训练,单机多卡跑到一定步数就退出,无有用报错信息 #79

Closed
jidlin opened this issue Sep 5, 2021 · 7 comments

Comments

@jidlin
Copy link

jidlin commented Sep 5, 2021

中文对话数据,数据量400w,单卡可以跑完整个epoch,单机4卡运行到一定步数就退出

环境:
paddlepaddle-gpu==2.0.1
cuda==11.0
cudnn==8.0

终端报错是:

INFO 2021-09-05 21:51:40,245 launch_utils.py:327] terminate all the procs
ERROR 2021-09-05 21:51:40,245 launch_utils.py:584] ABORT!!! Out of all 4 trainers, the trainer process with rank=[3] was aborted. Please check its log.
INFO 2021-09-05 21:51:43,248 launch_utils.py:327] terminate all the procs`

work_log.3里面报错如下:

--------------------------------------
C++ Traceback (most recent call last):
--------------------------------------
0   paddle::framework::ParallelExecutor::Run(std::vector<std::string, std::allocator<std::string > > const&, bool)
1   paddle::framework::details::ScopeBufferedSSAGraphExecutor::Run(std::vector<std::string, std::allocator<std::string > > const&, bool)
2   paddle::framework::details::FastThreadedSSAGraphExecutor::Run(std::vector<std::string, std::allocator<std::string > > const&, bool)
3   paddle::framework::BlockingQueue<unsigned long>::Pop()
4   paddle::framework::SignalHandle(char const*, int)
5   paddle::platform::GetCurrentTraceBackString[abi:cxx11]()

----------------------
Error Message Summary:
----------------------
FatalError: `Termination signal` is detected by the operating system.
  [TimeInfo: *** Aborted at 1630683022 (unix time) try "date -d @1630683022" if you are using GNU date ***]
  [SignalInfo: *** SIGTERM (@0x3e800000a5e) received by PID 2812 (TID 0x7f718f576b80) from PID 2654 ***]
@jidlin
Copy link
Author

jidlin commented Sep 6, 2021

另外,400w的数据集分成2个200w的子数据集,多卡也能跑完一个epoch,400w的时候也没有发现内存不够的现象

@sserdoubleh
Copy link
Collaborator

是否稳定复现?大概多少 step 的时候出错,能否看下其他 GPU 的报错情况,有没有什么异常?

@jidlin
Copy link
Author

jidlin commented Sep 6, 2021

是否稳定复现?大概多少 step 的时候出错,能否看下其他 GPU 的报错情况,有没有什么异常?

稳定复现,batch_size 16000的时候,3270step左右报错;GPU内存好像无异常,单卡是正常跑的

附训练config:

job settings

job_script="./scripts/distributed/train.sh"

task settings

model=UnifiedTransformer
task=DialogGeneration

vocab_path="./package/dialog_cn/vocab.txt"
spm_model_file="./package/dialog_cn/spm.model"
train_file="./data/example/train_filelist"
valid_file="./data/example/valid_filelist"
data_format="raw"
file_format="filelist"
config_path="./package/dialog_cn/12L.json"

training settings

is_cn="true"
in_tokens="true"
batch_size=16000
lr=1e-3
warmup_steps=4000
weight_decay=0.01
num_epochs=2

log_steps=1
validation_steps=1000
save_steps=5000

log_dir="./log"
save_path="./output"

@jidlin
Copy link
Author

jidlin commented Sep 6, 2021

是否稳定复现?大概多少 step 的时候出错,能否看下其他 GPU 的报错情况,有没有什么异常?

刚刚看了下,多卡的时候GPU内存一直比较稳定,利用率基本在100%,但是CPU内存占用一直在稳步上升不知道怎么回事;单卡的时候CPU和GPU的内存占用都比较稳定

@jidlin
Copy link
Author

jidlin commented Sep 6, 2021

确定了,是内存一直增加报的错,但是日志里面没有oom的报错信息,很奇怪,为什么分布式训练CPU内存会一直增加呢

@sserdoubleh
Copy link
Collaborator

return self.cached[text]

有可能是因为这个 caced 的原因,保存住了分词结果,会占用 CPU 资源
对于大数据的训练,最好还是使用 knover/tools/pre_tokenize.py,先对整个数据集做分词,然后再设置 data_format="tokenized"(也可以使用 knover/tools/pre_numericalize.py,对应设置 data_format="numerical",可以参考 docs/usage.md)

@jidlin
Copy link
Author

jidlin commented Sep 6, 2021

return self.cached[text]

有可能是因为这个 caced 的原因,保存住了分词结果,会占用 CPU 资源
对于大数据的训练,最好还是使用 knover/tools/pre_tokenize.py,先对整个数据集做分词,然后再设置 data_format="tokenized"(也可以使用 knover/tools/pre_numericalize.py,对应设置 data_format="numerical",可以参考 docs/usage.md)

感谢!先分词之后CPU果然比较稳定了,模型训练中,这个感觉可以写到说明文档里面,数据量>400w建议先分词~

@jidlin jidlin closed this as completed Sep 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants