Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for linking distributions with embedded support #461

Closed
torfjelde opened this issue Feb 6, 2023 · 10 comments
Closed

Support for linking distributions with embedded support #461

torfjelde opened this issue Feb 6, 2023 · 10 comments

Comments

@torfjelde
Copy link
Member

For certain distributions, the random variable represented by the Distribution has support which is lower-dimensional than the return-type indicates; that is, the returned realizations are embedded in a higher dimensional space.

For example, LKJ is a distribution over correlation-matrices. Correlation matrices are required to be positive-definite (PD) and have 1 along the diagonal. PD means that we only have (n choose 2) + n degrees of freedom, and 1 along the diagonal removes the additional factor of n, leaving us with only (n choose 2) degrees of freedom. That is, as a vector space, the dimension of the correlation-matrices is actually just (n choose 2), not n × n as might be indicated by the returned Matrix{Float64} from rand(::LKJ)!

For SimpleVarInfo, this is trivial to support because SimpleVarInfo only contains the realizations themselves, no information related to the distributions, etc. Therefore, with something like TuringLang/Bijectors.jl#246, things just work

julia> using DynamicPPL, Distributions, Bijectors

julia> # Switch the bijector used to the `VecCorrBijector` from the forementioned PR.
       Bijectors.bijector(::LKJ) = Bijectors.VecCorrBijector();

julia> @model demo() = x ~ LKJ(3, 1);

julia> model = demo();

julia> vi = SimpleVarInfo(model);

julia> # Now it's a matrix.
       vi[@varname(x)]
3×3 Matrix{Float64}:
  1.0         -0.00803721  -0.849602
 -0.00803721   1.0          0.00190424
 -0.849602     0.00190424   1.0

julia> vi_transformed = link!!(vi, model);

julia> # Now it's a vector.
       vi_transformed[@varname(x)]
3-element Vector{Float64}:
 -0.00803738468434096
 -1.2547213956880081
 -0.0093368799126288

julia> logjoint(model, vi_transformed)  # (✓) Works!
-3.515748926181343

In contrast, with VarInfo things are not so simple:

julia> vi = VarInfo(model);

julia> vi[@varname(x)]
3×3 Matrix{Float64}:
 1.0        0.382085   0.607741
 0.382085   1.0       -0.173265
 0.607741  -0.173265   1.0

julia> vi_transformed = link!!(vi, model);
ERROR: DimensionMismatch: tried to assign 3 elements to 9 destinations
Stacktrace:
...

With VarInfo there are multiple challenges:

  1. link!! occurs in-place and expects the same shape as the original (untransformed) value.
  2. getindex(vi, vn, dist) uses reconstruct(dist, val) to reshape the underlying flattened representation in VarInfo to what dist expects. This is done before passing it to the bijector/transformation, and so we if we're working with a Vector (because we're in transformed space), then call reconstruct(dist, val::Vector) we get back a Matrix aaaand the inverse transformation, which expects a Vector, fails. We could start looking into potentially adding the transformation used to the reconstruct call, i.e. letting (dist, transform)-pairs define the reconstruct rather than just dist, but then the problem is that in VarInfo whether a variable is transformed or not is decided at runtime, which in turn causes type-instabilities (reconstruct would then return Vector in some cases and Matrix in others, decided upon at runtime).

So. We need a good way of doing this with VarInfo and I figured I'd make an issue so we can discuss this in more detail together.

@devmotion
Copy link
Member

Regarding the first point: Maybe we could just link!! in place but make sure that we fix the problematic call that causes the DimensionMismatch and instead replace it with something like copyto! which will only update the first few elements (but not necessarily all). AFAICT the number of elements that link!! wants to set is always less than or equal to the number of elements in the unlinked representation, so that should always work. Additionally, I think we need something (an additional field that saves the length of the link!!ed objects?) to make sure that when extracting stuff (and un-linking) from the linked object we only retrieve the first m elements but not all.

@torfjelde
Copy link
Member Author

So that's easy enough to do! We can just change the corresponding range prior to calling setval! in _link!, i.e. change

y, logjac = with_logabsdet_jacobian(b, x)
setval!(vi, vectorize(dist, y), vn)

But this still leaves the problem of getindex.

Because VarInfo stores everything in a flattened representation, we need to reconstruct. But now we might want the flattened representation (or any other representation for that matter) in the step between "getting the raw value from VarInfo" and then "returning the transformed value in the correct format" in getindex(vi, val, dist). But there's no way to do this in a type-stable manner because istrans(vi, vn) is determined at runtime.

@torfjelde
Copy link
Member Author

There might be a way if we move the reconstruct into the link/invlink somehow, but all of this is quickly becoming very nasty.

@yebai
Copy link
Member

yebai commented Feb 6, 2023

I think it is ok if we stop adding new features to VarInfo but only support the new LKJ transform for SimpleVarInfo.

@torfjelde
Copy link
Member Author

But the current VarInfo behavior is actually a bug, right? 😕

@yebai
Copy link
Member

yebai commented Feb 6, 2023

I understand that we can default to SimpleVarInfo for all HMC samplers, right? If so, the users won't use VarInfo unless they explicitly choose so.

@torfjelde
Copy link
Member Author

I understand that we can default to SimpleVarInfo for all HMC samplers, right? If so, the users won't use VarInfo unless they explicitly choose so.

Not when used with Gibbs 😕 Plus SimpleVarInfo is not as performant as VarInfo when the LHS of the ~ aren't just a single variable, i.e. no indexing, etc. (unless you seed SimpleVarInfo with the values by hand)

@yebai
Copy link
Member

yebai commented Feb 6, 2023

Not when used with Gibbs

We can use a combination of SimpleVarInfo + VarInfo in Gibbs if I understand all pieces correctly, that is

  • init step: run the model once to get UntypedVarInfo / SimpleVarInfo{Dict}
  • Gibbs step: for MH and HMC steps, we can construct a SimpleVarInfo from the init step's returned UntypedVarInfo, and then treat everything else as observed data via the condition API.
  • Gibbs step: for other samplers like PG we keep using VarInfo, by treating everything else as data

Plus SimpleVarInfo is not as performant as VarInfo when the LHS of the ~ aren't just a single variable, i.e. no indexing, etc. (unless you seed SimpleVarInfo with the values by hand)

This is probably fine for now but can be improved later via tracing the model once to seed SimpleVarInfo.

@torfjelde
Copy link
Member Author

we can construct a SimpleVarInfo from the init step's returned UntypedVarInfo

Not possible 😕 (unless we write some pretty significant pieces of code to reconstruct container values from a seuqence of VarName)

bors bot pushed a commit that referenced this issue Apr 26, 2023
Attempt at addressing #461.

I think the approach here is somewhat correct, but it's currently very dirty because of the intricacies of `VarInfo` implementation. This can be cleaned up, but will take some effort. Until I've done this, I leave this as a draft PR.

We also most certainly need to do some benchmarking before merging this as it could lead to some additional overhead.

NOTE: This is based on #457 

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
@yebai
Copy link
Member

yebai commented Aug 4, 2023

Fixed by #462

@yebai yebai closed this as completed Aug 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants