Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call of Flux.stack results in StackOverfloxError for approx. 6000 sequence elements of a model output of a LSTM #1585

Open
VoSiLk opened this issue Apr 28, 2021 · 7 comments

Comments

@VoSiLk
Copy link

VoSiLk commented Apr 28, 2021

I may have discovered an error. When I simulate the trained model with a test data set longer than 10000 data points, I get a StackOverflowError due to the Flux.stack function

https://discourse.julialang.org/t/how-to-arrange-data-for-time-series-forecasting-mini-batching-without-violoating-the-gpu-memory-for-a-lstm/54171/8

using JLD2
using Flux
using CUDA

X_train = rand(80000, 6)
Y_train = rand(80000, 1)

X_test = rand(60000, 6)
Y_test = rand(60000, 1)

N, no_features = size(X_train)

# batching
gpu_or_cpu = cpu

batch_size=1

seq_len = 1000
num_batches = Int(floor(N/(batch_size+seq_len)))

X_batched = [[Float32.(hcat([X_train[j+i+ k*(seq_len+batch_size), :] for j in 0:batch_size-1]...)) for i in 1:seq_len] for k in 0:num_batches-1] |> gpu_or_cpu
Y_batched = [Float32.(vcat([Y_train[j+seq_len+ k*(seq_len+batch_size)] for j in 0:batch_size-1]...)) for k in 0:num_batches-1] |> gpu_or_cpu

data_train = zip(X_batched, Y_batched)

if gpu_or_cpu ==gpu
    CUDA.allowscalar(false)
end

# convert to cpu or gpu
X_batched =  X_batched |> gpu_or_cpu
Y_batched = Y_batched |> gpu_or_cpu
data_train = zip(X_batched, Y_batched)

# select optimizer
opt = ADAM(0.001, (0.9, 0.999))

function loss(X,Y)
    Flux.reset!(model)
    mse_val = sum(abs2.(Flux.stack(model.(X),1)[end, :] .-Y))
    return mse_val
end

# ini of the model
model = Chain(LSTM(no_features, 70), LSTM(70, 70), LSTM(70, 70), Dense(70, 1, relu)) |> gpu_or_cpu
ps = Flux.params(model)
Flux.reset!(model)

# train one epoch
@time Flux.train!(loss, ps, data_train, opt)
x_test= [vec(Float32.(X_test[i,:])) for i in 1:size(X_test,1)]

y_model = Flux.stack(model.(x_test),1)

it doesn’t occur with

x_test= [vec(Float32.(X_test[i,:])) for i in 1:size(X_test,1)]

y_model = vcat(model.(x_test)...)

Used versions of Julia and Flux:
Julia 1.6.0
Flux v0.12.2

@DhairyaLGandhi
Copy link
Member

What is the length it doesn't occur with? Could you post the complete stacktrace?

@VoSiLk
Copy link
Author

VoSiLk commented Apr 28, 2021

Between a length of 5000 and 6000.

StackOverflowError:
Stacktrace:
 [1] _cat_size_shape(::Tuple{Bool}, ::Tuple{Int64, Int64}, ::Matrix{Float32}, ::Matrix{Float32}, ::Vararg{Mat
   @ Base .\abstractarray.jl:1602
 [2] cat_size_shape(::Tuple{Bool}, ::Matrix{Float32}, ::Matrix{Float32}, ::Vararg{Matrix{Float32}, N} where N
   @ Base .\abstractarray.jl:1600
 [3] _cat_t(::Int64, ::Type{Float32}, ::Matrix{Float32}, ::Vararg{Matrix{Float32}, N} where N)
   @ Base .\abstractarray.jl:1646
   @ Base .\abstractarray.jl:1643
 [5] _cat(::Int64, ::Matrix{Float32}, ::Vararg{Matrix{Float32}, N} where N)
   @ SparseArrays C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\SparseArrays\src\sparsev
 [6] cat(::Matrix{Float32}, ::Vararg{Matrix{Float32}, N} where N; dims::Int64)
   @ Base .\abstractarray.jl:1781
 [7] stack(xs::Vector{Vector{Float32}}, dim::Int64)
   @ Flux ~\.julia\packages\Flux\6BByF\src\utils.jl:476
 [8] top-level scope
   @ F:\DataBasedModeling\FluxLSTMGPU4Upload.jl:172
 [9] eval
   @ .\boot.jl:360 [inlined]

@DhairyaLGandhi
Copy link
Member

Interesting, this seems to be happening in Base.

@darsnack
Copy link
Member

I've encountered this error before when splatting inside cat for large lengths. The solution is to do reduce((x, y) -> cat(...), ...). Our stack code should be updated not to use splatting.

@darsnack
Copy link
Member

Looks like the limit on number of splattable arguments is known: JuliaLang/julia#30796 (comment)

@DhairyaLGandhi
Copy link
Member

Well that confirms its a Base thing.

@darsnack
Copy link
Member

Well a Base thing that is not going to be addressed. So we still need to update our code so that we don't hit the error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants