Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# AdvancedHMC Changelog

## 0.8.4

- Introduces an experimental way to improve the *diagonal* mass matrix adaptation using gradient information (similar to [nutpie](https://github.com/pymc-devs/nutpie)),
currently to be initialized for a `metric` of type `DiagEuclideanMetric`
via `mma = AdvancedHMC.NutpieVar(size(metric); var=copy(metric.M⁻¹))`
until a new interface is introduced in an upcoming breaking release to specify the method of adaptation.

## 0.8.0

- To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`).
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.8.3"
version = "0.8.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
12 changes: 8 additions & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ where `ϵ` is the step size of leapfrog integration.
### Adaptor (`adaptor`)

- Adapt the mass matrix `metric` of the Hamiltonian dynamics: `mma = MassMatrixAdaptor(metric)`

+ This is lowered to `UnitMassMatrix`, `WelfordVar` or `WelfordCov` based on the type of the mass matrix `metric`
+ There is an experimental way to improve the *diagonal* mass matrix adaptation using gradient information (similar to [nutpie](https://github.com/pymc-devs/nutpie)),
currently to be initialized for a `metric` of type `DiagEuclideanMetric`
via `mma = AdvancedHMC.NutpieVar(size(metric); var=copy(metric.M⁻¹))`
until a new interface is introduced in an upcoming breaking release to specify the method of adaptation.

- Adapt the step size of the leapfrog integrator `integrator`: `ssa = StepSizeAdaptor(δ, integrator)`

+ It uses Nesterov's dual averaging with `δ` as the target acceptance rate.
- Combine the two above *naively*: `NaiveHMCAdaptor(mma, ssa)`
- Combine the first two using Stan's windowed adaptation: `StanHMCAdaptor(mma, ssa)`
Expand All @@ -60,12 +64,12 @@ sample(
Draw `n_samples` samples using the kernel `κ` under the Hamiltonian system `h`

- The randomness is controlled by `rng`.

+ If `rng` is not provided, the default random number generator (`Random.default_rng()`) will be used.

- The initial point is given by `θ`.
- The adaptor is set by `adaptor`, for which the default is no adaptation.

+ It will perform `n_adapts` steps of adaptation, for which the default is `1_000` or 10% of `n_samples`, whichever is lower.
- `drop_warmup` specifies whether to drop samples.
- `verbose` controls the verbosity.
Expand Down
3 changes: 2 additions & 1 deletion src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ export find_good_eps
include("adaptation/Adaptation.jl")
using .Adaptation
import .Adaptation:
StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation
StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation, PositionOrPhasePoint

# Helpers for initializing adaptors via AHMC structs

Expand Down Expand Up @@ -114,6 +114,7 @@ export StepSizeAdaptor,
MassMatrixAdaptor,
UnitMassMatrix,
WelfordVar,
NutpieVar,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you change your mind about whether to export anything? If yes, I think we should do a version bump and HISTORY.md entry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! I forgot, I had wanted to highlight this also to you because I was unsure about exactly what you mention!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What to mention in HISTORY.md you mean? A description of what's been changed (or in this case, added), why/what it does, and how/when to use it. The HISTORY.md entry can have a basic explanation of the gist of the change, the way you would explain it to someone who asked about it in person, e.g. through an example if that feels helpful. You can refer to the docstring for all the details of what all the optional arguments are etc.

The most important HISTORY.md entries are the ones where something is being broken/removed, there we try to give clear instructions for how to cope with the change, like how to change your code that uses the feature that is being removed. That's obviously not relevant here though.

Note that the AHMC HISTORY.md doc isn't suuuper detailed yet, but we would like to slowly improve this across TuringLang. Currently we keep detailed notes for DynamicPPL and Turing, and the other packages are quite variable.

For the version bump, I think this can be a patch version bump.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What to mention in HISTORY.md you mean?

Yeah, exactly, and also whether that would be needed and whether we'd do a version bump. Would we then also directly trigger registration? And how'd Turing.jl be affected downstream? 🤔 Just wondering what the usual workflow is here 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I've added a line to the HISTORY.md, but I doubt that many people would actually use this newly exported feature. For one, because most people are probably interacting with AdvancedHMC via Turing, but also because it seems to be a bit convoluted currently to switch out the default adaptation for this one, see e.g. #473 (comment).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do want to export NutpieVar then I would do the version bump and immediate release, because otherwise I kinda don't see the point of exporting it in the first place, since presumably that is done to give users access to it. I think Turing should be unaffected since this doesn't break the existing interface (hence bumping just the patch version is fine) and we probably wouldn't want to use this new feature in Turing (before the interface rework).

I think it's up to you to decide if this is something AHMC should ship to users straight away, even if using it is a clunky, or wait for the interface changes. I'm happy to merge either way, though I think the NutpieVar docstring may still have a mention saying that it isn't exported, so that would need harmonising.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to merge either way, though I think the NutpieVar docstring may still have a mention saying that it isn't exported, so that would need harmonising.

Oh, you're right!

WelfordCov,
NaiveHMCAdaptor,
StanHMCAdaptor,
Expand Down
2 changes: 1 addition & 1 deletion src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ function AbstractMCMC.step(

# Adapt h and spl.
tstat = stat(t)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))

# Compute next transition and state.
Expand Down
23 changes: 12 additions & 11 deletions src/adaptation/Adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ export Adaptation
using LinearAlgebra: LinearAlgebra
using Statistics: Statistics

using ..AdvancedHMC: AbstractScalarOrVec
using ..AdvancedHMC: AbstractScalarOrVec, PhasePoint
using DocStringExtensions

"""
$(TYPEDEF)

Abstract type for HMC adaptors.
Abstract type for HMC adaptors.
"""
abstract type AbstractAdaptor end
function getM⁻¹ end
Expand All @@ -21,12 +21,17 @@ function initialize! end
function finalize! end
export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹

get_position(x::PhasePoint) = x.θ
get_position(x::AbstractVecOrMat{<:AbstractFloat}) = x
const PositionOrPhasePoint = Union{AbstractVecOrMat{<:AbstractFloat}, PhasePoint}

struct NoAdaptation <: AbstractAdaptor end
export NoAdaptation
include("stepsize.jl")
export StepSizeAdaptor, NesterovDualAveraging

include("massmatrix.jl")
export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, WelfordCov
export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, NutpieVar, WelfordCov

##
## Composite adaptors
Expand All @@ -47,18 +52,14 @@ getϵ(ca::NaiveHMCAdaptor) = getϵ(ca.ssa)
# TODO: implement consensus adaptor
function adapt!(
nca::NaiveHMCAdaptor,
θ::AbstractVecOrMat{<:AbstractFloat},
z_or_theta::PositionOrPhasePoint,
α::AbstractScalarOrVec{<:AbstractFloat},
)
adapt!(nca.ssa, θ, α)
adapt!(nca.pc, θ, α)
return nothing
end
function reset!(aca::NaiveHMCAdaptor)
reset!(aca.ssa)
reset!(aca.pc)
adapt!(nca.ssa, z_or_theta, α)
adapt!(nca.pc, z_or_theta, α)
return nothing
end

initialize!(adaptor::NaiveHMCAdaptor, n_adapts::Int) = nothing
finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa)

Expand Down
105 changes: 95 additions & 10 deletions src/adaptation/massmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ finalize!(::MassMatrixAdaptor) = nothing

function adapt!(
adaptor::MassMatrixAdaptor,
θ::AbstractVecOrMat{<:AbstractFloat},
α::AbstractScalarOrVec{<:AbstractFloat},
z_or_theta::PositionOrPhasePoint,
::AbstractScalarOrVec{<:AbstractFloat},
is_update::Bool=true,
)
resize_adaptor!(adaptor, size(θ))
push!(adaptor, θ)
resize_adaptor!(adaptor, size(get_position(z_or_theta)))
push!(adaptor, z_or_theta)
is_update && update!(adaptor)
return nothing
end

Base.push!(a::MassMatrixAdaptor, z_or_theta::PositionOrPhasePoint) = push!(a, get_position(z_or_theta))

## Unit mass matrix adaptor

struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end
Expand All @@ -39,15 +41,14 @@ getM⁻¹(::UnitMassMatrix{T}) where {T} = LinearAlgebra.UniformScaling{T}(one(T

function adapt!(
::UnitMassMatrix,
::AbstractVecOrMat{<:AbstractFloat},
::PositionOrPhasePoint,
::AbstractScalarOrVec{<:AbstractFloat},
is_update::Bool=true,
)
return nothing
end

## Diagonal mass matrix adaptor

abstract type DiagMatrixEstimator{T} <: MassMatrixAdaptor end

getM⁻¹(ve::DiagMatrixEstimator) = ve.var
Expand All @@ -70,7 +71,7 @@ NaiveVar{T}(sz::Tuple{Int,Int}) where {T<:AbstractFloat} = NaiveVar(Vector{Matri

NaiveVar(sz::Union{Tuple{Int},Tuple{Int,Int}}) = NaiveVar{Float64}(sz)

Base.push!(nv::NaiveVar, s::AbstractVecOrMat) = push!(nv.S, s)
Base.push!(nv::NaiveVar, s::AbstractVecOrMat{<:AbstractFloat}) = push!(nv.S, s)

reset!(nv::NaiveVar) = resize!(nv.S, 0)

Expand Down Expand Up @@ -135,7 +136,7 @@ function reset!(wv::WelfordVar{T}) where {T<:AbstractFloat}
return nothing
end

function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T}
function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T<:AbstractFloat}
wv.n += 1
(; δ, μ, M, n) = wv
n = T(n)
Expand All @@ -153,6 +154,90 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat}
return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5))
end

"""
NutpieVar

Nutpie-style diagonal mass matrix estimator (using positions and gradients).

Expected to converge faster and to a better mass matrix than [`WelfordVar`](@ref), for which it is a drop-in replacement.

Can be initialized via `NutpieVar(sz)` where `sz` is either a `Tuple{Int}` or a `Tuple{Int,Int}`.

# Fields

$(FIELDS)
"""
mutable struct NutpieVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be mutable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that it doesn't have to be - but to adhere to the implicit internal interface, having it be mutable makes implementation easier. WelfordVar e.g. is also mutable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See e.g. here, which implies among other things the presence of a (mutable) n field.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function you linked should be fine. It does a comparison on ve.n and then conditionally mutates ve.var, but it does the mutation with ve.var .=, which doesn't require changing the value of the field ve.var by making it point to a new Array, but rather changes the elements within ve.var, for which ve itself doesn't have to be mutable. Note that an immutable type can have mutable objects as its field values.

If it is an interface demand that subtypes of DiagMatrixEstimator have to be mutable then I think that should be changed because

  • I think it's better to define interfaces based on functions and methods rather than particular fields of a type. So rather than say "there has to be field .var and we must be able to mutate it" we could say "there has to be function setvar!! with argument types blahblah".
  • Mutable types are generally harder to reason about and often slower, and thus good to avoid when possible.

However, that's probably out of scope for this PR.

Copy link
Contributor Author

@nsiccha nsiccha Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I mostly meant that ve.n has to be able to be mutated, because otherwise the conditional will always evaluate to the same value.

You can of course get around that by having some ...!! function acting on the adaptation, but I do find this a bit awkward and would prefer a solution where ve.n would just be a Ref (inside an otherwise non-mutable struct) and then the conditional would look something like nobs(ve) >= min_nobs(ve).

If it is an interface demand that subtypes of DiagMatrixEstimator have to be mutable then I think that should be changed.

I absolutely agree!

Edit: though I don't think it's a strict interface demand - for it was just so that the method using that interface was already defined (and would be used for my subtype), so I thought I'd just rely on that already present method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for it was just so that the method using that interface was already defined (and would be used for my subtype), so I thought I'd just rely on that already present method.

I think that's fair. Would raise this as something to change later for the whole abstact type though.

You can of course get around that by having some ...!! function acting on the adaptation, but I do find this a bit awkward and would prefer a solution where ve.n would just be a Ref (inside an otherwise non-mutable struct) and then the conditional would look something like nobs(ve) >= min_nobs(ve).

Do you have a particular reason for preferring the Ref? I have the opposite preference, of avoiding mutable types and fields whenever possible. (I do find a single Ref preferable to having the whole type be mutable though.) Partially because I've seen it help performance, and partially because I find it easier to trust and understand code with as little mutable state as possible. The only downside I see is that you need to construct a lot of objects like

new_foo = ImmutableFooType(old_foo.a, old_foo.b, c_the_only_field_that_has_changed, old_foo.d, old_foo.e)

but there's no performance penalty for that and the code can be simplified with

new_foo = Accessors.@set old_foo.c = c_the_only_field_that_has_changed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think it's mainly overkill and an overoptimization. Yes, it will probably be slightly more efficient, but it will also make interacting with the struct more awkward. For simple counters such as n in this case, a Ref is probably perfectly fine and not performance critical.

I do think the "better" use case for ...!! type functions is when you can clearly anticipate that the type will change from one call to the other.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see it as overkill, because I see it as being simpler than using a Ref, and don't find it awkward at all. I just always think mutability makes things harder to reason about, and I feel safe and secure when surrounded my immutable things. It's like transcending our grimy, contingent reality and living in Plato's world of pure forms.

I do think the "better" use case for ...!! type functions is when you can clearly anticipate that the type will change from one call to the other.

I love me a !! function. I think most ! functions should actually be !! functions, to leave it up to the implementer of the type to choose whether the internal data structures are mutating or not.

Anyway, this is a bigger design conversation than this PR, and warrants an issue for a proper discussion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, this is a bigger design conversation than this PR, and warrants an issue for a proper discussion.

I'm ready to die on this hill!

"Online variance estimator of the posterior positions."
position_estimator::WelfordVar{T,E,V}
"Online variance estimator of the posterior gradients."
gradient_estimator::WelfordVar{T,E,V}
"The number of observations collected so far."
n::Int
"The minimal number of observations after which the estimate of the variances can be updated."
n_min::Int
"The estimated variances - initialized to ones, updated after calling [`update!`](@ref) if `n > n_min`."
var::V
function NutpieVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::V) where {E,V}
return new{eltype(E),E,V}(
WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)),
WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)),
n, n_min, var
)
end
end

function Base.show(io::IO, ::NutpieVar{T}) where {T}
return print(io, "NutpieVar{", T, "} adaptor")
end
Comment on lines +190 to +192
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two-argument version of show should, according to the docs,

The representation used by show generally includes Julia-specific formatting and type information, and should be parseable Julia code when possible.

We break this rule all the time in TuringLang, so not too fussed about it, but I would still slightly prefer making a nice human readable version of show to be defined with the three-argument version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we then simultaneously also fix e.g. WelfordVar (which was my template)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, the current state of the show methods is due to #466.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to leave this as is, follow the current intra-package convention, and maybe open an issue about generally fixing our use of show. I didn't fully understand #466 at first glance, but it might be implying that something is calling print(x) when it should rather be calling something likeshow(io, MIME("text/plain"), x).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm also unsure about #466. I'd also be in favour of opening another issue to potentially address any "weird" implementations.


function NutpieVar{T}(
sz::Union{Tuple{Int},Tuple{Int,Int}}=(2,); n_min::Int=10, var=ones(T, sz)
) where {T<:AbstractFloat}
return NutpieVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var)
end

function NutpieVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...)
return NutpieVar{Float64}(sz; kwargs...)
end

function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int,Int}) where {T<:AbstractFloat}
if size_θ != size(nv.var)
@assert nv.n == 0 "Cannot resize a var estimator when it contains samples."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like something that could plausibly be hit sometimes. Could it a throw error rather than an @assert? From the docstring of @assert:

│ Warning

│ An assert might be disabled at some optimization levels. Assert should therefore only be used as a debugging tool and
│ not used for authentication verification (e.g., verifying passwords or checking array bounds). The code must not rely
│ on the side effects of running cond for the correct behavior of a function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that! I'd assume as before, we might then also want to fix WelfordVar and friends?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I would propose changing those too. Can be a different PR though, don't mean to turn this into a general code style refactor.

resize_adaptor!(nv.position_estimator, size_θ)
resize_adaptor!(nv.gradient_estimator, size_θ)
nv.var = ones(T, size_θ)
end
end

function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int}) where {T<:AbstractFloat}
length_θ = first(size_θ)
if length_θ != size(nv.var, 1)
@assert nv.n == 0 "Cannot resize a var estimator when it contains samples."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

resize_adaptor!(nv.position_estimator, size_θ)
resize_adaptor!(nv.gradient_estimator, size_θ)
fill!(resize!(nv.var, length_θ), T(1))
end
end

function reset!(nv::NutpieVar)
nv.n = 0
reset!(nv.position_estimator)
reset!(nv.gradient_estimator)
end

Base.push!(::NutpieVar, x::AbstractVecOrMat{<:AbstractFloat}) = error("`NutpieVar` adaptation requires position and gradient information!")

function Base.push!(nv::NutpieVar, z::PhasePoint)
nv.n += 1
push!(nv.position_estimator, z.θ)
push!(nv.gradient_estimator, z.ℓπ.gradient)
return nothing
end

# Ref: https://github.com/pymc-devs/nutpie
get_estimation(nv::NutpieVar) = sqrt.(get_estimation(nv.position_estimator) ./ get_estimation(nv.gradient_estimator))

## Dense mass matrix adaptor

abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end
Expand All @@ -175,7 +260,7 @@ end

NaiveCov{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveCov(Vector{Vector{T}}())

Base.push!(nc::NaiveCov, s::AbstractVector) = push!(nc.S, s)
Base.push!(nc::NaiveCov, s::AbstractVector{<:AbstractFloat}) = push!(nc.S, s)

reset!(nc::NaiveCov{T}) where {T} = resize!(nc.S, 0)

Expand Down Expand Up @@ -225,7 +310,7 @@ function reset!(wc::WelfordCov{T}) where {T<:AbstractFloat}
return nothing
end

function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T}
function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T<:AbstractFloat}
wc.n += 1
(; δ, μ, n, M) = wc
n = T(n)
Expand Down
8 changes: 4 additions & 4 deletions src/adaptation/stan_adaptor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,20 @@ is_window_end(a::StanHMCAdaptor) = a.state.i in a.state.window_splits

function adapt!(
tp::StanHMCAdaptor,
θ::AbstractVecOrMat{<:AbstractFloat},
z_or_theta::PositionOrPhasePoint,
α::AbstractScalarOrVec{<:AbstractFloat},
)
tp.state.i += 1

adapt!(tp.ssa, θ, α)
adapt!(tp.ssa, z_or_theta, α)

resize_adaptor!(tp.pc, size(θ)) # Resize pre-conditioner if necessary.
resize_adaptor!(tp.pc, size(get_position(z_or_theta))) # Resize pre-conditioner if necessary.

# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
if is_in_window(tp)
# We accumlate stats from θ online and only trigger the update of M⁻¹ in the end of window.
is_update_M⁻¹ = is_window_end(tp)
adapt!(tp.pc, θ, α, is_update_M⁻¹)
adapt!(tp.pc, z_or_theta, α, is_update_M⁻¹)
end

if is_window_end(tp)
Expand Down
4 changes: 2 additions & 2 deletions src/adaptation/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ end
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp
# Note: This function is not merged with `adapt!` to empahsize the fact that
# step size adaptation is not dependent on `θ`.
# Note 2: `da.state` and `α` support vectorised HMC but should do so together.
# Note 2: `da.state` and `α` support vectorised HMC but should do so together.
function adapt_stepsize!(
da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{T}
) where {T<:AbstractFloat}
Expand Down Expand Up @@ -211,7 +211,7 @@ end

function adapt!(
da::NesterovDualAveraging,
θ::AbstractVecOrMat{<:AbstractFloat},
::PositionOrPhasePoint,
α::AbstractScalarOrVec{<:AbstractFloat},
)
adapt_stepsize!(da, α)
Expand Down
Loading
Loading