Skip to content

Add Stochastic Gradient HMC #428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# AdvancedHMC Changelog

## 0.9.0

- Stochastic gradient based methods `SGHMC` and `SGLD` are supported in AdvancedHMC.jl, please note there are similar methods with the same name in Turing.jl, so when using the two packages together, please specify the package exporting the method.

## 0.8.0

- To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`).
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.8.0"
version = "0.9.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"

[compat]
AdvancedHMC = "0.8"
AdvancedHMC = "0.9"
Documenter = "1"
DocumenterCitations = "1"
DocumenterCitations = "1"
2 changes: 1 addition & 1 deletion src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
@@ -125,7 +125,7 @@ include("sampler.jl")
export sample

include("constructors.jl")
export HMCSampler, HMC, NUTS, HMCDA
export HMCSampler, HMC, NUTS, HMCDA, SGHMC

include("abstractmcmc.jl")

122 changes: 122 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -205,6 +205,120 @@
return Transition(t.z, tstat), newstate
end

struct SGHMCState{
TTrans<:Transition,
TMetric<:AbstractMetric,
TKernel<:AbstractMCMCKernel,
TAdapt<:Adaptation.AbstractAdaptor,
T<:AbstractVector{<:Real},
}
"Index of current iteration."
i::Int
"Current [`Transition`](@ref)."
transition::TTrans
"Current [`AbstractMetric`](@ref), possibly adapted."
metric::TMetric
"Current [`AbstractMCMCKernel`](@ref)."
κ::TKernel
"Current [`AbstractAdaptor`](@ref)."
adaptor::TAdapt
velocity::T
end
getadaptor(state::SGHMCState) = state.adaptor
getmetric(state::SGHMCState) = state.metric
getintegrator(state::SGHMCState) = state.κ.τ.integrator

Check warning on line 229 in src/abstractmcmc.jl

Codecov / codecov/patch

src/abstractmcmc.jl#L227-L229

Added lines #L227 - L229 were not covered by tests

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
spl::SGHMC;
initial_params=nothing,
kwargs...,
)
# Unpack model
logdensity = model.logdensity

# Define metric
metric = make_metric(spl, logdensity)

# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model)

# Compute initial sample and state.
initial_params = make_initial_params(rng, spl, logdensity, initial_params)
ϵ = make_step_size(rng, spl, hamiltonian, initial_params)
integrator = make_integrator(spl, ϵ)

# Make kernel
κ = make_kernel(spl, integrator)

# Make adaptor
adaptor = make_adaptor(spl, metric, integrator)

# Get an initial sample.
h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params)

state = SGHMCState(0, t, metric, κ, adaptor, initial_params)

return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::AbstractMCMC.LogDensityModel,
spl::SGHMC,
state::SGHMCState;
n_adapts::Int=0,
kwargs...,
)
if haskey(kwargs, :nadapts)
throw(

Check warning on line 275 in src/abstractmcmc.jl

Codecov / codecov/patch

src/abstractmcmc.jl#L275

Added line #L275 was not covered by tests
ArgumentError(
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
),
)
end

i = state.i + 1
t_old = state.transition
adaptor = state.adaptor
κ = state.κ
metric = state.metric

# Reconstruct hamiltonian.
h = Hamiltonian(metric, model)

# Compute gradient of log density.
logdensity_and_gradient = Base.Fix1(
LogDensityProblems.logdensity_and_gradient, model.logdensity
)
θ = copy(t_old.z.θ)
grad = last(logdensity_and_gradient(θ))

# Update latent variables and velocity according to
# equation (15) of Chen et al. (2014)
v = state.velocity
η = spl.learning_rate
α = spl.momentum_decay
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
θ .+= newv

# Make new transition.
z = phasepoint(h, θ, v)
t = transition(rng, h, κ, z)

# Adapt h and spl.
tstat = stat(t)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))

# Compute next sample and state.
sample = Transition(t.z, tstat)
newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)

return sample, newstate
end

################
### Callback ###
################
@@ -392,6 +506,10 @@
return NoAdaptation()
end

function make_adaptor(spl::SGHMC, metric::AbstractMetric, integrator::AbstractIntegrator)
return NoAdaptation()
end

function make_adaptor(
spl::HMCSampler, metric::AbstractMetric, integrator::AbstractIntegrator
)
@@ -417,3 +535,7 @@
function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator)
return spl.κ
end

function make_kernel(spl::SGHMC, integrator::AbstractIntegrator)
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(spl.n_leapfrog)))
end
45 changes: 45 additions & 0 deletions src/constructors.jl
Original file line number Diff line number Diff line change
@@ -163,3 +163,48 @@ function HMCDA(δ, λ; integrator=:leapfrog, metric=:diagonal)
end

sampler_eltype(::HMCDA{T}) where {T} = T

########### Static Hamiltonian Monte Carlo ###########

#############
### SGHMC ###
#############
"""
SGHMC(learning_rate::Real, momentun_decay::Real, integrator = :leapfrog, metric = :diagonal)

Stochastic Gradient Hamiltonian Monte Carlo sampler

# Fields

$(FIELDS)

# Notes

For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1402.4102)):

- Chen, Tianqi, Emily Fox, and Carlos Guestrin. "Stochastic gradient hamiltonian monte carlo." International conference on machine learning. PMLR, 2014.
"""
struct SGHMC{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
AbstractHMCSampler
"Learning rate for the gradient descent."
learning_rate::T
"Momentum decay rate."
momentum_decay::T
"Number of leapfrog steps."
n_leapfrog::Int
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::I
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::M
end

function SGHMC(
learning_rate, momentum_decay, n_leapfrog; integrator=:leapfrog, metric=:diagonal
)
T = determine_sampler_eltype(
learning_rate, momentum_decay, n_leapfrog, integrator, metric
)
return SGHMC(T(learning_rate), T(momentum_decay), n_leapfrog, integrator, metric)
end

sampler_eltype(::SGHMC{T}) where {T} = T
24 changes: 24 additions & 0 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ using Statistics: mean
nuts = NUTS(0.8)
hmc = HMC(100; integrator=Leapfrog(0.05))
hmcda = HMCDA(0.8, 0.1)
sghmc = SGHMC(0.01, 0.1, 100)

integrator = Leapfrog(1e-3)
κ = AdvancedHMC.make_kernel(nuts, integrator)
@@ -111,6 +112,29 @@ using Statistics: mean

@test m_est_hmc ≈ [49 / 24, 7 / 6] atol = RNDATOL

samples_sghmc = AbstractMCMC.sample(
rng,
model,
sghmc,
n_adapts + n_samples;
n_adapts=n_adapts,
initial_params=θ_init,
progress=false,
verbose=false,
)

# Transform back to original space.
# NOTE: We're not correcting for the `logabsdetjac` here since, but
# we're only interested in the mean it doesn't matter.
for t in samples_sghmc
t.z.θ .= invlink_gdemo(t.z.θ)
end
m_est_sghmc = mean(samples_sghmc) do t
t.z.θ
end

@test m_est_sghmc ≈ [49 / 24, 7 / 6] atol = RNDATOL

samples_custom = AbstractMCMC.sample(
rng,
model,
Loading
Oops, something went wrong.