Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch-seq2seq slower than OpenNMT-py #27

Closed
kylegao91 opened this issue Jul 18, 2017 · 7 comments
Closed

pytorch-seq2seq slower than OpenNMT-py #27

kylegao91 opened this issue Jul 18, 2017 · 7 comments

Comments

@kylegao91
Copy link
Contributor

Benchmarked the two implementations using WMT's newstest2013 from German to English. See training logs in the gist. Despite accuracy differences, pytorch-seq2seq is 10 times slower than OpenNMT.py.

@PetrochukM
Copy link

PetrochukM commented Jul 19, 2017

@kylegao91 Changed pytorch-seq2seq in a private implementation. We were able to match the speeds of OpenNMT .

Things that made a big difference:

  • Removing fixed length batching and introducing variable sized batching had a 3 - 4x speed up. Pooling together similar sized examples reduced padding. We implemented this with torchtext.
  • Faster loss function similar to OpenNMT memory efficient loss. Instead of looping row by row evaluating the loss batch times. We transformed the target and output from 2D and 3D to 1D and 2D. Evaluated the loss once for the entire batch.
        # (seq len, batch size, dictionary size) -> (batch size * seq len, dictionary size)
        outputs = outputs.view(-1, outputs.size(2))
        # (seq len, batch size) -> (batch size * seq len)
        targets = targets.view(-1)
        self.criterion(outputs, targets)
  • Removed the DecoderRNN loop for updating length. We were able to use tensor operations and not include a python loop.
            eos_batches = symbols.view(-1).data.eq(self.eos_idx).nonzero()
            if eos_batches.dim() > 0:
                # (n, 1) => (n)
                eos_batches = eos_batches.view(-1)
                lengths[eos_batches] = len(sequence_symbols)

@kylegao91
Copy link
Contributor Author

@Deepblue129 Thanks a lot! I will try the ideas.

@kylegao91
Copy link
Contributor Author

@Deepblue129 Regarding the 3rd point, your code ignored the condition di < lengths[b_idx] so that the lengths in lengths might be longer than they should be. Look at #32 for modified version.

@cclauss
Copy link

cclauss commented Jul 20, 2017

If speed is essential, why not step up to Python 3.6 or to pypy? Both are faster that Python 2.7.

@kylegao91
Copy link
Contributor Author

@cclauss It's like, doing that would speed up an O(N^2) sorting algorithm, but it's still O(N^2) instead of O(Nlog(N))...
We definitely need to support python3 though, would appreciate any contribution to that end.

@PetrochukM
Copy link

PetrochukM commented Jul 21, 2017

@kylegao91 kylegao91 modified the milestones: Sprint 2, Sprint 1 Jul 31, 2017
@kylegao91 kylegao91 modified the milestones: Sprint 3, Sprint 2 Aug 31, 2017
@kylegao91
Copy link
Contributor Author

Now with #32, #55, and #73, this issue is done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants