-
Notifications
You must be signed in to change notification settings - Fork 48
Implements a simple Nutpie style adaptation (using both positions and gradients, but not changing the schedule). #473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
caabd0f
e030eec
7cdea36
989d284
842fe06
0f27172
c395715
f1d1c80
76a1373
45a1915
3d47e8a
98c1ef7
5888c0f
ca680bb
4ee6cd4
e82cc19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,16 +9,18 @@ finalize!(::MassMatrixAdaptor) = nothing | |
|
|
||
| function adapt!( | ||
| adaptor::MassMatrixAdaptor, | ||
| θ::AbstractVecOrMat{<:AbstractFloat}, | ||
| α::AbstractScalarOrVec{<:AbstractFloat}, | ||
| z_or_theta::PositionOrPhasePoint, | ||
| ::AbstractScalarOrVec{<:AbstractFloat}, | ||
| is_update::Bool=true, | ||
| ) | ||
| resize_adaptor!(adaptor, size(θ)) | ||
| push!(adaptor, θ) | ||
| resize_adaptor!(adaptor, size(get_position(z_or_theta))) | ||
| push!(adaptor, z_or_theta) | ||
| is_update && update!(adaptor) | ||
| return nothing | ||
| end | ||
|
|
||
| Base.push!(a::MassMatrixAdaptor, z_or_theta::PositionOrPhasePoint) = push!(a, get_position(z_or_theta)) | ||
|
|
||
| ## Unit mass matrix adaptor | ||
|
|
||
| struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end | ||
|
|
@@ -39,15 +41,14 @@ getM⁻¹(::UnitMassMatrix{T}) where {T} = LinearAlgebra.UniformScaling{T}(one(T | |
|
|
||
| function adapt!( | ||
| ::UnitMassMatrix, | ||
| ::AbstractVecOrMat{<:AbstractFloat}, | ||
| ::PositionOrPhasePoint, | ||
| ::AbstractScalarOrVec{<:AbstractFloat}, | ||
| is_update::Bool=true, | ||
| ) | ||
| return nothing | ||
| end | ||
|
|
||
| ## Diagonal mass matrix adaptor | ||
|
|
||
| abstract type DiagMatrixEstimator{T} <: MassMatrixAdaptor end | ||
|
|
||
| getM⁻¹(ve::DiagMatrixEstimator) = ve.var | ||
|
|
@@ -70,7 +71,7 @@ NaiveVar{T}(sz::Tuple{Int,Int}) where {T<:AbstractFloat} = NaiveVar(Vector{Matri | |
|
|
||
| NaiveVar(sz::Union{Tuple{Int},Tuple{Int,Int}}) = NaiveVar{Float64}(sz) | ||
|
|
||
| Base.push!(nv::NaiveVar, s::AbstractVecOrMat) = push!(nv.S, s) | ||
| Base.push!(nv::NaiveVar, s::AbstractVecOrMat{<:AbstractFloat}) = push!(nv.S, s) | ||
|
|
||
| reset!(nv::NaiveVar) = resize!(nv.S, 0) | ||
|
|
||
|
|
@@ -135,7 +136,7 @@ function reset!(wv::WelfordVar{T}) where {T<:AbstractFloat} | |
| return nothing | ||
| end | ||
|
|
||
| function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T} | ||
| function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T<:AbstractFloat} | ||
| wv.n += 1 | ||
| (; δ, μ, M, n) = wv | ||
| n = T(n) | ||
|
|
@@ -153,6 +154,90 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat} | |
| return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5)) | ||
| end | ||
|
|
||
| """ | ||
| NutpieVar | ||
|
|
||
| Nutpie-style diagonal mass matrix estimator (using positions and gradients). | ||
|
|
||
| Expected to converge faster and to a better mass matrix than [`WelfordVar`](@ref), for which it is a drop-in replacement. | ||
|
|
||
| Can be initialized via `NutpieVar(sz)` where `sz` is either a `Tuple{Int}` or a `Tuple{Int,Int}`. | ||
|
|
||
| # Fields | ||
|
|
||
| $(FIELDS) | ||
| """ | ||
| mutable struct NutpieVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this have to be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought that it doesn't have to be - but to adhere to the implicit internal interface, having it be mutable makes implementation easier.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See e.g. here, which implies among other things the presence of a (mutable)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the function you linked should be fine. It does a comparison on If it is an interface demand that subtypes of
However, that's probably out of scope for this PR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh right, I mostly meant that You can of course get around that by having some
I absolutely agree! Edit: though I don't think it's a strict interface demand - for it was just so that the method using that interface was already defined (and would be used for my subtype), so I thought I'd just rely on that already present method.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think that's fair. Would raise this as something to change later for the whole abstact type though.
Do you have a particular reason for preferring the new_foo = ImmutableFooType(old_foo.a, old_foo.b, c_the_only_field_that_has_changed, old_foo.d, old_foo.e)but there's no performance penalty for that and the code can be simplified with new_foo = Accessors.@set old_foo.c = c_the_only_field_that_has_changed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm, I think it's mainly overkill and an overoptimization. Yes, it will probably be slightly more efficient, but it will also make interacting with the struct more awkward. For simple counters such as I do think the "better" use case for
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see it as overkill, because I see it as being simpler than using a
I love me a Anyway, this is a bigger design conversation than this PR, and warrants an issue for a proper discussion.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm ready to die on this hill!
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "Online variance estimator of the posterior positions." | ||
| position_estimator::WelfordVar{T,E,V} | ||
| "Online variance estimator of the posterior gradients." | ||
| gradient_estimator::WelfordVar{T,E,V} | ||
| "The number of observations collected so far." | ||
| n::Int | ||
| "The minimal number of observations after which the estimate of the variances can be updated." | ||
| n_min::Int | ||
| "The estimated variances - initialized to ones, updated after calling [`update!`](@ref) if `n > n_min`." | ||
| var::V | ||
| function NutpieVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::V) where {E,V} | ||
| return new{eltype(E),E,V}( | ||
| WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)), | ||
| WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)), | ||
| n, n_min, var | ||
| ) | ||
| end | ||
| end | ||
|
|
||
| function Base.show(io::IO, ::NutpieVar{T}) where {T} | ||
| return print(io, "NutpieVar{", T, "} adaptor") | ||
| end | ||
|
Comment on lines
+190
to
+192
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The two-argument version of
We break this rule all the time in TuringLang, so not too fussed about it, but I would still slightly prefer making a nice human readable version of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we then simultaneously also fix e.g.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interestingly, the current state of the show methods is due to #466.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm happy to leave this as is, follow the current intra-package convention, and maybe open an issue about generally fixing our use of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I'm also unsure about #466. I'd also be in favour of opening another issue to potentially address any "weird" implementations. |
||
|
|
||
| function NutpieVar{T}( | ||
| sz::Union{Tuple{Int},Tuple{Int,Int}}=(2,); n_min::Int=10, var=ones(T, sz) | ||
| ) where {T<:AbstractFloat} | ||
| return NutpieVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var) | ||
| end | ||
|
|
||
| function NutpieVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) | ||
| return NutpieVar{Float64}(sz; kwargs...) | ||
| end | ||
|
|
||
| function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int,Int}) where {T<:AbstractFloat} | ||
| if size_θ != size(nv.var) | ||
| @assert nv.n == 0 "Cannot resize a var estimator when it contains samples." | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like something that could plausibly be hit sometimes. Could it a throw error rather than an
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't know that! I'd assume as before, we might then also want to fix
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, I would propose changing those too. Can be a different PR though, don't mean to turn this into a general code style refactor. |
||
| resize_adaptor!(nv.position_estimator, size_θ) | ||
| resize_adaptor!(nv.gradient_estimator, size_θ) | ||
| nv.var = ones(T, size_θ) | ||
| end | ||
| end | ||
|
|
||
| function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int}) where {T<:AbstractFloat} | ||
| length_θ = first(size_θ) | ||
| if length_θ != size(nv.var, 1) | ||
| @assert nv.n == 0 "Cannot resize a var estimator when it contains samples." | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. |
||
| resize_adaptor!(nv.position_estimator, size_θ) | ||
| resize_adaptor!(nv.gradient_estimator, size_θ) | ||
| fill!(resize!(nv.var, length_θ), T(1)) | ||
| end | ||
| end | ||
|
|
||
| function reset!(nv::NutpieVar) | ||
| nv.n = 0 | ||
| reset!(nv.position_estimator) | ||
| reset!(nv.gradient_estimator) | ||
| end | ||
|
|
||
| Base.push!(::NutpieVar, x::AbstractVecOrMat{<:AbstractFloat}) = error("`NutpieVar` adaptation requires position and gradient information!") | ||
|
|
||
| function Base.push!(nv::NutpieVar, z::PhasePoint) | ||
| nv.n += 1 | ||
| push!(nv.position_estimator, z.θ) | ||
| push!(nv.gradient_estimator, z.ℓπ.gradient) | ||
| return nothing | ||
| end | ||
|
|
||
| # Ref: https://github.com/pymc-devs/nutpie | ||
| get_estimation(nv::NutpieVar) = sqrt.(get_estimation(nv.position_estimator) ./ get_estimation(nv.gradient_estimator)) | ||
|
|
||
| ## Dense mass matrix adaptor | ||
|
|
||
| abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end | ||
|
|
@@ -175,7 +260,7 @@ end | |
|
|
||
| NaiveCov{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveCov(Vector{Vector{T}}()) | ||
|
|
||
| Base.push!(nc::NaiveCov, s::AbstractVector) = push!(nc.S, s) | ||
| Base.push!(nc::NaiveCov, s::AbstractVector{<:AbstractFloat}) = push!(nc.S, s) | ||
|
|
||
| reset!(nc::NaiveCov{T}) where {T} = resize!(nc.S, 0) | ||
|
|
||
|
|
@@ -225,7 +310,7 @@ function reset!(wc::WelfordCov{T}) where {T<:AbstractFloat} | |
| return nothing | ||
| end | ||
|
|
||
| function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T} | ||
| function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T<:AbstractFloat} | ||
| wc.n += 1 | ||
| (; δ, μ, n, M) = wc | ||
| n = T(n) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you change your mind about whether to export anything? If yes, I think we should do a version bump and HISTORY.md entry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! I forgot, I had wanted to highlight this also to you because I was unsure about exactly what you mention!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What to mention in HISTORY.md you mean? A description of what's been changed (or in this case, added), why/what it does, and how/when to use it. The HISTORY.md entry can have a basic explanation of the gist of the change, the way you would explain it to someone who asked about it in person, e.g. through an example if that feels helpful. You can refer to the docstring for all the details of what all the optional arguments are etc.
The most important HISTORY.md entries are the ones where something is being broken/removed, there we try to give clear instructions for how to cope with the change, like how to change your code that uses the feature that is being removed. That's obviously not relevant here though.
Note that the AHMC HISTORY.md doc isn't suuuper detailed yet, but we would like to slowly improve this across TuringLang. Currently we keep detailed notes for DynamicPPL and Turing, and the other packages are quite variable.
For the version bump, I think this can be a patch version bump.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, exactly, and also whether that would be needed and whether we'd do a version bump. Would we then also directly trigger registration? And how'd Turing.jl be affected downstream? 🤔 Just wondering what the usual workflow is here 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I've added a line to the HISTORY.md, but I doubt that many people would actually use this newly exported feature. For one, because most people are probably interacting with AdvancedHMC via Turing, but also because it seems to be a bit convoluted currently to switch out the default adaptation for this one, see e.g. #473 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do want to export
NutpieVarthen I would do the version bump and immediate release, because otherwise I kinda don't see the point of exporting it in the first place, since presumably that is done to give users access to it. I think Turing should be unaffected since this doesn't break the existing interface (hence bumping just the patch version is fine) and we probably wouldn't want to use this new feature in Turing (before the interface rework).I think it's up to you to decide if this is something AHMC should ship to users straight away, even if using it is a clunky, or wait for the interface changes. I'm happy to merge either way, though I think the NutpieVar docstring may still have a mention saying that it isn't exported, so that would need harmonising.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you're right!