Skip to content

Gradient calculation bug re-introduced in Flux v0.10.4 and Zygote v0.4.22 #1269

@alecokas

Description

@alecokas

I am seeing what looks like a bug which has been re-introduced between Zygote versions 0.4.20 and 0.4.22. By using the newer version of Zygote (which is automatically installed with Flux 0.10.4), I am unable to obtain gradients from a convolutional VAE Flux model. I am running Julia 1.4.1.

Here is a piece of code which you can use to reproduce the bug, taken from the more general discussions at here and further discussion here.

module FLUXVAE                                                                   
                                                                                 
using Flux                                                                       
using Flux: @epochs, binarycrossentropy                                          
using DistributionsAD                                                            
using Distributions                                                              
                                                                                 
# dummy data                                                                     
function dummy_data()                                                            
    d = Array{Float32}(zeros((796, 512, 1, 10))) .+ 1                            
    batches = [reshape(d[:,:,:, i:i+4], (796, 512, 1, 5)) for i in 1:5]          
end                                                                              
                                                                                 
struct Reshape                                                                   
    shape                                                                        
end                                                                              
                                                                                 
Reshape(args...) = Reshape(args)                                                 
(r::Reshape)(x) = reshape(x, r.shape)                                            
Flux.@functor Reshape ()                                                         
                                                                                 
                                                                                 
# convolutional encoder                                                          
function encoder()                                                               
    conv1 = Conv((14, 10), 1 => 4, relu, stride = (10, 10), pad = 4)             
    pool1 = MaxPool((8, 8), stride = (4, 4), pad = 2)                            
    conv2 = Conv((4, 3), 4 => 4, stride = (2, 2), pad = 1)                       
    res = Reshape(280, :)                                                        
    # enc1(X) = reshape(conv2(pool1(conv1(X))), (280, :))                        
    # Chain(res, conv2, pool1, conv1)                                            
    Chain(conv1, pool1, conv2, res)                                              
end                                                                              
                                                                                 
# decoder, I am using the one with transposed convolutions                       
function decoder(;dense_decoder = false)                                         
    if dense_decoder                                                             
        dec = Dense(4, 796*512, sigmoid)                                         
        dec1(X) = reshape(dec(X),  (796, 512, 1, :))                             
    else                                                                         
        interaction1 = Dense(4, 280)  # specific to my setup                     
        res = Reshape(10, 7, 4, :)                                               
        # int1(X) = reshape(interaction1(X), (10, 7, 4, :))                      
        tc1 = ConvTranspose((4, 3), 4 => 4, relu, stride = (2, 2), pad = 1)         
        tc2 = ConvTranspose((8, 8), 4 => 4, relu, stride = (4, 4), pad = 2)         
        tc3 = ConvTranspose((14, 10), 4 => 1, sigmoid, stride = (10, 10), pad = 4)
        dec = Chain(interaction1, tc1, tc2, tc3) # for params                    
        dec1 = Chain(interaction1, res, tc1, tc2, tc3)                           
    end                                                                          
    return (dec, dec1)                                                           
end                                                                              
                                                                                 
# sample from z-distribution                                                     
z(μ::T, logσ) where {T} = μ + exp(logσ) * randn(T)                               
z(μ, logσ, eps) = μ + exp(logσ) * eps                                            
                                                                                 
# log(p(x|z)), log(p(z)), log(q(z|x))                                            
logp_x_z1(X, z, dec1) = -sum(binarycrossentropy.(dec1(z), X))                    
logp_z(z::AbstractArray{T}) where {T} = sum((logpdf.(Normal(zero(T), one(T)), z)))
log_q_z_x(ϵ, log_sigma) = logpdf(Normal(zero(ϵ), one(ϵ)), ϵ) - log_sigma         
                                                                                 
# vae loss estimator                                                             
function vae_loss(enc1, dec1, μ1, logσ1)                                         
    mu(X) = μ1(enc1(X))                                                          
    l(X) =  logσ1(enc1(X))                                                       
    e(X) = randn(eltype(X), size(l(X))) # latentdim1                             
    z_(X) = z.(mu(X), l(X), e(X))                                                
    return X->-(logp_x_z1(X, z_(X), dec1) + logp_z(z_(X)) - sum(log_q_z_x.(e(X), l(X)))) * 1//5
end                                                                              
                                                                                 
