Skip to content

Conversation

zuhengxu
Copy link
Member

@zuhengxu zuhengxu commented Jun 5, 2023

TODOs:

  • variational obj
    • ELBO
    • ELBO with mini-batch subsampling in logp (will work in next pr)
    • MLE
  • user interface
    • migrate to ADTypes.jl
  • training loops
  • one example to demonstrate the usr interface (planar flow learn 2d banana)
  • tests
  • docs
  • clean example/

@zuhengxu zuhengxu requested review from sunxd3 and torfjelde June 10, 2023 07:56
# destruct flow for explicit access to the parameters
# destructure can result in some overhead when the flow length is large
@info "desctructuring flow..."
θ_flat, re = Flux.destructure(flow)
Copy link
Member Author

@zuhengxu zuhengxu Jun 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 225 can be really slow as mentioned in line 24; I tried the flatten(x) function written in https://github.com/TuringLang/AdvancedVI.jl/discussions/46#discussioncomment-5427443, which is also slow.

For reference, Flux.destructure(flow) with flow being 20-layer 2d planar layers takes ridiculously long time---way longer than the training time. I was wondering is there is a better way of handling the explicit parameterization in terms of performance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my reply to that comment:

For reference, the excution of the following code took over an hour and I just kill the program for the lack of patience

This doesn't have anything to do with the usage of destructure; it's because you're doing ∘(...).

This then causes a ComposedFunction which is n_layers nested (ComposedFunction(ComposedFunction(ComposedFunction(...)), ...); you get the gist). This blows up the compilation times.

For "large" compilations, you should instead use something like https://github.com/oschulz/FunctionChains.jl. Then you could use a Vector for to represent the composition which would make the compilation time waaaaay better (because now "iterate" over the recursion at runtime rather than at compile-time).

I also wonder if we could deal with this by subtyping Function due to https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues thinking

But can discuss this issue more somewhere else +1

@zuhengxu
Copy link
Member Author

Two major concern so far on the current design:

  1. [AbstractDifferentiation.jl]: For the consistency of the code and ease of switching between different AD systems, I used AbstractDifferentiation.jl. It is funny that AdvancedVI.jl just did the same thing by migrating to AbstractDifferentiation.jl. But as dicussed in Basic rewrite of the package 2023 edition AdvancedVI.jl#45 (comment), it doesn't seem to be ideal in terms of performance.

I haven't digged into the “severe limitation and performance issue” about AbstractDifferentiation.jl, but one thing I did realize is that we are missing many nice configuration to some of the AD (e.g., compiled tape for ReverseDiff.jl). (Please let me know if I'm claiming wrong things here, which is very likely).

@sunxd3 @torfjelde what do you think of AbstractDifferentiation.jl? Should we keep what I have there or implement specialized grad! function as did in AdvancedVI.jl.

  1. [Flux.destructure]: As i mentioned in 5a7df3d#r1225209163, this is ridiculously slow when desctructuring a long flow. For reference, the excution of the following code took over an hour and I just kill the program for the lack of patience,
using Bijectors, Flux, Random, Distributions, LinearAlgebra

function create_planar_flow(n_layers::Int, q₀)
    d = length(q₀)
    Ts = ([PlanarLayer(d) for _ in 1:n_layers]...)
    return transformed(q₀, Ts)
end

flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I)) # create 20 layers planar flow
theta, re = Flux.destructure(flow) # this is sooooooo slow

I wonder if we have better ways of handling the parameter flattening/unflattening, with better performance without lossing support to those AD systems (at least ForwardDiff, ReverseDiff, Zygote).

@torfjelde
Copy link
Member

what do you think of AbstractDifferentiation.jl? Should we keep what I have there or implement specialized grad! function as did in AdvancedVI.jl.

As discussed there, I don't think it's quite there yet. We can try using https://github.com/SciML/ADTypes.jl mentioned in the other PR, and then we just define the grad! function or whatever, similar to in AdvancedVI.jl.

@zuhengxu
Copy link
Member Author

@torfjelde Thanks for your comments about the compilation time for the nested Composed(...); FunctionChains.jl does address the issue. I have a few follow up questions on this, which might be trivial to most Julia developers so I'm asking it here.

This then causes a ComposedFunction which is n_layers nested (ComposedFunction(ComposedFunction(ComposedFunction(...)), ...); you get the gist). This blows up the compilation times.

For "large" compilations, you should instead use something like https://github.com/oschulz/FunctionChains.jl. Then you could use a Vector for to represent the composition which would make the compilation time waaaaay better (because now "iterate" over the recursion at runtime rather than at compile-time).

First to doublecheck my understanding on this issue. It's not the running time for this flattening (recursion through the nested compositions) being slow, but the compilation time. This must be a dumb question, but why the compilation time is THIS much slower than the run-time?

I also wonder if we could deal with this by subtyping Function due to https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues 🤔

I'm also not quite clear on why subtyping can make the compilation faster in our case. (honestly I don't think I fully understand the Compiler efficiency issues---appology for my lack of knowldge on this) The link seems to explain why enforcing type for many functions can causes significant burst in compilation time. This issue is related because our flow layers are all typed functions. But I don't see why subtyping these flow layers helps.

@zuhengxu
Copy link
Member Author

Another Q: should we just merge NormalizingFlows.jl into AdvancedVI.jl directly? NF is just a special case of VI, and I don't see any reason why the two pkgs don't share the same usr interface? I originallt thought the development of AdvancedVI,jl is ceased for some reason. But given that it's also rewritten now, why don't we put all forces in and just make Normalizingflow another algorithm of VI?

@Red-Portal
Copy link
Member

Hi Zuheng/David!

I think you can worry about that later, since once AdvancedVI's interface solidifies, it should be straightforward to combine the two.

Best wishes for your GSoC!

@torfjelde
Copy link
Member

Another Q: should we just merge NormalizingFlows.jl into AdvancedVI.jl directly? NF is just a special case of VI, and I don't see any reason why the two pkgs don't share the same usr interface? I originallt thought the development of AdvancedVI,jl is ceased for some reason. But given that it's also rewritten now, why don't we put all forces in and just make Normalizingflow another algorithm of VI?

So this would indeed be the idea:) But as @Red-Portal is saying, I think for now we just keep it separate until AdvancedVI's interface solidifies. Once that is done, we just replace the VI-related stuff in NormalizingFlows.jl with the interface from AdvancedVI.

I'm very happy to have you involved in the discussion for AdvancedVI.jl though! I think you'll have some valuable insights into what we actually want from a "general variational inference"-interface.

@torfjelde
Copy link
Member

It's not the running time for this flattening (recursion through the nested compositions) being slow, but the compilation time. This must be a dumb question, but why the compilation time is THIS much slower than the run-time?

Generally speaking, the more work you do before you actually execute the code (compile-time), the faster the code will be (lower run-time), right?

In Julia, you can even inspect these things directly:

julia> f(x) = x^2
f (generic function with 1 method)

julia> @code_typed optimize=false f(1)
CodeInfo(
1%1 = Main.:^::Core.Const(^)
│   %2 = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│   %3 = (%2)()::Core.Const(Val{2}())
│   %4 = Base.literal_pow(%1, x, %3)::Int64
└──      return %4
) => Int64

julia> @code_typed optimize=true f(1)
CodeInfo(
1%1 = Base.mul_int(x, x)::Int64
└──      return %1
) => Int64

As you see there, when we do optimize=true, the entire code is removed for a simple multiplication of two integers.

So what happens when you compose to functions?

julia> g = f  f
f  f

julia> @code_typed optimize=false g(1)
CodeInfo(
1%1 = Base.:(var"#_#97")::Core.Const(Base.var"#_#97")
│   %2 = Core.NamedTuple()::Core.Const(NamedTuple())
│   %3 = Base.pairs(%2)::Core.Const(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())
│   %4 = Core.tuple(%3, c)::Core.Const((Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), f  f))
│   %5 = Core._apply_iterate(Base.iterate, %1, %4, x)::Int64
└──      return %5
) => Int64

julia> @code_typed optimize=true g(1)
CodeInfo(
1%1 = Core.getfield(x, 1)::Int64%2 = Base.mul_int(%1, %1)::Int64%3 = Base.mul_int(%2, %2)::Int64
└──      return %3
) => Int64

Now, in the optimized version, all the calls to f are completely gone! But clearly we had to do some additional work here as we had to effectively inline the calls to f in addition to optimize the operations inside of f.

You can also time these things:

julia> @time f(1)  # <= doesn't say anything about compilation time because it's so fast
  0.000001 seconds
1

julia> @time g(1)  # <= now the compilation time is actually noticable
  0.004762 seconds (789 allocations: 54.535 KiB, 99.77% compilation time)
1

Note that if I re-run the above, then there's no compilation time since it's already been compiled:

julia> @time f(1)
  0.000001 seconds
