-
Notifications
You must be signed in to change notification settings - Fork 231
Add dense pre-conditioner #607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/samplers/support/hmc_core.jl
Outdated
| d = length(vi[spl]) | ||
| A = pc.covar | ||
| # TODO: the type coversion below is use to a possible bug in Julia.LinearAlgebra.cholesky for Float64 | ||
| A = Matrix{Float32}(pc.covar) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a weird issue I met today. Basically I saw a bug saying that my matrix A is not Hermitian. Then I printed the matrix, copied and paste to the REPL and call the cholesky function again, which works. Then I believe this might be something related to precision, and tried to force the type to Float32 here - then the bug disappears.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is normal. If you know the matrix is symmetric to some precision, wrap it with a Symmetric wrapper before passing to Cholesky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
| return (n / ((n + 5) * (n - 1))) .* M .+ 1e-3 * (5 / (n + 5)) | ||
| end | ||
|
|
||
| abstract type CovarEstimator{TI<:Integer,TF<:Real} end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need two abstract types (VarEstimator, CovarEstimator)? Perhaps consider merge them into one abstract type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently have both naive covariance computation and the Welford one. The DensePreConditioner can use either of them
| wc.n += 1 | ||
| δ = s .- wc.μ | ||
| wc.μ .+= δ ./ wc.n | ||
| wc.M .+= (s .- wc.μ) * δ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks correct to me!
| function get_covar(wc::WelfordCovar{TI,TF})::Matrix{TF} where {TI<:Integer,TF<:Real} | ||
| n, M = wc.n, wc.M | ||
| @assert n >= 2 "Cannot get variance with only one sample" | ||
| return (n / ((n + 5) * (n - 1))) .* M + 1e-3 * (5 / (n + 5)) * LinearAlgebra.I |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks correct to me!
|
|
||
| # NOTE: related Hamiltonian change: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp | ||
| function gen_momentum_sampler(vi::VarInfo, spl::Sampler, pc::DensePreConditioner) | ||
| d = length(vi[spl]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps dropd = length(vi[spl]) since d = size(A,1)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked over the code. For gen_momentum_sampler for UnitPreConditioner and those without pre-conditioner, we have to use d = length(vi[spl]). I feel it's more consistent to use this for all gen_momentum_sampler, + we cannot simplify the function signature because other gen_momentum_sampler need spl and `vi.
test/adapt.jl/covar_estimator.jl
Outdated
|
|
||
| covar = get_covar(wc) | ||
|
|
||
| @test covar ≈ LinearAlgebra.diagm(0 => ones(D)) atol=0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a tighter bound than 0.5?
| end | ||
|
|
||
| var = get_var(ve) | ||
| var = get_var(wv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a tighter bound than 0.5?
|
Thanks, @xukai92. I have reviewed the code and it looks correct! Ready to merge once the minor improvement suggestions are addressed! |
No description provided.