# train vae1                                                                     
function train!()                                                                
    enc1 = encoder()                                                             
    dec, dec1 = decoder()                                                        
    # mean and log-variance of vae1's z-variable/latent space                    
    μ1 = Dense(280, 4)                                                           
    logσ1 = Dense(280, 4)                                                        
    L1 = vae_loss(enc1, dec1, μ1, logσ1)                                         
    ps1 = Flux.params(enc1, μ1, logσ1, dec1)                                     
    batches = dummy_data()                                                       
    @epochs 3 Flux.train!(L1, ps1, zip(batches), ADAM())                         
end                                                                              
                                                                                 
end                                                                               

When I manually downgrade to using Zygote 0.4.20, the above code is run successfully, however I see the following stack trace using Zygote 0.4.22:

julia> FLUXVAE.train!()
[ Info: Epoch 1
ERROR: MethodError: no method matching ∇maxpool(::Array{Float64,4}, ::Array{Float32,4}, ::Array{Float32,4}, ::PoolDims{2,(8, 8),(4, 4),(2, 2, 2, 2),(1, 1)})
Closest candidates are:
  ∇maxpool(::AbstractArray{T,N}, ::AbstractArray{T,N}, ::AbstractArray{T,N}, ::PoolDims; kwargs...) where {T, N} at /home/aleco/.julia/packages/NNlib/FAI3o/src/pooling.jl:123
Stacktrace:
 [1] (::Zygote.var"#1239#1240"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Array{Float32,4},PoolDims{2,(8, 8),(4, 4),(2, 2, 2, 2),(1, 1)},Array{Float32,4}})(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/lib/nnlib.jl:74
 [2] (::Zygote.var"#2759#back#1241"{Zygote.var"#1239#1240"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Array{Float32,4},PoolDims{2,(8, 8),(4, 4),(2, 2, 2, 2),(1, 1)},Array{Float32,4}}})(::Array{Float64,4}) at /home/aleco/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [3] MaxPool at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/conv.jl:377 [inlined]
 [4] (::typeof(∂(λ)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [5] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
 [6] (::typeof(∂(applychain)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [7] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
 [8] (::typeof(∂(applychain)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [9] Chain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:38 [inlined]
 [10] (::typeof(∂(λ)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [11] l at ./REPL[11]:66 [inlined]
 [12] (::typeof(∂(λ)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [13] z_ at ./REPL[11]:68 [inlined]
 [14] (::typeof(∂(λ)))(::Array{Float64,2}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [15] #7 at ./REPL[11]:69 [inlined]
 [16] (::typeof(∂(λ)))(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [17] #175 at /home/aleco/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
 [18] #347#back at /home/aleco/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [19] #17 at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:89 [inlined]
 [20] (::typeof(∂(λ)))(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#50#51"{Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:177
 [22] gradient(::Function, ::Params) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:54
 [23] macro expansion at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:88 [inlined]
 [24] macro expansion at /home/aleco/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [25] train!(::Main.FLUXVAE.var"#7#12"{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Main.FLUXVAE.Reshape,ConvTranspose{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},ConvTranspose{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},ConvTranspose{2,4,typeof(σ),Array{Float32,4},Array{Float32,1}}}},Main.FLUXVAE.var"#l#9"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}},Main.FLUXVAE.var"#e#10"{Main.FLUXVAE.var"#l#9"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Main.FLUXVAE.var"#z_#11"{Main.FLUXVAE.var"#mu#8"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}},Main.FLUXVAE.var"#l#9"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}},Main.FLUXVAE.var"#e#10"{Main.FLUXVAE.var"#l#9"{Chain{Tuple{Conv{2,4,typeof(relu),Array{Float32,4},Array{Float32,1}},MaxPool{2,4},Conv{2,4,typeof(identity),Array{Float32,4},Array{Float32,1}},Main.FLUXVAE.Reshape}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}}, ::Params, ::Base.Iterators.Zip{Tuple{Array{Array{Float32,4},1}}}, ::ADAM; cb::Flux.Optimise.var"#18#26") at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [26] train!(::Function, ::Params, ::Base.Iterators.Zip{Tuple{Array{Array{Float32,4},1}}}, ::ADAM) at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:79
 [27] macro expansion at /home/aleco/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:122 [inlined]
 [28] macro expansion at /home/aleco/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [29] train!() at ./REPL[11]:82
 [30] top-level scope at REPL[12]:1

As a side note, I wasn't sure where this issue should be raised since it effects both Flux.jl and Zygote.jl. I'm happy to cross-post if appropriate.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions