Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A issue about multi-gpu #235

Closed
luyaojie opened this issue Sep 6, 2017 · 4 comments
Closed

A issue about multi-gpu #235

luyaojie opened this issue Sep 6, 2017 · 4 comments

Comments

@luyaojie
Copy link

luyaojie commented Sep 6, 2017

Hello, all.

I was encountering a error when I ran openmt with multi-gpu option

python train.py -data data/demo -save_model demo-model -word_vec_size 620 -gpuid 0 1 2 3

Traceback (most recent call last):
File "train.py", line 309, in
main()
File "train.py", line 270, in main
model.encoder.embeddings.load_pretrained_vectors(opt.pre_word_vecs_enc)
File "/home1/yaojie/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 262, in getattr
type(self).name, name))
AttributeError: 'DataParallel' object has no attribute 'encoder'

@srush
Copy link
Contributor

srush commented Sep 7, 2017 via email

@dalegebit
Copy link

dalegebit commented Sep 7, 2017

Your problem can be simply fixed by moving this line model.encoder.embeddings.load_pretrained_vectors(opt.pre_word_vecs_enc) to the position before model becomes DataParallel

Actually there are totally two main problems concerning multi-gpus:

  1. When a batch is split into several parts, the the largest length of each small part may not be equal to the original one of the whole batch. So when recovered to padded LongTensors through unpack, the recovered size may not be equal for each small part. I recommend fixing it by passing an extra argument explicitly indicating the largest length of the whole batch and then concatenating an extra padding to the output according to the largest length. See this: #Pad PackedSequences to original batch length pytorch/pytorch#1591. You can also refer to my implementation: #https://github.com/dalegebit/OpenNMT-py/blob/d09323599b9ff5759b0daf08d814118faf0716c1/onmt/Models.py#L217
  2. Pytorch currently don't support DataParallel returning dict or instances of custom classes. I have fixed this: #Allow DataParallel returning dict and any other instances of iterable custom classes pytorch/pytorch#2511. And in the meantime, you should turn RNNDecoderState into an iterable: #https://github.com/dalegebit/OpenNMT-py/blob/d09323599b9ff5759b0daf08d814118faf0716c1/onmt/Models.py#L508

@srush
Copy link
Contributor

srush commented Sep 7, 2017

@dalegebit we would love a PR if you have one.

@dalegebit
Copy link

Sure

marcotcr pushed a commit to marcotcr/OpenNMT-py that referenced this issue Sep 20, 2017
@vince62s vince62s closed this as completed Aug 2, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants