Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux.destructure gives MethodError when used with non-trainable parameters #1553

Closed
ctessum opened this issue Mar 30, 2021 · 2 comments
Closed

Comments

@ctessum
Copy link

ctessum commented Mar 30, 2021

Hello,

I've experienced an error (using Flux 0.11.6) with the following code:

using Flux

struct train_part
    a
    b
end

function (a::train_part)(x) 
    a.a * a.b * x
end

Flux.@functor train_part (a,) # Specify that only the 'a' matrix is trainable.

m = Chain(
    Dense(2, 2, tanh),
    train_part(zeros(2,2), zeros(2,2)),
)

Flux.destructure(m)

This is the error message:

ERROR: MethodError: no method matching train_part(::Matrix{Float64})
Closest candidates are:
  train_part(::Any, ::Any) at REPL[12]:2
Stacktrace:
  [1] (::var"#3#4")(y::NamedTuple{(:a,), Tuple{Matrix{Float64}}})
    @ Main ~/.julia/packages/Functors/YlETM/src/functor.jl:12
  [2] fmap1(f::Function, x::train_part)
    @ Functors ~/.julia/packages/Functors/YlETM/src/functor.jl:30
  [3] #fmap#13
    @ ~/.julia/packages/Functors/YlETM/src/functor.jl:35 [inlined]
  [4] (::Functors.var"#14#15"{IdDict{Any, Any}, Flux.var"#33#35"{Zygote.Buffer{Any, Vector{Any}}}})(x::train_part)
    @ Functors ~/.julia/packages/Functors/YlETM/src/functor.jl:35
  [5] map
    @ ./tuple.jl:214 [inlined]
  [6] fmap1(f::Function, x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, train_part}})
    @ Functors ~/.julia/packages/Functors/YlETM/src/functor.jl:30
  [7] fmap(f::Flux.var"#33#35"{Zygote.Buffer{Any, Vector{Any}}}, x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, train_part}}; cache::IdDict{Any, Any})
    @ Functors ~/.julia/packages/Functors/YlETM/src/functor.jl:35
  [8] fmap
    @ ~/.julia/packages/Functors/YlETM/src/functor.jl:34 [inlined]
  [9] destructure(m::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, train_part}})
    @ Flux ~/.julia/packages/Flux/goUGu/src/utils.jl:409
 [10] top-level scope
    @ REPL[16]:1

I think what might be happening is that destructure extracts one value from train_part, because only one of the values in train_part is trainable, and then it tries to reconstruct it with just the one variable, which does not work.

@ToucheSir
Copy link
Member

Ah yes, this is a case of FluxML/Functors.jl#6 / FluxML/Functors.jl#3. It should be resolved in Functors.jl 2.x, which you can get by upgrading to Flux 0.12. If you're not able to update Flux, you can implement functor manually per FluxML/Functors.jl#3 (comment).

@ctessum
Copy link
Author

ctessum commented Mar 31, 2021

Thanks, that fixes it!

@ctessum ctessum closed this as completed Mar 31, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants