-
Notifications
You must be signed in to change notification settings - Fork 5
implemented variational objectives #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
udpate sum divide to mean Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
change sum divide to mean Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
… into structuring
src/NormalizingFlows.jl
Outdated
# destruct flow for explicit access to the parameters | ||
# destructure can result in some overhead when the flow length is large | ||
@info "desctructuring flow..." | ||
θ_flat, re = Flux.destructure(flow) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 225 can be really slow as mentioned in line 24; I tried the flatten(x)
function written in https://github.com/TuringLang/AdvancedVI.jl/discussions/46#discussioncomment-5427443, which is also slow.
For reference, Flux.destructure(flow)
with flow being 20-layer 2d planar layers takes ridiculously long time---way longer than the training time. I was wondering is there is a better way of handling the explicit parameterization in terms of performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my reply to that comment:
For reference, the excution of the following code took over an hour and I just kill the program for the lack of patience
This doesn't have anything to do with the usage of destructure; it's because you're doing
∘(...)
.This then causes a
ComposedFunction
which isn_layers
nested (ComposedFunction(ComposedFunction(ComposedFunction(...)), ...)
; you get the gist). This blows up the compilation times.For "large" compilations, you should instead use something like https://github.com/oschulz/FunctionChains.jl. Then you could use a
Vector
for to represent the composition which would make the compilation time waaaaay better (because now "iterate" over the recursion at runtime rather than at compile-time).I also wonder if we could deal with this by subtyping
Function
due to https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues thinkingBut can discuss this issue more somewhere else +1
Two major concern so far on the current design:
I haven't digged into the “severe limitation and performance issue” about @sunxd3 @torfjelde what do you think of
using Bijectors, Flux, Random, Distributions, LinearAlgebra
function create_planar_flow(n_layers::Int, q₀)
d = length(q₀)
Ts = ∘([PlanarLayer(d) for _ in 1:n_layers]...)
return transformed(q₀, Ts)
end
flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I)) # create 20 layers planar flow
theta, re = Flux.destructure(flow) # this is sooooooo slow I wonder if we have better ways of handling the parameter flattening/unflattening, with better performance without lossing support to those AD systems (at least ForwardDiff, ReverseDiff, Zygote). |
As discussed there, I don't think it's quite there yet. We can try using https://github.com/SciML/ADTypes.jl mentioned in the other PR, and then we just define the |
@torfjelde Thanks for your comments about the compilation time for the nested
First to doublecheck my understanding on this issue. It's not the running time for this flattening (recursion through the nested compositions) being slow, but the compilation time. This must be a dumb question, but why the compilation time is THIS much slower than the run-time?
I'm also not quite clear on why subtyping can make the compilation faster in our case. (honestly I don't think I fully understand the Compiler efficiency issues---appology for my lack of knowldge on this) The link seems to explain why enforcing type for many functions can causes significant burst in compilation time. This issue is related because our flow layers are all typed functions. But I don't see why subtyping these flow layers helps. |
Another Q: should we just merge |
Hi Zuheng/David! I think you can worry about that later, since once Best wishes for your GSoC! |
So this would indeed be the idea:) But as @Red-Portal is saying, I think for now we just keep it separate until AdvancedVI's interface solidifies. Once that is done, we just replace the VI-related stuff in NormalizingFlows.jl with the interface from AdvancedVI. I'm very happy to have you involved in the discussion for AdvancedVI.jl though! I think you'll have some valuable insights into what we actually want from a "general variational inference"-interface. |
Generally speaking, the more work you do before you actually execute the code (compile-time), the faster the code will be (lower run-time), right? In Julia, you can even inspect these things directly: julia> f(x) = x^2
f (generic function with 1 method)
julia> @code_typed optimize=false f(1)
CodeInfo(
1 ─ %1 = Main.:^::Core.Const(^)
│ %2 = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│ %3 = (%2)()::Core.Const(Val{2}())
│ %4 = Base.literal_pow(%1, x, %3)::Int64
└── return %4
) => Int64
julia> @code_typed optimize=true f(1)
CodeInfo(
1 ─ %1 = Base.mul_int(x, x)::Int64
└── return %1
) => Int64 As you see there, when we do So what happens when you compose to functions? julia> g = f ∘ f
f ∘ f
julia> @code_typed optimize=false g(1)
CodeInfo(
1 ─ %1 = Base.:(var"#_#97")::Core.Const(Base.var"#_#97")
│ %2 = Core.NamedTuple()::Core.Const(NamedTuple())
│ %3 = Base.pairs(%2)::Core.Const(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())
│ %4 = Core.tuple(%3, c)::Core.Const((Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), f ∘ f))
│ %5 = Core._apply_iterate(Base.iterate, %1, %4, x)::Int64
└── return %5
) => Int64
julia> @code_typed optimize=true g(1)
CodeInfo(
1 ─ %1 = Core.getfield(x, 1)::Int64
│ %2 = Base.mul_int(%1, %1)::Int64
│ %3 = Base.mul_int(%2, %2)::Int64
└── return %3
) => Int64 Now, in the optimized version, all the calls to You can also time these things: julia> @time f(1) # <= doesn't say anything about compilation time because it's so fast
0.000001 seconds
1
julia> @time g(1) # <= now the compilation time is actually noticable
0.004762 seconds (789 allocations: 54.535 KiB, 99.77% compilation time)
1 Note that if I re-run the above, then there's no compilation time since it's already been compiled: julia> @time f(1)
0.000001 seconds
1
julia> @time g(1)
0.000002 seconds
1 Taking it a step further: julia> g_many = reduce(∘, fill(f, 10))
f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f
julia> @time g_many(1)
0.008233 seconds (31.46 k allocations: 2.070 MiB, 99.69% compilation time)
1 Compilation of Let's look at the resulting code: julia> @code_typed optimize=false g_many(1)
CodeInfo(
1 ─ %1 = Base.:(var"#_#97")::Core.Const(Base.var"#_#97")
│ %2 = Core.NamedTuple()::Core.Const(NamedTuple())
│ %3 = Base.pairs(%2)::Core.Const(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())
│ %4 = Core.tuple(%3, c)::Core.Const((Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f))
│ %5 = Core._apply_iterate(Base.iterate, %1, %4, x)::Int64
└── return %5
) => Int64
julia> @code_typed optimize=true g_many(1)
CodeInfo(
1 ─ %1 = Core.getfield(x, 1)::Int64
│ %2 = Base.mul_int(%1, %1)::Int64
│ %3 = Base.mul_int(%2, %2)::Int64
│ %4 = Base.mul_int(%3, %3)::Int64
│ %5 = Base.mul_int(%4, %4)::Int64
│ %6 = Base.mul_int(%5, %5)::Int64
│ %7 = Base.mul_int(%6, %6)::Int64
│ %8 = Base.mul_int(%7, %7)::Int64
│ %9 = Base.mul_int(%8, %8)::Int64
│ %10 = Base.mul_int(%9, %9)::Int64
│ %11 = Base.mul_int(%10, %10)::Int64
└── return %11
) => Int64 So in the optimized version, we have actually unrolled all the methods, and just boiled it down to a sequence of We can benchmark it: julia> using BenchmarkTools
julia> @btime $f($1)
2.523 ns (0 allocations: 0 bytes)
1
julia> @btime $g_many($1)
2.716 ns (0 allocations: 0 bytes)
1 So this is pretty impressive, right? The runtime is almost the same. We can compare this to a "bad" composed implementation. julia> struct BadComposed
fs
end
julia> function (bc::BadComposed)(x)
y = bc.fs[1](x)
for f in bc.fs[2:end]
y = f(y)
end
return y
end
julia>
julia> g_many_bad = BadComposed(fill(f, 10))
BadComposed([f, f, f, f, f, f, f, f, f, f])
julia> @time g_many_bad(1)
0.045895 seconds (34.42 k allocations: 2.282 MiB, 99.71% compilation time)
1
julia> @code_typed optimize=true g_many_bad(1)
CodeInfo(
1 ── %1 = Base.getfield(bc, :fs)::Any
│ %2 = Base.getindex(%1, 1)::Any
│ %3 = (%2)(x)::Any
│ %4 = Base.getfield(bc, :fs)::Any
│ %5 = Base.lastindex(%4)::Any
│ %6 = (isa)(%5, Int64)::Bool
└─── goto #8 if not %6
2 ── %8 = π (%5, Int64)
│ %9 = Base.sle_int(2, %8)::Bool
└─── goto #4 if not %9
3 ── goto #5
4 ── goto #5
5 ┄─ %13 = φ (#3 => %8, #4 => 1)::Int64
│ %14 = %new(UnitRange{Int64}, 2, %13)::UnitRange{Int64}
└─── goto #6
6 ── goto #7
7 ── goto #9
8 ── %18 = (2:%5)::Any
└─── goto #9
9 ┄─ %20 = φ (#7 => %14, #8 => %18)::Any
│ %21 = Base.getindex(%4, %20)::Any
│ %22 = Base.iterate(%21)::Any
│ %23 = (%22 === nothing)::Bool
│ %24 = Base.not_int(%23)::Bool
└─── goto #12 if not %24
10 ┄ %26 = φ (#9 => %22, #11 => %31)::Any
│ %27 = φ (#9 => %3, #11 => %30)::Any
│ %28 = Core.getfield(%26, 1)::Any
│ %29 = Core.getfield(%26, 2)::Any
│ %30 = (%28)(%27)::Any
│ %31 = Base.iterate(%21, %29)::Any
│ %32 = (%31 === nothing)::Bool
│ %33 = Base.not_int(%32)::Bool
└─── goto #12 if not %33
11 ─ goto #10
12 ┄ %36 = φ (#10 => %30, #9 => %3)::Any
└─── return %36
) => Any
Now, the compile-time for the above julia> g_super_many = reduce(∘, fill(f, 100))
f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f ∘ f
julia> @time g_super_many(1)
0.121201 seconds (486.78 k allocations: 31.423 MiB, 6.13% gc time, 99.61% compilation time)
1
julia> g_super_many_bad = BadComposed(fill(f, 100))
BadComposed([f, f, f, f, f, f, f, f, f, f … f, f, f, f, f, f, f, f, f, f])
julia> @time g_super_many_bad(1)
0.000024 seconds (101 allocations: 1.625 KiB)
1 The difference is here that for In constrast, it does not specialize for That is:
julia> @code_typed optimize=true g_super_many_bad(1) # same as before!
CodeInfo(
1 ── %1 = Base.getfield(bc, :fs)::Any
│ %2 = Base.getindex(%1, 1)::Any
│ %3 = (%2)(x)::Any
│ %4 = Base.getfield(bc, :fs)::Any
│ %5 = Base.lastindex(%4)::Any
│ %6 = (isa)(%5, Int64)::Bool
└─── goto #8 if not %6
2 ── %8 = π (%5, Int64)
│ %9 = Base.sle_int(2, %8)::Bool
└─── goto #4 if not %9
3 ── goto #5
4 ── goto #5
5 ┄─ %13 = φ (#3 => %8, #4 => 1)::Int64
│ %14 = %new(UnitRange{Int64}, 2, %13)::UnitRange{Int64}
└─── goto #6
6 ── goto #7
7 ── goto #9
8 ── %18 = (2:%5)::Any
└─── goto #9
9 ┄─ %20 = φ (#7 => %14, #8 => %18)::Any
│ %21 = Base.getindex(%4, %20)::Any
│ %22 = Base.iterate(%21)::Any
│ %23 = (%22 === nothing)::Bool
│ %24 = Base.not_int(%23)::Bool
└─── goto #12 if not %24
10 ┄ %26 = φ (#9 => %22, #11 => %31)::Any
│ %27 = φ (#9 => %3, #11 => %30)::Any
│ %28 = Core.getfield(%26, 1)::Any
│ %29 = Core.getfield(%26, 2)::Any
│ %30 = (%28)(%27)::Any
│ %31 = Base.iterate(%21, %29)::Any
│ %32 = (%31 === nothing)::Bool
│ %33 = Base.not_int(%32)::Bool
└─── goto #12 if not %33
11 ─ goto #10
12 ┄ %36 = φ (#10 => %30, #9 => %3)::Any
└─── return %36
) => Any
Now, there are some caveats to what I explained above, e.g. actually Julia doesn't completely inline the full 100 functions, but it will only "unroll" these sort of recursions at compile-time up to a certain point (still does infers the types though). And no need to preface with "maybe this is a stupid question" or anything; when it comes to these things there really aren't any stupid questions:) Julia is simple to get up and running with, but it's also very complex in the sense that you can always keep digging and optimizing. I came mostly from a Python background myself, so I'm still trying to figure this stuff out even after 3 years of Julia :sweat_smiley: |
This is a ridiculously good explanation to my question and really elucidate many phenomenons I observed! Appreciate this awsome answer @torfjelde! |
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
… into structuring
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
… into structuring
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
… into structuring
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
… into structuring
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
… into structuring
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
… into structuring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome stuff @zuhengxu :)
I'll let you do the honour of hitting the big green merge button |
Really well done in here! I've been following a little bit and this looks really wonderful. |
TODOs:
ELBO with mini-batch subsampling in(will work in next pr)logp
ADTypes.jl
example/