Skip to content

Error in the docs and in combining layers #329

@marcobonici

Description

@marcobonici

I need to train a Normalizing flow on some samples and then use it as a distribution.
@Red-Portal suggested me to use Bijectors.jl. However, when trying to follow the example in the documentation, I noticed there is an error in the documentation.

julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)]
ERROR: MethodError: no method matching length(::Inverse{Bijectors.Logit{Float64, Float64}})

Closest candidates are:
  length(!Matched::LibGit2.GitBlob)
   @ LibGit2 /opt/hostedtoolcache/julia/1.10.5/x64/share/julia/stdlib/v1.10/LibGit2/src/blob.jl:3
  length(!Matched::LibGit2.GitStatus)
   @ LibGit2 /opt/hostedtoolcache/julia/1.10.5/x64/share/julia/stdlib/v1.10/LibGit2/src/status.jl:21
  length(!Matched::DataStructures.DiBitVector)
   @ DataStructures ~/.julia/packages/DataStructures/95DJa/src/dibit_vector.jl:40

For my use case, that is not important, but I thought it was worth mentioning.

So, I tried to use a PlanarLayer on my use case, but it did not work: the training was performed, but the result was not satisfying. Then, as suggested by the tutorial, I tried to compose a couple of layers, to see it would have improved performance...and I got this error, when getting to the training

ArgumentError: broadcasting over dictionaries and `NamedTuple`s is reserved

Stacktrace:
 [1] broadcastable(::@NamedTuple{w::Vector{Float64}, u::Vector{Float64}, b::Vector{Float64}})
   @ Base.Broadcast ./broadcast.jl:744
 [2] broadcasted
   @ ./broadcast.jl:1345 [inlined]
 [3] (::var"#1#2")(θ::PlanarLayer{Vector{Float64}, Vector{Float64}}, ∇::@NamedTuple{w::Vector{Float64}, u::Vector{Float64}, b::Vector{Float64}})
   @ Main ./In[2]:31
 [4] (::Base.var"#4#5"{var"#1#2"})(a::Tuple{PlanarLayer{Vector{Float64}, Vector{Float64}}, @NamedTuple{w::Vector{Float64}, u::Vector{Float64}, b::Vector{Float64}}})
   @ Base ./generator.jl:36
 [5] iterate(::Base.Generator{Vector{Any}, IRTools.Inner.var"#52#53"{IRTools.Inner.var"#54#55"{IRTools.Inner.Block}}})
   @ Base ./generator.jl:47 [inlined]
 [6] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{@NamedTuple{outer::PlanarLayer{Vector{Float64}, Vector{Float64}}, inner::PlanarLayer{Vector{Float64}, Vector{Float64}}}, Tuple{@NamedTuple{w::Vector{Float64}, u::Vector{Float64}, b::Vector{Float64}}, @NamedTuple{w::Vector{Float64}, u::Vector{Float64}, b::Vector{Float64}}}}}, Base.var"#4#5"{var"#1#2"}})
   @ Base ./array.jl:834
 [7] map(::Function, ::@NamedTuple{outer::PlanarLayer{Vector{Float64}, Vector{Float64}}, inner::PlanarLayer{Vector{Float64}, Vector{Float64}}}, ::Tuple{@NamedTuple{w::Vector{Float64}, u::Vector{Float64}, b::Vector{Float64}}, @NamedTuple{w::Vector{Float64}, u::Vector{Float64}, b::Vector{Float64}}})
   @ Base ./abstractarray.jl:3406
 [8] top-level scope
   @ In[2]:30

Here below is the MWE, to reproduce the error. I am doing what is suggested in the docs, but maybe I misinterpreted something...? Thanks in advance for your help!

using Zygote
using Bijectors
using Functors
b = PlanarLayer(2)  PlanarLayer(2)
using Functors
θs, reconstruct = Functors.functor(b);
       
struct NLLObjective{R,D,T}
    reconstruct::R
    basedist::D
    data::T
end


function (obj::NLLObjective)(θs...)
    transformed_dist = transformed(obj.basedist, obj.reconstruct(θs))
    return -sum(Base.Fix1(logpdf, transformed_dist), eachcol(obj.data))
end

xs = randn(2, 1000);

f = NLLObjective(reconstruct, MvNormal(2, 1), xs);
       
@info "Initial loss: $(f(θs...))"

ε = 1e-3;

for i in 1:100
    ∇s = Zygote.gradient(f, θs...)
    θs = map(θs, ∇s) do θ, ∇
        θ - ε .*end
end
@info "Finall loss: $(f(θs...))"
       
samples = rand(transformed(f.basedist, f.reconstruct(θs)), 1000);

mean(eachcol(samples)) # ≈ [0, 0]

cov(samples; dims=2)   # ≈ I

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions