Skip to content

Commit

Permalink
fix bug on minibatch dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanemac authored and paraynaud committed Jul 29, 2022
1 parent 06cbfee commit 51c63bc
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
"""
flag_dim(x)
Returns true if x has 3 dimensions.
This function is used to reshape X in `create_minibatch(X, Y, minibatch_size)` in case x has only 3 dimensions.
"""
flag_dim(x) = length(size(x)) == 3

"""
create_minibatch(X, Y, minibatch_size)
Create a minibatch's iterator of the data `X`, `Y` of size `1/minibatch_size * length(Y)`.
"""
create_minibatch(x_data, y_data, minibatch_size) =
minibatch(x_data, y_data, minibatch_size; xsize = (size(x_data, 1), size(x_data, 2), size(x_data, 3), :))
function create_minibatch(x_data, y_data, minibatch_size; _reshape::Bool=flag_dim(x_data))
mb = minibatch(x_data, y_data, minibatch_size; xsize = (size(x_data, 1), size(x_data, 2), size(x_data, 3), :))
if _reshape
mb.xsize = (size(x_data, 1), size(x_data, 2), 1, :) # To force x_data to take 1 as third dimension
end
return mb
end

"""
vector_params(chain :: C) where C <: Chain
Expand Down

0 comments on commit 51c63bc

Please sign in to comment.