Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add state_from_transition, parameters and setparameters!! #86

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
178 changes: 178 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,181 @@ For chains of this type, AbstractMCMC defines the following two methods.
AbstractMCMC.chainscat
AbstractMCMC.chainsstack
```

## Interacting with states of samplers

To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods:
```@docs
AbstractMCMC.values
AbstractMCMC.setvalues!!
```
and optionally
```@docs
AbstractMCMC.updatestate!!(state, transition, state_prev)
```
These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers.

### Example: `MixtureSampler`

In a `MixtureSampler` we need two things:
- `components`: collection of samplers.
- `weights`: collection of weights representing the probability of chosing the corresponding sampler.

```julia
struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler
components::C
weights::W
end
```

To implement the state, we need to keep track of a couple of things:
- `index`: the index of the sampler used in this `step`.
- `transition`: the transition resulting from this `step`.
- `states`: the current states of _all_ the components.
Two aspects of this might seem a bit strange:
1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously.
2. We need to put the `transition` from the `step` into the state.

The reason for (1) is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc.

For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. you might have a sampler which is _independent_ of the current realizations and the state is simply `nothing`.

Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`.
```julia
struct MixtureState{T,S}
index::Int
transition::T
states::S
end
```
The `step` for a `MixtureSampler` is defined by the following generative process
```math
\begin{aligned}
i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\
X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1})
\end{aligned}
```
where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler.
[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler.

If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code:

```julia
# Update the corresponding state, i.e. `state.states[i]`, using
# the state and transition from the previous iteration.
state_current = AbstractMCMC.updatestate!!(
state.states[i], state.states[i_prev], state.transition
)

# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model, sampler_current, sampler_state;
kwargs...
)
```

The full [`AbstractMCMC.step`](@ref) implementation would then be something like:

```julia
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...)
# Sample the component to use in this `step`.
i = rand(Categorical(sampler.weights))
sampler_current = sampler.components[i]

# Update the corresponding state, i.e. `state.states[i]`, using
# the state and transition from the previous iteration.
i_prev = state.index
state_current = AbstractMCMC.updatestate!!(
state.states[i], state.states[i_prev], state.transition
)

# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model, sampler_current, state_current;
kwargs...
)

# Create the new states.
# NOTE: Code below will result in `states_new` being a `Vector`.
# If we wanted to allow usage of alternative containers, e.g. `Tuple`,
# it would be better to use something like `@set states[i] = state_current`
# where `@set` is from Setfield.jl.
states_new = map(1:length(state.states)) do j
if j == i
# Replace the i-th state with the new one.
state_current
else
# Otherwise we just carry over the previous ones.
state.states[j]
end
end

# Create the new `MixtureState`.
state_new = MixtureState(i, transition, states_new)

return transition, state_new
end
```

And for the initial [`AbstractMCMC.step`](@ref) we have:

```julia
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...)
# Initialize every state.
transitions_and_states = map(sampler.components) do spl
AbstractMCMC.step(rng, model, spl; kwargs...)
end

# Sample the component to use this `step`.
i = rand(Categorical(sampler.weights))
# Extract the corresponding transition.
transition = first(transitions_and_states[i])
# Extract states.
states = map(last, transitions_and_states)
# Create new `MixtureState`.
state = MixtureState(i, transition, states)

return transition, state
end
```

Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `values` and `setvalues!!`:

```julia
function AbstractMCMC.updatestate!!(::AdvancedMH.Transition, state_prev::AdvancedMH.Transition)
# Let's `deepcopy` just to be certain.
return deepcopy(state_prev)
end
```

To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do

```julia
sampler = MixtureSampler([sampler1, sampler2], [0.1, 0.9])
transition, state = AbstractMCMC.step(rng, model, sampler)
while ...
transition, state = AbstractMCMC.step(rng, model, sampler, state)
end
```

As a final note, there is one potential issue we haven't really addressed in the above implementation: a lot of samplers have their own implementations of `AbstractMCMC.AbstractModel` which means that we would also have to ensure that all the different samplers we are using would be compatible with the same model. A very easy way to fix this would be to just add a struct called `ManyModels` supporting `getindex`, e.g. `models[i]` would return the i-th `model`:

```julia
struct ManyModels{M} <: AbstractMCMC.AbstractModel
models::M
end

Base.getindex(model::ManyModels, I...) = model.models[I...]
```

Then the above `step` would just extract the `model` corresponding to the current sampler:

```julia
# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model[i], sampler_current, state_current;
kwargs...
)
```

This issue should eventually disappear as the community moves towards a unified approach to implement `AbstractMCMC.AbstractModel`.
28 changes: 28 additions & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,34 @@
"""
struct MCMCSerial <: AbstractMCMCEnsemble end

"""
updatestate!!(state, transition_prev[, state_prev])

Return new instance of `state` using information from `transition_prev` and, optionally, `state_prev`.

Defaults to `setvalues!!(state, values(transition_prev))`.
"""
updatestate!!(state, transition_prev, state_prev) = updatestate!!(state, transition_prev)
updatestate!!(state, transition) = setvalues!!(state, Base.values(transition))

Check warning on line 90 in src/AbstractMCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractMCMC.jl#L89-L90

Added lines #L89 - L90 were not covered by tests

"""
setvalues!!(state, values)

Update the values of the `state` with `values` and return it.

If `state` can be updated in-place, it is expected that this function returns `state` with updated
values. Otherwise a new `state` object with the new `values` is returned.
"""
function setvalues!! end

@doc """
values(transition)

Return values in `transition`.
"""
Base.values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to use Base.values here? I think adding this documentation should be considered type piracy and would only be allowed if we define it only for our own types as arguments (which we can't do here).

Unrelated, why did you use @doc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to use Base.values here? I think adding this documentation should be considered type piracy and would only be allowed if we define it only for our own types as arguments (which we can't do here).

Hmm, I do agree with this. I'm coming around to values not being the best idea.

Unrelated, why did you use @doc?

IIRC you need @doc here to be able to add the docstring? I.e. if you remove the @doc it won't be added to the docstring of Base.values.



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

include("samplingstats.jl")
include("logging.jl")
include("interface.jl")
Expand Down