-
-
Notifications
You must be signed in to change notification settings - Fork 617
Closed
Description
I am seeing what looks like a bug which has been re-introduced between Zygote versions 0.4.20 and 0.4.22. By using the newer version of Zygote (which is automatically installed with Flux 0.10.4), I am unable to obtain gradients from a convolutional VAE Flux model. I am running Julia 1.4.1.
Here is a piece of code which you can use to reproduce the bug, taken from the more general discussions at here and further discussion here.
module FLUXVAE
using Flux
using Flux: @epochs, binarycrossentropy
using DistributionsAD
using Distributions
# dummy data
function dummy_data()
d = Array{Float32}(zeros((796, 512, 1, 10))) .+ 1
batches = [reshape(d[:,:,:, i:i+4], (796, 512, 1, 5)) for i in 1:5]
end
struct Reshape
shape
end
Reshape(args...) = Reshape(args)
(r::Reshape)(x) = reshape(x, r.shape)
Flux.@functor Reshape ()
# convolutional encoder
function encoder()
conv1 = Conv((14, 10), 1 => 4, relu, stride = (10, 10), pad = 4)
pool1 = MaxPool((8, 8), stride = (4, 4), pad = 2)
conv2 = Conv((4, 3), 4 => 4, stride = (2, 2), pad = 1)
res = Reshape(280, :)
# enc1(X) = reshape(conv2(pool1(conv1(X))), (280, :))
# Chain(res, conv2, pool1, conv1)
Chain(conv1, pool1, conv2, res)
end
# decoder, I am using the one with transposed convolutions
function decoder(;dense_decoder = false)
if dense_decoder
dec = Dense(4, 796*512, sigmoid)
dec1(X) = reshape(dec(X), (796, 512, 1, :))
else
interaction1 = Dense(4, 280) # specific to my setup
res = Reshape(10, 7, 4, :)
# int1(X) = reshape(interaction1(X), (10, 7, 4, :))
tc1 = ConvTranspose((4, 3), 4 => 4, relu, stride = (2, 2), pad = 1)
tc2 = ConvTranspose((8, 8), 4 => 4, relu, stride = (4, 4), pad = 2)
tc3 = ConvTranspose((14, 10), 4 => 1, sigmoid, stride = (10, 10), pad = 4)
dec = Chain(interaction1, tc1, tc2, tc3) # for params
dec1 = Chain(interaction1, res, tc1, tc2, tc3)
end
return (dec, dec1)
end
# sample from z-distribution
z(μ::T, logσ) where {T} = μ + exp(logσ) * randn(T)
z(μ, logσ, eps) = μ + exp(logσ) * eps
# log(p(x|z)), log(p(z)), log(q(z|x))
logp_x_z1(X, z, dec1) = -sum(binarycrossentropy.(dec1(z), X))
logp_z(z::AbstractArray{T}) where {T} = sum((logpdf.(Normal(zero(T), one(T)), z)))
log_q_z_x(ϵ, log_sigma) = logpdf(Normal(zero(ϵ), one(ϵ)), ϵ) - log_sigma
# vae loss estimator
function vae_loss(enc1, dec1, μ1, logσ1)
mu(X) = μ1(enc1(X))
l(X) = logσ1(enc1(X))
e(X) = randn(eltype(X), size(l(X))) # latentdim1
z_(X) = z.(mu(X), l(X), e(X))
return X->-(logp_x_z1(X, z_(X), dec1) + logp_z(z_(X)) - sum(log_q_z_x.(e(X), l(X)))) * 1//5
end
# train vae1
function train!()
enc1 = encoder()
dec, dec1 = decoder()
# mean and log-variance of vae1's z-variable/latent space
μ1 = Dense(280, 4)
logσ1 = Dense(280, 4)
L1 = vae_loss(enc1, dec1, μ1, logσ1)
ps1 = Flux.params(enc1, μ1, logσ1, dec1)
batches = dummy_data()
@epochs 3 Flux.train!(L1, ps1, zip(batches), ADAM())
end
end
When I manually downgrade to using Zygote 0.4.20, the above code is run successfully, however I see the following stack trace using Zygote 0.4.22:
julia> FLUXVAE.train!()
[ Info: Epoch 1
ERROR: MethodError: no method matching ∇maxpool(::Array{Float64,4}, ::Array{Float32,4}, ::Array{Float32,4}, ::PoolDims{2,(8, 8),(4, 4),(2, 2, 2, 2),(1, 1)})
Closest candidates are:
∇maxpool(::AbstractArray{T,N}, ::AbstractArray{T,N}, ::AbstractArray{T,N}, ::PoolDims; kwargs...) where {T, N} at /home/aleco/.julia/packages/NNlib/FAI3o/src/pooling.jl:123
Stacktrace:
[1] (::Zygote.var"#1239#1240"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Array{Float32,4},PoolDims{2,(8, 8),(4, 4),(2, 2, 2, 2),(1, 1)},Array{Float32,4}})(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/lib/nnlib.jl:74
[2] (::Zygote.var"#2759#back#1241"{Zygote.var"#1239#1240"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Array{Float32,4},PoolDims{2,(8, 8),(4, 4),(2, 2, 2, 2),(1, 1)},Array{Float32,4}}})(::Array{Float64,4}) at /home/aleco/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[3] MaxPool at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/conv.jl:377 [inlined]
[4] (::typeof(∂(λ)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[5] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
[6] (::typeof(∂(applychain)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[7] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
[8] (::typeof(∂(applychain)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[9] Chain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:38 [inlined]
[10] (::typeof(∂(λ)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[11] l at ./REPL[11]:66 [inlined]
[12] (::typeof(∂(λ)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[13] z_ at ./REPL[11]:68 [inlined]
[14] (::typeof(∂(λ)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[15] #7 at ./REPL[11]:69 [inlined]
[16] (::typeof(∂(λ)))(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[17] #175 at /home/aleco/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
[18] #347#back at /home/aleco/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[19] #17 at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:89 [inlined]
[20] (::typeof(∂(λ)))(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[21] (::Zygote.var"#50#51"{Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:177
[22] gradient(::Function, ::Params) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:54
[23] macro expansion at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:88 [inlined]
[24] macro expansion at /home/aleco/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
[25] train!(::Main.FLUXVAE.var"#7#12"{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Main.FLUXVAE.Reshape,ConvTranspose{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},ConvTranspose{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},ConvTranspose{2,4,typeof(σ),Array{Float32,4},Array{Float32,1}}}},Main.FLUXVAE.var"#l#9"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}},Main.FLUXVAE.var"#e#10"{Main.FLUXVAE.var"#l#9"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Main.FLUXVAE.var"#z_#11"{Main.FLUXVAE.var"#mu#8"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}},Main.FLUXVAE.var"#l#9"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}},Main.FLUXVAE.var"#e#10"{Main.FLUXVAE.var"#l#9"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}}, ::Params, ::Base.Iterators.Zip{Tuple{Array{Array{Float32,4},1}}}, ::ADAM; cb::Flux.Optimise.var"#18#26") at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
[26] train!(::Function, ::Params, ::Base.Iterators.Zip{Tuple{Array{Array{Float32,4},1}}}, ::ADAM) at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:79
[27] macro expansion at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:122 [inlined]
[28] macro expansion at /home/aleco/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
[29] train!() at ./REPL[11]:82
[30] top-level scope at REPL[12]:1
As a side note, I wasn't sure where this issue should be raised since it effects both Flux.jl and Zygote.jl. I'm happy to cross-post if appropriate.
Metadata
Metadata
Assignees
Labels
No labels