-
-
Notifications
You must be signed in to change notification settings - Fork 617
Block parameters for RNNs (for cudnn paths) #1855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Errors are in utils tests. Maybe I should comment those out to just have recurrent tests? |
|
Something else to think about. When using CUDNN the weight matrix is not organized like it is in the cpu implementation. I have yet to think of a good solution for this. The best I can think of is not ensuring how W is organized in docs/user facing details. We can have a forward facing api (through views like I've done here) and make guarantees on that. |
|
Sorry. Dumping more info/thoughts here. There is a lot of metadata required to pass to cudnn. I'm wondering if we should have a flux side metadata struct to deal with these particularities, especially if we decide to make multi-layer RNNs apart of the same interface. But maybe these can be taken care of in the parametric types. |
|
I haven't had time yet to go through this PR, but for:
If the metadata is CUDA-specific, one place to put it would be NNlibCUDA. It would also be cool to have a multi-layer RNN interface, but given the complexity of that I wouldn't worry about solving it in this PR :) |
|
Updated to use efficient view adjoint based on multigate. I also changed some things to better represent Flux's current state for recurrent layers. I was behind on some things. |
src/layers/recurrent.jl
Outdated
| Wi::VA | ||
| Wh::VA | ||
| b::VB | ||
| W::A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speculative, but instead of storing W as a field, I wonder if this could just call parent([Wi|Wh|b]). Possibly too janky?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That feels a bit too janky for my tastes. I agree the struct is a bit cluttered. We could override getproperty for these, but that might also be not great. I think having these is preferable to not.
One thing we maybe should do with these changes is rename W to weight so it is consistent with other parts of Flux.
test/layers/recurrent.jl
Outdated
| + rnn.cell.b)), | ||
| rnn.cell.Wh) | ||
| @test grads_seq[rnn.cell.Wh] ≈ bptt[1] | ||
| bptt = gradient(rnn.cell.W) do W |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some formatting and indentation funniness to work out after the implementation is locked in. These new test cases look much cleaner though!
|
I just realized this new version requires the use of a bias unit. I'm not sure how to solve that w/o a flag of some kind. |
|
The bias thing is an implementation pain for other frameworks as well. I recall either some documentation or a discussion thread about how much Keras had to bend over backwards to make it work. |
|
I dealt with the bias issue by dispatching on constructors and the I haven't tested on GPU. But I expect it should work. |
|
Are we sure that using If we go with this, we should make sure that using views interact nicely with Do we need to store the views in the struct at all? Is that for retrocompatibility? Can we overload getproperty instead? |
|
Good point! I do not have a benchmark comparing Flux RNNs/GRUs/LSTMs to cudnn. But this white paper does a good comparison over a variety of kernels. What I've found consistently online is that cudnn can be significantly faster than other kernels (compared with Theano, Torch, and Tensorflow based kernels). We could write a benchmark to do another comparison to the current flux rnns to cudnn, but because these haven't been optimized I would be very surprised if it was competitive in the forward and backward pass. We could likely replicate the cudnn kernels in pure julia to avoid blocking, but that would likely require a significant amount of engineering work. Another solution to avoid blocking could be to have an intermediary gpu array we copy the non-blocked weights to in the cudnn call, but that seemed less ideal and might make multi-layer cudnn support more challenging.
Yes, the custom functor handles that. I have not tested how we currently deal with the bias unit yet on GPUs, because my server has been out of commission. I'm hoping for that to be resolved today so I can test there.
We have multiple test cases already for bptt gradients. Any other cases you would want to see?
Definitely for compatibility and convenience. When moving to the cudnn algorithm the weights have a different organization than ours here. I was hoping to get feedback on the getproperty but I think it was buried in a feed. I'm more optimistic about overloading getproperty than I seem there: #1855 (comment). |
|
It seems that a lot of the complexity here comes from taking the storage cuDNN wants (packed array block) and de-sugaring it into what we like (separate blocks). Would it not be easier to reverse the logic? We store separate fields for each semantic block. The constructor by default will make these fields views of a monolithic block. The individual views are what we take gradients w.r.t. and writes to them should handle the accumulation and opt state correctly? We can have a level of indirection like For cuDNN and functoring, we define a (apologies if this was already discussed and I've missed something obvious) |
This changes the RNN implementation such that it stores all its parameters in a block. We include helper views to access specific parts of the weight matrix more easily. This is from some rumblings I've been doing to get make cudnn paths possible for Flux RNNs. The idea would be to port all these changes to the other recurrent cells, but that should happen after we are happy with the move.
Some points to consider:
I also cleaned up some of the testing code for recurrent cells. Because:
Also a test currently fails (the eltype test), which I thought we were wanting to move away from anyway. But maybe I'm misremembering.
PR Checklist