Skip to content

Commit

Permalink
Merge #1772
Browse files Browse the repository at this point in the history
1772: Expand RNN/LSTM/GRU docs r=ToucheSir a=mcognetta

This PR adds expanded documentation to the RNN/LSTM/GRU/GRUv3 docs, resolving #1696.

It addresses the `in` and `out` parameter meanings and adds a warning about a common gotcha (not calling reset when batch sizes are changed).

Co-authored-by: Marco Cognetta <cognetta.marco@gmail.com>
  • Loading branch information
bors[bot] and mcognetta committed Nov 23, 2021
2 parents 2053274 + fbd2ad2 commit 66a84ef
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,57 @@ end
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
# Examples
```jldoctest
julia> r = RNN(3, 5)
Recur(
RNNCell(3, 5, tanh), # 50 parameters
) # Total: 4 trainable arrays, 50 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
julia> r(rand(Float32, 3)) |> size
(5,)
julia> Flux.reset!(r);
julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
(5, 10)
```
!!! warning "Batch size changes"
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example:
```julia
julia> r = RNN(3, 5)
Recur(
RNNCell(3, 5, tanh), # 50 parameters
) # Total: 4 trainable arrays, 50 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
julia> r.state |> size
(5, 1)
julia> r(rand(Float32, 3)) |> size
(5,)
julia> r.state |> size
(5, 1)
julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
(5, 10)
julia> r.state |> size # state shape has changed
(5, 10)
julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
(50,)
```
"""
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
Recur(m::RNNCell) = Recur(m, m.state0)
Expand Down Expand Up @@ -178,8 +229,32 @@ Base.show(io::IO, l::LSTMCell) =
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
# Examples
```jldoctest
julia> l = LSTM(3, 5)
Recur(
LSTMCell(3, 5), # 190 parameters
) # Total: 5 trainable arrays, 190 parameters,
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
julia> l(rand(Float32, 3)) |> size
(5,)
julia> Flux.reset!(l);
julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
(5, 10)
```
!!! warning "Batch size changes"
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
Recur(m::LSTMCell) = Recur(m, m.state0)
Expand Down Expand Up @@ -243,8 +318,32 @@ Base.show(io::IO, l::GRUCell) =
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v1 of the referenced paper.
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
# Examples
```jldoctest
julia> g = GRU(3, 5)
Recur(
GRUCell(3, 5), # 140 parameters
) # Total: 4 trainable arrays, 140 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 792 bytes.
julia> g(rand(Float32, 3)) |> size
(5,)
julia> Flux.reset!(g);
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
(5, 10)
```
!!! warning "Batch size changes"
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
"""
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
Recur(m::GRUCell) = Recur(m, m.state0)
Expand Down Expand Up @@ -297,8 +396,32 @@ Base.show(io::IO, l::GRUv3Cell) =
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v3 of the referenced paper.
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
# Examples
```jldoctest
julia> g = GRUv3(3, 5)
Recur(
GRUv3Cell(3, 5), # 140 parameters
) # Total: 5 trainable arrays, 140 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 848 bytes.
julia> g(rand(Float32, 3)) |> size
(5,)
julia> Flux.reset!(g);
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
(5, 10)
```
!!! warning "Batch size changes"
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)
Expand Down

0 comments on commit 66a84ef

Please sign in to comment.