Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vi should not work with Matrixvariate distributions #1545

Merged
merged 7 commits into from
Aug 29, 2022

Conversation

torfjelde
Copy link
Member

Issue

TuringLang/Bijectors.jl#169

This happens because Bijectors.Stacked is only being compatible with 0- and 1-dimensional inputs. To "fix" this I've added a Vec bijector which wraps any higher-than-1-dimensional bijector with a "flattening" operation.

Now works

julia> using Turing

julia> @model function demo(xs)
           Σ ~ LKJ(size(xs, 1), 1)

           for i = 1:size(xs, 2)
               xs[:, i] ~ MvNormal(Σ)
           end

           return Σ
       end;

julia> m = demo(randn(2, 100));

julia> # This now works
       q = Turing.Variational.meanfield(m);

julia> rand(q)
4-element Array{Float64,1}:
  1.0
 -0.43345701408941034
 -0.43345701408941034
  1.0000000000000002

julia> advi = ADVI(1, 1000)
ADVI{AdvancedVI.ForwardDiffAD{40}}(1, 1000)

julia> q = vi(m, advi);
┌ Info: [ADVI] Should only be seen once: optimizer created for θ
└   objectid(θ) = 0x8016f7ba5cf582c2
[ADVI] Optimizing...100% Time: 0:00:03

julia> rand(q)
4-element Array{Float64,1}:
 1.0
 0.07902135820833789
 0.07902135820833789
 1.0

Questions

  • I'm a bit split on the Vec bijector as it's a bit "hacky". Ideally we'd be able to specify both input- and output-dimensionality (in full generality, ideally spaces rather than just dimensionality, but this seems non-trivial), but we did not do this from the get-go due to additional complexity for both user and implementer. So should we keep Vec here as a way of only fixing this particular issue or should I move it to Bijectors.jl so it's available there to? I'm slightly in favour of the last option.

@coveralls
Copy link

coveralls commented Feb 15, 2021

Pull Request Test Coverage Report for Build 568545271

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 0 of 12 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.5%) to 73.839%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/variational/advi.jl 0 12 0.0%
Totals Coverage Status
Change from base Build 553759810: 0.5%
Covered Lines: 954
Relevant Lines: 1292

💛 - Coveralls

@codecov
Copy link

codecov bot commented Feb 15, 2021

Codecov Report

Merging #1545 (3219380) into master (5990fae) will decrease coverage by 0.72%.
The diff coverage is 27.77%.

@@            Coverage Diff             @@
##           master    #1545      +/-   ##
==========================================
- Coverage   82.21%   81.49%   -0.73%     
==========================================
  Files          21       21              
  Lines        1406     1421      +15     
==========================================
+ Hits         1156     1158       +2     
- Misses        250      263      +13     
Impacted Files Coverage Δ
src/variational/advi.jl 75.43% <27.77%> (-22.19%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@yebai
Copy link
Member

yebai commented Jul 5, 2022

@torfjelde is this related to the more general batching-issue?

@torfjelde
Copy link
Member Author

@torfjelde is this related to the more general batching-issue?

In a very indirect way, yeah. It's related because we want to drop the assumption that input and output are of the same shapes from Bijectors.jl, which we really need something like Batching.jl to do properly.

@yebai yebai merged commit ba33e48 into master Aug 29, 2022
@delete-merged-branch delete-merged-branch bot deleted the tor/vi-for-matrixvariate branch August 29, 2022 22:26
@yebai
Copy link
Member

yebai commented Aug 29, 2022

CI is passing after using the most recent Bijector release. Let's merge this PR for now since it's taking quite a bit of time already; the code only affects advi. Let's make further improvements in new PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants