-
-
Notifications
You must be signed in to change notification settings - Fork 130
Description
I tried to reused PyTorch weights in Flux and made the output fully consistent. But unfortunately, gradient results seem wrong. Finally I figured out that it's due to "Strided" MaxPool's wrong gradient. Following gists can be used to show it:
https://gist.github.com/yiyuezhuo/94a08637a6a80ac4a25eb89e5e605e16
compare_master.ipynb gist show present implementation comparison results.

By just replace it = with +=. The most inconsistency are removed.

But some inconsistency still hold. It's due to PyTorch's row major and Flux/Julia's column major difference. They follow their natural order to find the first element to apply gradients. So replace the iteration order can remove almost all inconsistency:
But this modification may damage performance, and just get vain gain.
Present unit test failed to detect this error, since reshape(1:n, ...) style example is unable to check the case in which an entry is maximum on many windows. The = override other gradients passed to it though.
I have prepared two branches to Pull Request:
fix_maxpool_gradient
fix_maxpool_gradient_further (change iteration order)
Which one should I PR? (Well, since this issue is ignored entirely, maybe no PR will be accepted. But the gradient is still wrong, a solution will be fine imo.)
