Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consensus-based sampling #23

Merged
merged 1 commit into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ makedocs(;
"Home" => "index.md",
"Basic usage" => [
"Function minimisation" => "function_minimisation.md"
"Distribution sampling" => "distribution_sampling.md"
"Method parameters" => "method_parameters.md"
"Stopping criteria" => "stopping_criteria.md"
"Particle initialisation" => "particle_initialisation.md"
Expand Down
43 changes: 43 additions & 0 deletions docs/src/distribution_sampling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Distribution sampling

ConsensusBasedX.jl also provides consensus-based sampling, [J. A. Carrillo, F. Hoffmann, A. M. Stuart, and U. Vaes (2022)](https://onlinelibrary.wiley.com/doi/10.1111/sapm.12470). The package exports `sample`, which behaves exactly as `minimise` in [Function minimisation](@ref). It assumes you have defined a function `f(x::AbstractVector)` that takes a single vector argumemt `x` of length `D = length(x)`.

For instance, if `D = 2`, you can sample `exp(-f)` by running:
```julia
out = sample(f, D = 2, extended_output=true)
out.sample
```
[Full-code example](https://github.com/PdIPS/ConsensusBasedX.jl/blob/main/examples/basic_usage/sample_with_keywords.jl).

!!! note
You must always provide `D`.


## Using a `config` object

For more advanced usage, you will select several options. You can pass these as extra keyword arguments to `sample`, or you can create a `NamedTuple` called `config` and pass that:
```julia
config = (; D = 2, extended_output=true)
out = sample(f, config)
out.sample
```
[Full-code example](https://github.com/PdIPS/ConsensusBasedX.jl/blob/main/examples/basic_usage/sample_with_config.jl).

!!! note
If you pass a `Dict` instead, it will be converted to a `NamedTuple` automatically.


## Running on minimisation mode

Consensus-based sampling can also be used for minimisation. If you want to run it in that mode, pass the option `CBS_mode = :minimise`.


## Method reference

```@index
Pages = ["distribution_sampling.md"]
```

```@docs
ConsensusBasedX.sample
```
1 change: 0 additions & 1 deletion docs/src/function_minimisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ Full-code examples are provided for the [keyword](https://github.com/PdIPS/Conse

## Method reference


```@index
Pages = ["function_minimisation.md"]
```
Expand Down
17 changes: 17 additions & 0 deletions docs/src/low_level.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,20 @@ The full reference is:
```@docs
ConsensusBasedX.ConsensusBasedOptimisationCache
```

## `ConsensusBasedSampling`

The `ConsensusBasedSampling` struct (of type `CBXMethod`) defines the details of the *consensus-based sampling method* (function evaluations, covariance matrix...).

```@docs
ConsensusBasedX.ConsensusBasedSampling
```

### `ConsensusBasedSamplingCache`

`ConsensusBasedSampling` requires a cache, `ConsensusBasedSamplingCache` (of type `CBXMethodCache`). This can be constructed with [`ConsensusBasedX.construct_method_cache`](@ref).

The full reference is:
```@docs
ConsensusBasedX.ConsensusBasedSamplingCache
```
4 changes: 4 additions & 0 deletions docs/src/summary_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ See [Stopping criteria](@ref).
- `extended_output::Bool = false` controls the output, and by default returns only the computed minimiser. `extended_output = true` returns additional information, see [Extended output](@ref).
- `parallelisation = :NoParallelisation` controls the parallelisation of the `minimise` routine, switched off by default. `parallelisation=:EnsembleParallelisation` enables parallelisation, see [Parallelisation](@ref).
- `verbosity::Int = 0` is the verbosity level. `verbosity = 0` produces no output to console. `verbosity = 1` produces some output.

## Consensus-based sampling options

- `CBS_mode = :sampling` controls the mode of consensus-based sampling. If you want to perform a minimisation, pass `CBS_mode = :minimise` instead.
6 changes: 6 additions & 0 deletions examples/basic_usage/sample_to_minimise.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using ConsensusBasedX

f(x) = ConsensusBasedX.Ackley(x, shift = 1)

config = (; D = 2, N = 20, CBS_mode = :minimise)
sample(f, config)
7 changes: 7 additions & 0 deletions examples/basic_usage/sample_with_config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using ConsensusBasedX

f(x) = ConsensusBasedX.Ackley(x, shift = 1)

config = (; D = 2, N = 20, extended_output = true)
out = sample(f, config)
out.sample
6 changes: 6 additions & 0 deletions examples/basic_usage/sample_with_keywords.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using ConsensusBasedX

f(x) = ConsensusBasedX.Ackley(x, shift = 1)

out = sample(f, D = 2, N = 20, extended_output = true)
out.sample
2 changes: 1 addition & 1 deletion src/CBO/CBO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Fields:

- `f`, the objective function.
- `correction<:CBXCorrection`, a correction term.
- `α::Float64`, the the exponential weight parameter.
- `α::Float64`, the exponential weight parameter.
- `λ::Float64`, the drift strengh.
- `σ::Float64`, the noise strengh.
"""
Expand Down
8 changes: 4 additions & 4 deletions src/CBO/CBO_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function compute_CBO_update!(
particle_dynamic_cache::ParticleDynamicCache,
m::Int,
) where {TF, TCorrection}
@expand particle_dynamic_cache D N X dX Δt root2Δt
@expand particle_dynamic_cache D N X dX Δt root_2Δt
@expand method correction λ σ
@expand method_cache consensus consensus_energy distance energy

Expand All @@ -59,7 +59,7 @@ function compute_CBO_update!(
λ *
(consensus[m][d] - X[m][n][d]) *
correction(energy[m][n] - consensus_energy[m]) +
root2Δt * σ * distance[m][n] * randn()
root_2Δt * σ * distance[m][n] * randn()
end
end
return nothing
Expand All @@ -72,7 +72,7 @@ function compute_CBO_update!(
particle_dynamic_cache::ParticleDynamicCache,
m::Int,
) where {TF, TCorrection}
@expand particle_dynamic_cache D N X dX Δt root2Δt
@expand particle_dynamic_cache D N X dX Δt root_2Δt
@expand method correction λ σ
@expand method_cache consensus consensus_energy energy

Expand All @@ -81,7 +81,7 @@ function compute_CBO_update!(
dX[m][n][d] =
(consensus[m][d] - X[m][n][d]) * (
Δt * λ * correction(energy[m][n] - consensus_energy[m]) +
root2Δt * σ * randn()
root_2Δt * σ * randn()
)
end
end
Expand Down
16 changes: 8 additions & 8 deletions src/CBO/is_method_pending.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function is_method_pending(
method::ConsensusBasedOptimisation,
method_cache::ConsensusBasedOptimisationCache,
method::CBXMethod,
method_cache::CBXMethodCache,
particle_dynamic::ParticleDynamic,
particle_dynamic_cache::ParticleDynamicCache,
m::Int,
Expand Down Expand Up @@ -39,8 +39,8 @@ function is_method_pending(
end

function is_method_pending_energy_threshold(
method::ConsensusBasedOptimisation,
method_cache::ConsensusBasedOptimisationCache,
method::CBXMethod,
method_cache::CBXMethodCache,
particle_dynamic::ParticleDynamic,
particle_dynamic_cache::ParticleDynamicCache,
m::Int,
Expand All @@ -50,8 +50,8 @@ function is_method_pending_energy_threshold(
end

function is_method_pending_energy_tolerance(
method::ConsensusBasedOptimisation,
method_cache::ConsensusBasedOptimisationCache,
method::CBXMethod,
method_cache::CBXMethodCache,
particle_dynamic::ParticleDynamic,
particle_dynamic_cache::ParticleDynamicCache,
m::Int,
Expand All @@ -62,8 +62,8 @@ function is_method_pending_energy_tolerance(
end

function is_method_pending_max_evaluations(
method::ConsensusBasedOptimisation,
method_cache::ConsensusBasedOptimisationCache,
method::CBXMethod,
method_cache::CBXMethodCache,
particle_dynamic::ParticleDynamic,
particle_dynamic_cache::ParticleDynamicCache,
m::Int,
Expand Down