diff --git a/docs/make.jl b/docs/make.jl index c0abd2c..b3028a0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -30,6 +30,7 @@ makedocs(; "Home" => "index.md", "Basic usage" => [ "Function minimisation" => "function_minimisation.md" + "Distribution sampling" => "distribution_sampling.md" "Method parameters" => "method_parameters.md" "Stopping criteria" => "stopping_criteria.md" "Particle initialisation" => "particle_initialisation.md" diff --git a/docs/src/distribution_sampling.md b/docs/src/distribution_sampling.md new file mode 100644 index 0000000..c00e2a3 --- /dev/null +++ b/docs/src/distribution_sampling.md @@ -0,0 +1,43 @@ +# Distribution sampling + +ConsensusBasedX.jl also provides consensus-based sampling, [J. A. Carrillo, F. Hoffmann, A. M. Stuart, and U. Vaes (2022)](https://onlinelibrary.wiley.com/doi/10.1111/sapm.12470). The package exports `sample`, which behaves exactly as `minimise` in [Function minimisation](@ref). It assumes you have defined a function `f(x::AbstractVector)` that takes a single vector argumemt `x` of length `D = length(x)`. + +For instance, if `D = 2`, you can sample `exp(-f)` by running: +```julia +out = sample(f, D = 2, extended_output=true) +out.sample +``` +[Full-code example](https://github.com/PdIPS/ConsensusBasedX.jl/blob/main/examples/basic_usage/sample_with_keywords.jl). + +!!! note + You must always provide `D`. + + +## Using a `config` object + +For more advanced usage, you will select several options. You can pass these as extra keyword arguments to `sample`, or you can create a `NamedTuple` called `config` and pass that: +```julia +config = (; D = 2, extended_output=true) +out = sample(f, config) +out.sample +``` +[Full-code example](https://github.com/PdIPS/ConsensusBasedX.jl/blob/main/examples/basic_usage/sample_with_config.jl). + +!!! note + If you pass a `Dict` instead, it will be converted to a `NamedTuple` automatically. + + +## Running on minimisation mode + +Consensus-based sampling can also be used for minimisation. If you want to run it in that mode, pass the option `CBS_mode = :minimise`. + + +## Method reference + +```@index +Pages = ["distribution_sampling.md"] +``` + +```@docs +ConsensusBasedX.sample +``` diff --git a/docs/src/function_minimisation.md b/docs/src/function_minimisation.md index 8aeb025..b96f97a 100644 --- a/docs/src/function_minimisation.md +++ b/docs/src/function_minimisation.md @@ -49,7 +49,6 @@ Full-code examples are provided for the [keyword](https://github.com/PdIPS/Conse ## Method reference - ```@index Pages = ["function_minimisation.md"] ``` diff --git a/docs/src/low_level.md b/docs/src/low_level.md index 4b9090a..8c9b8f5 100644 --- a/docs/src/low_level.md +++ b/docs/src/low_level.md @@ -49,3 +49,20 @@ The full reference is: ```@docs ConsensusBasedX.ConsensusBasedOptimisationCache ``` + +## `ConsensusBasedSampling` + +The `ConsensusBasedSampling` struct (of type `CBXMethod`) defines the details of the *consensus-based sampling method* (function evaluations, covariance matrix...). + +```@docs +ConsensusBasedX.ConsensusBasedSampling +``` + +### `ConsensusBasedSamplingCache` + +`ConsensusBasedSampling` requires a cache, `ConsensusBasedSamplingCache` (of type `CBXMethodCache`). This can be constructed with [`ConsensusBasedX.construct_method_cache`](@ref). + +The full reference is: +```@docs +ConsensusBasedX.ConsensusBasedSamplingCache +``` diff --git a/docs/src/summary_options.md b/docs/src/summary_options.md index fb291b8..697e8cc 100644 --- a/docs/src/summary_options.md +++ b/docs/src/summary_options.md @@ -38,3 +38,7 @@ See [Stopping criteria](@ref). - `extended_output::Bool = false` controls the output, and by default returns only the computed minimiser. `extended_output = true` returns additional information, see [Extended output](@ref). - `parallelisation = :NoParallelisation` controls the parallelisation of the `minimise` routine, switched off by default. `parallelisation=:EnsembleParallelisation` enables parallelisation, see [Parallelisation](@ref). - `verbosity::Int = 0` is the verbosity level. `verbosity = 0` produces no output to console. `verbosity = 1` produces some output. + +## Consensus-based sampling options + +- `CBS_mode = :sampling` controls the mode of consensus-based sampling. If you want to perform a minimisation, pass `CBS_mode = :minimise` instead. diff --git a/examples/basic_usage/sample_to_minimise.jl b/examples/basic_usage/sample_to_minimise.jl new file mode 100644 index 0000000..d651f07 --- /dev/null +++ b/examples/basic_usage/sample_to_minimise.jl @@ -0,0 +1,6 @@ +using ConsensusBasedX + +f(x) = ConsensusBasedX.Ackley(x, shift = 1) + +config = (; D = 2, N = 20, CBS_mode = :minimise) +sample(f, config) diff --git a/examples/basic_usage/sample_with_config.jl b/examples/basic_usage/sample_with_config.jl new file mode 100644 index 0000000..5714727 --- /dev/null +++ b/examples/basic_usage/sample_with_config.jl @@ -0,0 +1,7 @@ +using ConsensusBasedX + +f(x) = ConsensusBasedX.Ackley(x, shift = 1) + +config = (; D = 2, N = 20, extended_output = true) +out = sample(f, config) +out.sample diff --git a/examples/basic_usage/sample_with_keywords.jl b/examples/basic_usage/sample_with_keywords.jl new file mode 100644 index 0000000..10dfc27 --- /dev/null +++ b/examples/basic_usage/sample_with_keywords.jl @@ -0,0 +1,6 @@ +using ConsensusBasedX + +f(x) = ConsensusBasedX.Ackley(x, shift = 1) + +out = sample(f, D = 2, N = 20, extended_output = true) +out.sample diff --git a/src/CBO/CBO.jl b/src/CBO/CBO.jl index a4a4312..da5d360 100644 --- a/src/CBO/CBO.jl +++ b/src/CBO/CBO.jl @@ -7,7 +7,7 @@ Fields: - `f`, the objective function. - `correction<:CBXCorrection`, a correction term. - - `α::Float64`, the the exponential weight parameter. + - `α::Float64`, the exponential weight parameter. - `λ::Float64`, the drift strengh. - `σ::Float64`, the noise strengh. """ diff --git a/src/CBO/CBO_method.jl b/src/CBO/CBO_method.jl index eaa068b..a771d81 100644 --- a/src/CBO/CBO_method.jl +++ b/src/CBO/CBO_method.jl @@ -42,7 +42,7 @@ function compute_CBO_update!( particle_dynamic_cache::ParticleDynamicCache, m::Int, ) where {TF, TCorrection} - @expand particle_dynamic_cache D N X dX Δt root2Δt + @expand particle_dynamic_cache D N X dX Δt root_2Δt @expand method correction λ σ @expand method_cache consensus consensus_energy distance energy @@ -59,7 +59,7 @@ function compute_CBO_update!( λ * (consensus[m][d] - X[m][n][d]) * correction(energy[m][n] - consensus_energy[m]) + - root2Δt * σ * distance[m][n] * randn() + root_2Δt * σ * distance[m][n] * randn() end end return nothing @@ -72,7 +72,7 @@ function compute_CBO_update!( particle_dynamic_cache::ParticleDynamicCache, m::Int, ) where {TF, TCorrection} - @expand particle_dynamic_cache D N X dX Δt root2Δt + @expand particle_dynamic_cache D N X dX Δt root_2Δt @expand method correction λ σ @expand method_cache consensus consensus_energy energy @@ -81,7 +81,7 @@ function compute_CBO_update!( dX[m][n][d] = (consensus[m][d] - X[m][n][d]) * ( Δt * λ * correction(energy[m][n] - consensus_energy[m]) + - root2Δt * σ * randn() + root_2Δt * σ * randn() ) end end diff --git a/src/CBO/is_method_pending.jl b/src/CBO/is_method_pending.jl index 212dc17..12691a0 100644 --- a/src/CBO/is_method_pending.jl +++ b/src/CBO/is_method_pending.jl @@ -1,6 +1,6 @@ function is_method_pending( - method::ConsensusBasedOptimisation, - method_cache::ConsensusBasedOptimisationCache, + method::CBXMethod, + method_cache::CBXMethodCache, particle_dynamic::ParticleDynamic, particle_dynamic_cache::ParticleDynamicCache, m::Int, @@ -39,8 +39,8 @@ function is_method_pending( end function is_method_pending_energy_threshold( - method::ConsensusBasedOptimisation, - method_cache::ConsensusBasedOptimisationCache, + method::CBXMethod, + method_cache::CBXMethodCache, particle_dynamic::ParticleDynamic, particle_dynamic_cache::ParticleDynamicCache, m::Int, @@ -50,8 +50,8 @@ function is_method_pending_energy_threshold( end function is_method_pending_energy_tolerance( - method::ConsensusBasedOptimisation, - method_cache::ConsensusBasedOptimisationCache, + method::CBXMethod, + method_cache::CBXMethodCache, particle_dynamic::ParticleDynamic, particle_dynamic_cache::ParticleDynamicCache, m::Int, @@ -62,8 +62,8 @@ function is_method_pending_energy_tolerance( end function is_method_pending_max_evaluations( - method::ConsensusBasedOptimisation, - method_cache::ConsensusBasedOptimisationCache, + method::CBXMethod, + method_cache::CBXMethodCache, particle_dynamic::ParticleDynamic, particle_dynamic_cache::ParticleDynamicCache, m::Int, diff --git a/src/CBS/CBS.jl b/src/CBS/CBS.jl new file mode 100644 index 0000000..6f7192c --- /dev/null +++ b/src/CBS/CBS.jl @@ -0,0 +1,249 @@ +""" +```julia +ConsensusBasedSampling +``` + +Fields: + + - `f`, the objective function. + - `α::Float64`, the exponential weight parameter. + - `λ::Float64`, the mode parameter. `λ = 1 / (1 + α)` corresponds to `CBS_mode = :sampling`, and `λ = 1` corresponds to `CBS_mode = :minimise`. +""" +mutable struct ConsensusBasedSampling{TF} <: CBXMethod + f::TF + α::Float64 + λ::Float64 +end + +@config function construct_CBS(f; α::Real = 10, CBS_mode = :sampling) + @assert α >= 0 + if Symbol(CBS_mode) == :sampling + λ = 1 / (1 + α) + elseif Symbol(CBS_mode) in + [:minimise, :minimisation, :optimise, :optimisation] + λ = 1 + else + explanation = "`CBS_mode` should be either `:sampling` or `:minimise`." + throw(ArgumentError(explanation)) + end + return ConsensusBasedSampling(f, float(α), float(λ)) +end + +""" +```julia +ConsensusBasedSamplingCache{T} +``` + +**It is strongly recommended that you do not construct `ConsensusBasedSamplingCache` by hand.** Instead, use [`ConsensusBasedX.construct_method_cache`](@ref). + +Fields: + + - `consensus::Vector{Vector{T}}`, the consensus point of each ensemble. + - `consensus_energy::Vector{T}`, the energy (value of the objective function) of each consensus point. + - `consensus_energy_previous::Vector{T}`, the previous energy. + - `energy::Vector{Vector{T}}`, the energy of each particle. + - `exponents::Vector{Vector{T}}`, an exponent used to compute `logsums`. + - `logsums::Vector{T}`, a normalisation factor for `weights`. + - `noise::Vector{Vector{T}}`, a vector to contain the noise of one iteration. + - `root_covariance::Vector{Matrix{T}}`, the matrix square root of the weighted covariance of the particles. + - `weights::Vector{Vector{T}}`, the exponential weight of each particle. + - `energy_threshold::Float64`, the energy threshold. + - `energy_tolerance::Float64`, the energy tolerance. + - `max_evaluations::Float64`, the maximum number of `f` evaluations. + - `evaluations::Vector{Int}`, the current number of `f` evaluations. + - `exp_minus_Δt::Float64`, the time-stepping parameter. + - `noise_factor::Float64`, the noise multiplier. +""" +mutable struct ConsensusBasedSamplingCache{T} <: CBXMethodCache + consensus::Vector{Vector{T}} + consensus_energy::Vector{T} + consensus_energy_previous::Vector{T} + energy::Vector{Vector{T}} + exponents::Vector{Vector{T}} + logsums::Vector{T} + noise::Vector{Vector{T}} + root_covariance::Vector{Matrix{T}} + weights::Vector{Vector{T}} + + energy_threshold::Float64 + energy_tolerance::Float64 + max_evaluations::Float64 + + evaluations::Vector{Int} + + exp_minus_Δt::Float64 + noise_factor::Float64 +end + +@config function construct_method_cache( + X₀::AbstractArray, + method::ConsensusBasedSampling, + particle_dynamic::ParticleDynamic; + D::Int, + N::Int, + M::Int, + energy_threshold::Real = -Inf, + energy_tolerance::Real = 1e-8, + max_evaluations::Real = Inf, +) + @assert energy_tolerance >= 0 + @assert max_evaluations >= 0 + + type = deep_eltype(X₀) + consensus = nested_zeros(type, M, D) + consensus_energy = nested_zeros(type, M) + consensus_energy_previous = nested_zeros(type, M) + energy = nested_zeros(type, M, N) + exponents = nested_zeros(type, M, N) + logsums = nested_zeros(type, M) + noise = nested_zeros(type, M, D) + root_covariance = nested_zeros(type, M, (D, D)) + weights = nested_zeros(type, M, N) + + evaluations = zeros(Int, M) + + exp_minus_Δt = 0.0 + noise_factor = 0.0 + + method_cache = ConsensusBasedSamplingCache{type}( + consensus, + consensus_energy, + consensus_energy_previous, + energy, + exponents, + logsums, + noise, + root_covariance, + weights, + energy_threshold, + energy_tolerance, + max_evaluations, + evaluations, + exp_minus_Δt, + noise_factor, + ) + + return method_cache +end + +function set_Δt!( + method::ConsensusBasedSampling, + method_cache::ConsensusBasedSamplingCache, + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, + Δt::Real, +) + method_cache.exp_minus_Δt = exp(-Δt) + method_cache.noise_factor = sqrt((1 - method_cache.exp_minus_Δt^2) / method.λ) + return nothing +end + +function initialise_method_cache!( + X₀::AbstractArray, + method::ConsensusBasedSampling, + method_cache::ConsensusBasedSamplingCache, + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, +) + @expand particle_dynamic_cache M + @expand method_cache evaluations + + for m ∈ 1:M + evaluations[m] = 0 + end + return nothing +end + +function initialise_method!( + method::ConsensusBasedSampling, + method_cache::ConsensusBasedSamplingCache, + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, +) + @expand method_cache consensus_energy_previous + + for m ∈ 1:(particle_dynamic_cache.M) + compute_CBO_consensus!( + method, + method_cache, + particle_dynamic, + particle_dynamic_cache, + m, + ) + compute_CBS_root_covariance!( + method, + method_cache, + particle_dynamic, + particle_dynamic_cache, + m, + ) + consensus_energy_previous[m] = Inf + end + return nothing +end + +function compute_method_step!( + method::ConsensusBasedSampling, + method_cache::ConsensusBasedSamplingCache, + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, + m::Int, +) + compute_CBS_update!( + method, + method_cache, + particle_dynamic, + particle_dynamic_cache, + m, + ) + return nothing +end + +function finalise_method_step!( + method::ConsensusBasedSampling, + method_cache::ConsensusBasedSamplingCache, + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, + m::Int, +) + compute_CBO_consensus!( + method, + method_cache, + particle_dynamic, + particle_dynamic_cache, + m, + ) + compute_CBS_root_covariance!( + method, + method_cache, + particle_dynamic, + particle_dynamic_cache, + m, + ) + return nothing +end + +function wrap_output( + X₀::AbstractArray, + method::ConsensusBasedSampling, + method_cache::ConsensusBasedSamplingCache, + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, +) + ensemble_minimiser = method_cache.consensus + minimiser = sum(ensemble_minimiser) / length(ensemble_minimiser) + initial_particles = X₀ + final_particles = particle_dynamic_cache.X + sample = final_particles + return (; + minimiser, + ensemble_minimiser, + initial_particles, + final_particles, + method, + method_cache, + particle_dynamic, + particle_dynamic_cache, + sample, + ) +end diff --git a/src/CBS/CBS_method.jl b/src/CBS/CBS_method.jl new file mode 100644 index 0000000..6432fc6 --- /dev/null +++ b/src/CBS/CBS_method.jl @@ -0,0 +1,53 @@ +function compute_CBS_root_covariance!( + method::ConsensusBasedSampling, + method_cache::ConsensusBasedSamplingCache, + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, + m::Int, +) + @expand particle_dynamic_cache D N X + @expand method_cache consensus root_covariance weights + + for d ∈ 1:D, d2 ∈ 1:D + root_covariance[m][d2, d] = 0 + end + + for n ∈ 1:N + for d ∈ 1:D, d2 ∈ d:D + root_covariance[m][d2, d] += + weights[m][n] * + (X[m][n][d] - consensus[m][d]) * + (X[m][n][d2] - consensus[m][d2]) + root_covariance[m][d, d2] = root_covariance[m][d2, d] + end + end + + root_covariance[m] .= real.(sqrt(root_covariance[m])) + + return nothing +end + +function compute_CBS_update!( + method::ConsensusBasedSampling, + method_cache::ConsensusBasedSamplingCache, + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, + m::Int, +) + @expand particle_dynamic_cache D N X dX + @expand method_cache consensus exp_minus_Δt noise noise_factor root_covariance + + for n ∈ 1:N + for d ∈ 1:D + noise[m][d] = noise_factor * randn() + end + LinearAlgebra.mul!(dX[m][n], root_covariance[m], noise[m]) + + for d ∈ 1:D + dX[m][n][d] += + (exp_minus_Δt - 1) * X[m][n][d] + (1 - exp_minus_Δt) * consensus[m][d] + end + end + + return nothing +end diff --git a/src/ConsensusBasedX.jl b/src/ConsensusBasedX.jl index 8fa57de..b5713d7 100644 --- a/src/ConsensusBasedX.jl +++ b/src/ConsensusBasedX.jl @@ -22,6 +22,7 @@ include("./interface/maximise.jl") include("./interface/minimise.jl") include("./interface/optimise.jl") include("./interface/parse_config.jl") +include("./interface/sample.jl") include("./dynamics/ParticleDynamics.jl") include("./dynamics/benchmark_dynamic.jl") @@ -33,6 +34,9 @@ include("./CBO/CBO_method.jl") include("./CBO/corrections.jl") include("./CBO/is_method_pending.jl") +include("./CBS/CBS.jl") +include("./CBS/CBS_method.jl") + include("./ConsensusBasedXLowLevel.jl") export ConsensusBasedXLowLevel include("./ConsensusBasedXPlots.jl") diff --git a/src/dynamics/ParticleDynamics.jl b/src/dynamics/ParticleDynamics.jl index 76db73b..63ef987 100644 --- a/src/dynamics/ParticleDynamics.jl +++ b/src/dynamics/ParticleDynamics.jl @@ -36,8 +36,8 @@ Fields: - `X`, the particle array. - `dX`, the time derivative array. - `Δt::Float64`, the time step. - - `rootΔt::Float64`, the square root of the time step. - - `root2Δt::Float64`, the square root of twice the time step. + - `root_Δt::Float64`, the square root of the time step. + - `root_2Δt::Float64`, the square root of twice the time step. - `max_iterations::Float64`, the maximum number of iterations. - `max_time::Float64`, the maximal time. - `iteration::Vector{Int}`, the vector containing the iteration count per ensemble. @@ -61,8 +61,8 @@ mutable struct ParticleDynamicCache{ dX::TdX Δt::Float64 - rootΔt::Float64 - root2Δt::Float64 + root_Δt::Float64 + root_2Δt::Float64 max_iterations::Float64 max_time::Float64 @@ -89,8 +89,8 @@ end X = deep_zero(X₀) dX = deep_zero(X₀) Δt = 0.0 - rootΔt = 0.0 - root2Δt = 0.0 + root_Δt = 0.0 + root_2Δt = 0.0 iteration = zeros(Int, M) @@ -111,8 +111,8 @@ end X, dX, Δt, - rootΔt, - root2Δt, + root_Δt, + root_2Δt, float(max_iterations), float(max_time), iteration, @@ -148,7 +148,7 @@ function initialise_particle_dynamic_cache!( @expand particle_dynamic_cache M iteration deep_copyto!(particle_dynamic_cache.X, X₀) - set_Δt!(particle_dynamic_cache, particle_dynamic.Δt) + set_Δt!(particle_dynamic, particle_dynamic_cache, particle_dynamic.Δt) for m ∈ 1:M iteration[m] = 0 end @@ -163,16 +163,28 @@ function initialise_particle_dynamic_cache!( return nothing end -function set_Δt!(particle_dynamic_cache::ParticleDynamicCache, Δt::Real) +function set_Δt!( + particle_dynamic::ParticleDynamic, + particle_dynamic_cache::ParticleDynamicCache, + Δt::Real, +) particle_dynamic_cache.Δt = Δt - particle_dynamic_cache.rootΔt = sqrt(Δt) - particle_dynamic_cache.root2Δt = sqrt(2) * particle_dynamic_cache.rootΔt - set_Δt!(particle_dynamic_cache.method_cache, particle_dynamic_cache, Δt) + particle_dynamic_cache.root_Δt = sqrt(Δt) + particle_dynamic_cache.root_2Δt = sqrt(2) * particle_dynamic_cache.root_Δt + set_Δt!( + particle_dynamic.method, + particle_dynamic_cache.method_cache, + particle_dynamic, + particle_dynamic_cache, + Δt, + ) return nothing end function set_Δt!( + method::CBXMethod, method_cache::CBXMethodCache, + particle_dynamic::ParticleDynamic, particle_dynamic_cache::ParticleDynamicCache, Δt::Real, ) diff --git a/src/dynamics/is_dynamic_pending.jl b/src/dynamics/is_dynamic_pending.jl index ba1e1e3..ca974e0 100644 --- a/src/dynamics/is_dynamic_pending.jl +++ b/src/dynamics/is_dynamic_pending.jl @@ -42,13 +42,3 @@ function is_dynamic_pending_max_time( @expand particle_dynamic_cache iteration max_time return iteration[m] * Δt < max_time end - -function is_method_pending( - method::CBXMethod, - method_cache::CBXMethodCache, - particle_dynamic::ParticleDynamic, - particle_dynamic_cache::ParticleDynamicCache, - m::Int, -) - return true, "method_pending" -end diff --git a/src/interface/config.jl b/src/interface/config.jl index 9409a73..dd5a312 100644 --- a/src/interface/config.jl +++ b/src/interface/config.jl @@ -1,4 +1,4 @@ -for routine ∈ (:minimise, :maximise) +for routine ∈ (:minimise, :maximise, :sample) @eval begin function $routine(f; kw...) config = NamedTuple(kw) diff --git a/src/interface/minimise.jl b/src/interface/minimise.jl index 79ffeff..427f067 100644 --- a/src/interface/minimise.jl +++ b/src/interface/minimise.jl @@ -15,24 +15,22 @@ You must specify the dimension `D` of the problem. Other paramters (e.g. the num # Examples -```julia-repl -julia> minimise(f, D = 2) - +```julia +minimise(f, D = 2) ``` -```julia-repl -julia> minimise(f, config) +```julia config = (; D = 2); +minimise(f, config) ``` -```julia-repl -julia> minimise(f, D = 2, N = 20) - +```julia +minimise(f, D = 2, N = 20) ``` -```julia-repl -julia> minimise(f, config) +```julia config = (; D = 2, N = 20); +minimise(f, config) ``` """ function minimise(f, config::NamedTuple) @@ -46,7 +44,7 @@ export minimise X₀ = initialise_particles(config) parsed_X₀ = reshape(X₀, mode) - particle_dynamic = get_particle_dynamic(config, f) + particle_dynamic = get_minimise_particle_dynamic(config, f) particle_dynamic_cache = construct_particle_dynamic_cache(config, parsed_X₀, particle_dynamic) @@ -59,13 +57,9 @@ export minimise ) end -@config function get_particle_dynamic(f) +@config function get_minimise_particle_dynamic(f) @verb " • Constructing dynamic" - - # correction = NoCorrection() correction = HeavisideCorrection() - # correction = RegularisedHeavisideCorrection(1e-3) - method = construct_CBO(config, f, correction, config.noise) particle_dynamic = construct_particle_dynamic(config, method) return particle_dynamic diff --git a/src/interface/sample.jl b/src/interface/sample.jl new file mode 100644 index 0000000..1574075 --- /dev/null +++ b/src/interface/sample.jl @@ -0,0 +1,67 @@ +""" +```julia +sample(f; keywords...) +``` + +```julia +sample(f, config::NamedTuple) +``` + +Sample the distribution `exp(f)` using Consensus-Based Sampling (see [Distribution sampling](@ref)). + +You must specify the dimension `D` of the problem. Other paramters (e.g. the number of particles `N` or the number of ensembles `M` can also be specified; see [Summary of options](@ref). + +# Examples + +```julia +out = sample(f, D = 2, extended_output = true); +out.sample +``` + +```julia +config = (; D = 2, extended_output = true); +out = sample(f, config); +out.sample +``` + +```julia +out = sample(f, D = 2, N = 20, extended_output = true); +out.sample +``` + +```julia +config = (; D = 2, N = 20, extended_output = true); +out = sample(f, config); +out.sample +``` +""" +function sample(f, config::NamedTuple) + return sample_with_parsed_config(parse_config(config), f) +end +export sample + +@config function sample_with_parsed_config(f; mode) + @verb "[ConsensusBasedX.jl]: Executing sampling..." + + X₀ = initialise_particles(config) + parsed_X₀ = reshape(X₀, mode) + + particle_dynamic = get_sample_particle_dynamic(config, f) + + particle_dynamic_cache = + construct_particle_dynamic_cache(config, parsed_X₀, particle_dynamic) + + return wrapped_run_dynamic!( + config, + parsed_X₀, + particle_dynamic, + particle_dynamic_cache, + ) +end + +@config function get_sample_particle_dynamic(f) + @verb " • Constructing dynamic" + method = construct_CBS(config, f) + particle_dynamic = construct_particle_dynamic(config, method) + return particle_dynamic +end diff --git a/src/utils/arrays.jl b/src/utils/arrays.jl index 424dd2f..7d4a04a 100644 --- a/src/utils/arrays.jl +++ b/src/utils/arrays.jl @@ -81,12 +81,10 @@ function deep_zero(x::AbstractArray) return y end -nested_zeros(type::Type, dim::Int) = zeros(type, dim) +nested_zeros(type::Type, dim) = zeros(type, dim) -function nested_zeros(type::Type, dim::Int, dim2::Int) - return [zeros(type, dim2) for k ∈ 1:dim] -end +nested_zeros(type::Type, dim::Int, dim2) = [zeros(type, dim2) for k ∈ 1:dim] -function nested_zeros(type::Type, dim::Int, dim2::Int, dim3::Int) +function nested_zeros(type::Type, dim::Int, dim2::Int, dim3) return [zeros(type, dim2, dim3) for k ∈ 1:dim] end diff --git a/test/interface/minimise.jl b/test/interface/minimise.jl index 72f95d8..b2d2637 100644 --- a/test/interface/minimise.jl +++ b/test/interface/minimise.jl @@ -1,17 +1,20 @@ using ConsensusBasedX, Test function tests() - alloc(x) = Base.gc_alloc_count(x.gcstats) + f(x) = ConsensusBasedX.Quadratic(x, shift = 1) + g(x) = ConsensusBasedX.Ackley(x, shift = 1) + h(x) = ConsensusBasedX.Rastrigin(x, shift = 1) + + config = (; D = 2,) + @test_nowarn minimise(f, config) + alloc(x) = Base.gc_alloc_count(x.gcstats) config = (; D = 2, benchmark = true) - f(x) = ConsensusBasedX.Quadratic(x, shift = 1) @test alloc(minimise(f, config)) == 0 - g(x) = ConsensusBasedX.Ackley(x, shift = 1) @test alloc(minimise(g, config)) == 0 - h(x) = ConsensusBasedX.Rastrigin(x, shift = 1) @test alloc(minimise(h, config)) == 0 end diff --git a/test/interface/sample.jl b/test/interface/sample.jl new file mode 100644 index 0000000..2fda063 --- /dev/null +++ b/test/interface/sample.jl @@ -0,0 +1,10 @@ +using ConsensusBasedX, Test + +function tests() + f(x) = ConsensusBasedX.Quadratic(x, shift = 1) + + config = (; D = 2) + @test_nowarn sample(f, config) +end + +tests()