1

julia> @time g(1)
  0.000002 seconds
1

Taking it a step further:

julia> g_many = reduce(, fill(f, 10))
f  f  f  f  f  f  f  f  f  f

julia> @time g_many(1)
  0.008233 seconds (31.46 k allocations: 2.070 MiB, 99.69% compilation time)
1

Compilation of g_many would be even slower if it wasn't for the fact that have already compiled f and g (by calling them).

Let's look at the resulting code:

julia> @code_typed optimize=false g_many(1)
CodeInfo(
1%1 = Base.:(var"#_#97")::Core.Const(Base.var"#_#97")
│   %2 = Core.NamedTuple()::Core.Const(NamedTuple())
│   %3 = Base.pairs(%2)::Core.Const(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())
│   %4 = Core.tuple(%3, c)::Core.Const((Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), f  f  f  f  f  f  f  f  f  f))
│   %5 = Core._apply_iterate(Base.iterate, %1, %4, x)::Int64
└──      return %5
) => Int64

julia> @code_typed optimize=true g_many(1)
CodeInfo(
1%1  = Core.getfield(x, 1)::Int64%2  = Base.mul_int(%1, %1)::Int64%3  = Base.mul_int(%2, %2)::Int64%4  = Base.mul_int(%3, %3)::Int64%5  = Base.mul_int(%4, %4)::Int64%6  = Base.mul_int(%5, %5)::Int64%7  = Base.mul_int(%6, %6)::Int64%8  = Base.mul_int(%7, %7)::Int64%9  = Base.mul_int(%8, %8)::Int64%10 = Base.mul_int(%9, %9)::Int64%11 = Base.mul_int(%10, %10)::Int64
└──       return %11
) => Int64

So in the optimized version, we have actually unrolled all the methods, and just boiled it down to a sequence of mul_int! That's quite a bit of work at compilation time:)

We can benchmark it:

julia> using BenchmarkTools

julia> @btime $f($1)
  2.523 ns (0 allocations: 0 bytes)
1

julia> @btime $g_many($1)
  2.716 ns (0 allocations: 0 bytes)
1

So this is pretty impressive, right? The runtime is almost the same.

We can compare this to a "bad" composed implementation.

julia> struct BadComposed
           fs
       end

julia> function (bc::BadComposed)(x)
           y = bc.fs[1](x)
           for f in bc.fs[2:end]
               y = f(y)
           end
           return y
       end

julia> 

julia> g_many_bad = BadComposed(fill(f, 10))
BadComposed([f, f, f, f, f, f, f, f, f, f])

julia> @time g_many_bad(1)
  0.045895 seconds (34.42 k allocations: 2.282 MiB, 99.71% compilation time)
1

julia> @code_typed optimize=true g_many_bad(1)
CodeInfo(
1 ── %1  = Base.getfield(bc, :fs)::Any%2  = Base.getindex(%1, 1)::Any%3  = (%2)(x)::Any%4  = Base.getfield(bc, :fs)::Any%5  = Base.lastindex(%4)::Any%6  = (isa)(%5, Int64)::Bool
└───       goto #8 if not %6
2 ── %8  = π (%5, Int64)
│    %9  = Base.sle_int(2, %8)::Bool
└───       goto #4 if not %9
3 ──       goto #5
4 ──       goto #5
5 ┄─ %13 = φ (#3 => %8, #4 => 1)::Int64%14 = %new(UnitRange{Int64}, 2, %13)::UnitRange{Int64}
└───       goto #6
6 ──       goto #7
7 ──       goto #9
8 ── %18 = (2:%5)::Any
└───       goto #9
9 ┄─ %20 = φ (#7 => %14, #8 => %18)::Any%21 = Base.getindex(%4, %20)::Any%22 = Base.iterate(%21)::Any%23 = (%22 === nothing)::Bool%24 = Base.not_int(%23)::Bool
└───       goto #12 if not %24
10%26 = φ (#9 => %22, #11 => %31)::Any%27 = φ (#9 => %3, #11 => %30)::Any%28 = Core.getfield(%26, 1)::Any%29 = Core.getfield(%26, 2)::Any%30 = (%28)(%27)::Any%31 = Base.iterate(%21, %29)::Any%32 = (%31 === nothing)::Bool%33 = Base.not_int(%32)::Bool
└───       goto #12 if not %33
11 ─       goto #10
12%36 = φ (#10 => %30, #9 => %3)::Any
└───       return %36
) => Any

BadComposed is not type-stable, i.e. the Julia compiler cannot infer the types of each instruction. The reason for this is because, as you see on the first line of the @code_typed output, bc.fs is a vector of Any. This in turn means that Julia can't really optimize much (it can no longer look up the implementation required for the input at compile-time, but has to do it at run-time), effectively falling back to Pyton-like performance.

Now, the compile-time for the above BadComposed is still slower than the g_many, but the scaling is different:

julia> g_super_many = reduce(, fill(f, 100))
f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f  f

julia> @time g_super_many(1)
  0.121201 seconds (486.78 k allocations: 31.423 MiB, 6.13% gc time, 99.61% compilation time)
1

julia> g_super_many_bad = BadComposed(fill(f, 100))
BadComposed([f, f, f, f, f, f, f, f, f, f    f, f, f, f, f, f, f, f, f, f])

julia> @time g_super_many_bad(1)
  0.000024 seconds (101 allocations: 1.625 KiB)
1

The difference is here that for g_super_many, which is of type ComposedFunction{ComposedFunction{...},typeof(f)} 100 times nested, the Julia compiler can still figure out the exact type of every operation in there, and so it specializes, i.e. it compiles specific machine code specifically for this 100 times nested ComposedFunction!

In constrast, it does not specialize for BadComposed because in that case it cannot determine the types and falls back to it's implementation for Any. But we already compiled that for the case of g_many_bad because it was the same case there!

That is:

  • g_many_bad and g_super_many_bad are compiles to the same code.
  • g_many and g_super_many compiles to two different implementations.
julia> @code_typed optimize=true g_super_many_bad(1)  # same as before!
CodeInfo(
1 ── %1  = Base.getfield(bc, :fs)::Any%2  = Base.getindex(%1, 1)::Any%3  = (%2)(x)::Any%4  = Base.getfield(bc, :fs)::Any%5  = Base.lastindex(%4)::Any%6  = (isa)(%5, Int64)::Bool
└───       goto #8 if not %6
2 ── %8  = π (%5, Int64)
│    %9  = Base.sle_int(2, %8)::Bool
└───       goto #4 if not %9
3 ──       goto #5
4 ──       goto #5
5 ┄─ %13 = φ (#3 => %8, #4 => 1)::Int64%14 = %new(UnitRange{Int64}, 2, %13)::UnitRange{Int64}
└───       goto #6
6 ──       goto #7
7 ──       goto #9
8 ── %18 = (2:%5)::Any
└───       goto #9
9 ┄─ %20 = φ (#7 => %14, #8 => %18)::Any%21 = Base.getindex(%4, %20)::Any%22 = Base.iterate(%21)::Any%23 = (%22 === nothing)::Bool%24 = Base.not_int(%23)::Bool
└───       goto #12 if not %24
10%26 = φ (#9 => %22, #11 => %31)::Any%27 = φ (#9 => %3, #11 => %30)::Any%28 = Core.getfield(%26, 1)::Any%29 = Core.getfield(%26, 2)::Any%30 = (%28)(%27)::Any%31 = Base.iterate(%21, %29)::Any%32 = (%31 === nothing)::Bool%33 = Base.not_int(%32)::Bool
└───       goto #12 if not %33
11 ─       goto #10
12%36 = φ (#10 => %30, #9 => %3)::Any
└───       return %36
) => Any

Now, there are some caveats to what I explained above, e.g. actually Julia doesn't completely inline the full 100 functions, but it will only "unroll" these sort of recursions at compile-time up to a certain point (still does infers the types though).

And no need to preface with "maybe this is a stupid question" or anything; when it comes to these things there really aren't any stupid questions:) Julia is simple to get up and running with, but it's also very complex in the sense that you can always keep digging and optimizing. I came mostly from a Python background myself, so I'm still trying to figure this stuff out even after 3 years of Julia :sweat_smiley:

@zuhengxu
Copy link
Member Author

In Julia, you can even inspect these things directly: ....

This is a ridiculously good explanation to my question and really elucidate many phenomenons I observed! Appreciate this awsome answer @torfjelde!

zuhengxu and others added 24 commits June 29, 2023 11:30
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome stuff @zuhengxu :)

@torfjelde
Copy link
Member

I'll let you do the honour of hitting the big green merge button

@zuhengxu zuhengxu merged commit 588c184 into TuringLang:main Jun 30, 2023
@cpfiffer
Copy link
Member

cpfiffer commented Jul 5, 2023

Really well done in here! I've been following a little bit and this looks really wonderful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants