Skip to content

Conversation

@sethaxen
Copy link
Member

As discussed on Slack, it would be very useful to expose more AdvancedHMC configuration options to the NUTS, HMC, and HMCDA convenience constructors here. In particular, it would be nice to be able to customize the integrator, metric (values), adaptor, or termination criterion.

As far as I can tell, the main challenge here is that the initial step size and metric are model-dependent, while NUTS etc should be model-independent. The approach this PR takes is to introduce default metric types that, if not replaced by a specific user-provided object, are internally constructed using the model.

Currently this is only set up for NUTS. If this seems like a useful general approach to other devs, I'll continue with the HMC and HMCDA samplers.

example
julia> using Turing, AdvancedHMC

julia> @model function foo()
           x ~ filldist(Normal() * 1000, 10)
       end
foo (generic function with 2 methods)

julia> chns = sample(foo(), Turing.NUTS(), 1000; save_state=true)  # same defaults
┌ Info: Found initial step size
└   ϵ = 1638.4
Sampling 100%|████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×22×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 0.49 seconds
Compute duration  = 0.49 seconds
parameters        = x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean         std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol    Float64     Float64    Float64   Float64     Float64   Float64       Float64 

        x[1]   -47.6449   1016.1548    32.1336   25.7694   1412.7727    0.9990     2895.0259
        x[2]    -3.2566   1022.3188    32.3286   21.1134   2216.4932    0.9990     4541.9943
        x[3]    14.5863   1031.0343    32.6042   27.1184   1719.4190    0.9993     3523.3996
        x[4]   -15.5568   1003.9893    31.7489   22.8348   1227.8352    0.9992     2516.0558
        x[5]    10.5619   1012.4916    32.0178   29.7603   1304.7482    0.9990     2673.6643
        x[6]    51.3310   1054.6483    33.3509   23.1870   1410.7269    0.9992     2890.8337
        x[7]   -11.5721    969.3929    30.6549   25.1998   1729.1101    0.9992     3543.2585
        x[8]   -18.5535   1077.2640    34.0661   31.5005   1439.2385    1.0008     2949.2593
        x[9]   -10.6400    999.5810    31.6095   26.3672   1299.9401    0.9990     2663.8117
       x[10]   -15.7143    999.1795    31.5968   26.6365   1511.8314    0.9991     3098.0152

Quantiles
  parameters         2.5%       25.0%      50.0%      75.0%       97.5% 
      Symbol      Float64     Float64    Float64    Float64     Float64 

        x[1]   -2045.2328   -737.6996   -58.0066   601.9530   1963.7858
        x[2]   -1934.4399   -719.0586   -11.8093   700.2544   1897.8098
        x[3]   -2058.7987   -686.3523   -27.1321   742.4192   2113.6651
        x[4]   -2038.9502   -687.3642     3.0949   637.3003   1958.2605
        x[5]   -1970.9327   -673.6777    25.0485   707.0809   2029.6881
        x[6]   -1958.0350   -663.9977    93.2271   798.4898   1994.8763
        x[7]   -1959.5376   -644.9694     7.7956   617.2869   1993.9105
        x[8]   -2101.6400   -778.0808   -37.4446   742.2334   2099.6595
        x[9]   -2111.1562   -687.0762     1.1653   615.2214   1982.0659
       x[10]   -2003.7428   -671.0353   -28.7991   654.2518   1913.1991


julia> chns.info.samplerstate.hamiltonian.metric  # default metric is adapted
DiagEuclideanMetric([1.0351872421358126e6, 1.11 ...])

julia> chns.info.samplerstate.kernel.τ.integrator  # default integrator's step size is adapted
Leapfrog=0.933)

julia> integrator = JitteredLeapfrog(0.2, 0.1)
JitteredLeapfrog(ϵ0=0.2, jitter=0.1, ϵ=0.2)

julia> chns = sample(foo(), Turing.NUTS(; integrator, metricT=AdvancedHMC.DenseEuclideanMetric, adaptor=StepSizeAdaptor(0.8, integrator)), 1000; save_state=true)
Sampling 100%|████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×22×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 1.13 seconds
Compute duration  = 1.13 seconds
parameters        = x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean         std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol    Float64     Float64    Float64   Float64     Float64   Float64       Float64 

        x[1]   -24.8085    956.4862    30.2468   29.3441   1077.4406    0.9996      953.4872
        x[2]   -11.3046   1060.5540    33.5377   26.0390   1026.2949    0.9990      908.2256
        x[3]   -19.6913   1009.4485    31.9216   34.7410   1026.4851    0.9992      908.3939
        x[4]    10.9409   1044.2046    33.0206   31.3643   1201.9143    0.9991     1063.6410
        x[5]   -49.8740    999.2622    31.5994   32.9384   1165.3464    0.9992     1031.2800
        x[6]     4.5491    961.3624    30.4009   30.6247   1021.8930    0.9990      904.3301
        x[7]    67.1029    998.4159    31.5727   35.1677    937.8898    0.9991      829.9909
        x[8]     1.9122   1002.4672    31.7008   30.6341   1078.1651    0.9990      954.1284
        x[9]   -34.8320   1008.0000    31.8758   35.5399    830.5491    1.0031      734.9992
       x[10]   -22.7870    975.3529    30.8434   31.6282    873.5094    0.9991      773.0171

Quantiles
  parameters         2.5%       25.0%      50.0%      75.0%       97.5% 
      Symbol      Float64     Float64    Float64    Float64     Float64 

        x[1]   -1859.0547   -665.3047    -1.1679   625.7305   1812.4707
        x[2]   -2045.1465   -693.2438    -3.2864   696.4378   2066.3285
        x[3]   -1951.5782   -696.4558   -48.6351   661.1009   2024.9583
        x[4]   -2047.1058   -718.1381    37.8368   747.6024   1979.9771
        x[5]   -1947.0055   -709.0552   -45.4802   623.8888   1929.2428
        x[6]   -1964.3656   -648.1510    22.1216   665.4752   1815.8038
        x[7]   -1929.9386   -597.4158    62.1090   784.5026   1924.7035
        x[8]   -2064.1684   -652.1909    43.2797   671.0167   1795.4633
        x[9]   -2006.4426   -703.5645   -15.7514   624.2890   1962.0789
       x[10]   -2082.6195   -646.1086   -13.8338   622.7712   1896.3478


julia> chns.info.samplerstate.hamiltonian.metric  # specified metric is not adapted
DenseEuclideanMetric(diag=[1.0, 1.0, 1.0, 1.0, 1.0, 1 ...])

julia> chns.info.samplerstate.kernel.τ.integrator  # specified integrator's step size is adapted
JitteredLeapfrog(ϵ0=837.0, jitter=0.1, ϵ=0.2)

cc @cpfiffer @yebai

@sethaxen
Copy link
Member Author

Some things I don't like about the current interface:

  • To customize the adaptor, if I end up creating the StepSizeAdaptor myself, then I need to also create the integrator myself, and I don't benefit from Turing's mechanism for finding a good initial step size.
  • If I provide the integrator or metric, I may want to disable the step size or metric adaptation, respectively. Currently this requires creating the adaptor myself, but I'd prefer a simpler way to disable one or both adaptation schemes.

@yebai
Copy link
Member

yebai commented Apr 26, 2022

@sethaxen apologies for the slow response, I'll take a look later this week.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks, @sethaxen - very nice improvements. I like your design in general. Only some (very) minor comments below. I am happy to merge as-is if you like.

EDIT: seems that we need to fix some tests before merging.

Δ_max::Float64
ϵ::Float64 # (initial) step size
metric::metricT
integrator::I
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor comment: integrator seems a bit redundant with ϵ here. Maybe consider removing the ϵ field in this type, and storing the information in DefaultIntegrator when a user specifies ϵ in the constructor?

NUTS{AD}(-1, 0.65; kwargs...)
end

function NUTS{AD}(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for cleaning up the NUTS constructor.

Comment on lines +42 to +44
###
### Default options
###
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make sense to add the default options to AdvancedHMC?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that would make for a cleaner design in the end.

Comment on lines +405 to +408
TS<:AHMC.AbstractTrajectorySampler,
TC<:AHMC.AbstractTerminationCriterion,
I<:AHMC.AbstractIntegrator,
A<:AHMC.AbstractAdaptor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it sufficient to make these type parameters without corresponding fields (apart from I)? I.e., do the types contain all information or do we need values?

Δ_max::Float64,
ϵ::Float64,
::Type{metricT},
metricT::Type,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change this?

end

function NUTS{AD}(kwargs...) where AD
function NUTS{AD}(; kwargs...) where AD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the current code did not work? Otherwise this would be a breaking change it seems (positional arguments replaced with keyword arguments).

Comment on lines +452 to 457
function NUTS{AD}::Float64; kwargs...) where AD
NUTS{AD}(-1, δ; kwargs...)
end

function NUTS{AD}(kwargs...) where AD
function NUTS{AD}(; kwargs...) where AD
NUTS{AD}(-1, 0.65; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this can be simplified to (can't make a suggestion)

function NUTS{AD}::Float64=0.65; kwargs...) where AD
    NUTS{AD}(-1, δ; kwargs...)

Comment on lines +654 to 663
if spl.alg.integrator isa DefaultIntegrator
if iszero(spl.alg.ϵ)
ϵ = AHMC.find_good_stepsize(h, θ_init)
@info "Found initial step size" ϵ
else
ϵ = spl.alg.ϵ
end
else
ϵ = spl.alg.ϵ
ϵ = AHMC.step_size(spl.alg.integrator)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this could be refactored into a separate function since it's the same as above?


return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
adaptor = as_concrete(spl.alg.adaptor, metric; n_adapts=spl.alg.n_adapts, ϵ=ϵ, δ=spl.alg.δ)
return HMCState(vi, 0, 0, kernel.τ, h, spl.alg.adaptor, t.z)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should be

Suggested change
return HMCState(vi, 0, 0, kernel.τ, h, spl.alg.adaptor, t.z)
return HMCState(vi, 0, 0, kernel.τ, h, adaptor, t.z)

@yebai
Copy link
Member

yebai commented Jun 6, 2023

Closed in favour of #1997

@yebai yebai closed this Jun 6, 2023
@yebai yebai deleted the fullahmc branch June 6, 2023 09:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants