Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving interaction between different implementations of the interface #85

Open
torfjelde opened this issue Oct 17, 2021 · 7 comments
Labels
enhancement New feature or request

Comments

@torfjelde
Copy link
Member

torfjelde commented Oct 17, 2021

How do you feel about adding something like:

"""
    state_from_transiton(state, transition_prev[, state_prev])

Return new instance of `state` using information from `transition_prev` and, optionally, `state_prev`.

Defaults to `setparameters!!(state, parameters(transition_prev))`.
"""
function state_from_transition(state, transition_prev, state_prev)
    return state_from_transition(state, transition_prev)
end

function state_from_transition(state, transition)
    return setparameters!!(state, parameters(transition))
end

"""
    setparameters!!(state, parameters)

Return new instance of `state` with parameters set to `parameters`.
"""
setparameters!!

"""
    parameters(transition)

Return parameters in `transition`.
"""
parameters

to make it easier for samplers to interact across packages? Then you just need to implement state_from_transition for the different types to get cross-package compat.

Another issue is also the model argument which often is specific to a particular sampler implementation. We've previous spoken about generalizing this so that we don't have a bunch of these lying around, e.g. AdvancedMH.DensityModel and AdvancedHMC.DifferentiableDensityModel, but we should also maybe add a function getmodel(model, sampler, state) or something too, which is identity by default but allows one to provide a model-type which encodes a bunch of different models for specific samplers, e.g. in the case of a MixtureSampler you might have a MixtureState which, among other things, holds the current sampler-index, and a ManyModels which simply wraps a collection of models corresponding to each of the components/samplers in the MixtureSampler:

getmodel(model::ManyModels, sampler::MixtureSampler, state::MixtureState) = model[state.current_index]

I've been running into quite a few scenarios recently where I'd love to have something like this, e.g. wanting to implement MixtureSampler and CompositionSampler.

@torfjelde torfjelde added the enhancement New feature or request label Oct 17, 2021
@torfjelde
Copy link
Member Author

@devmotion @cpfiffer @yebai thoughts?

@cpfiffer
Copy link
Member

Yep, I 100% love both of these -- I think it's a bit of a shortcoming in our interface methods that we push the parameter stuff to the side. This seems minimal and unintrusive enough to be a good fit for AbstractMCMC.

@devmotion
Copy link
Member

We already have implemented some similar functions for working with transitions in Turing, eg. metadata, getparams, getlogp, getlogevidence. I think it could be useful to move some of them upstream.

@devmotion
Copy link
Member

I am less sure about get_model though. It seems orthogonal to the problem with multiple disconnected model types and quite specific for the mixture example? It also doesn't seem very scalable to implement such a function for more general combinations of states, samplers, and models.

Generally, I think it would be helpful to be more honest about the supported model types (possibly reusing model types such as the discussed DensityModel) in the implementations or, even better and more scalable if possible, only use functions of a generic yet to be added interface for models such as eg. loglikelihood etc.

@torfjelde
Copy link
Member Author

Awesome!

Generally, I think it would be helpful to be more honest about the supported model types (possibly reusing model types such as the discussed DensityModel) in the implementations or, even better and more scalable if possible, only use functions of a generic yet to be added interface for models such as eg. loglikelihood etc.

I 100% agree with this, but we had some issues reaching a consensus the last time we discussed this, but maybe we can now 👍

How about this direction (I'm trying to do something similar to what I think you proposed before @devmotion ):

struct DensityModel{F} <: AbstractModel
    logdensity::F
end

logdensity(model::DensityModel, args...) = model.logdensity(args...)

"""
    Differentiable{N}

Represents N-th order differentiability.
"""
struct Differentiable{N} end
const NonDifferentiable = Differentiable{0}
const FirstOrderDifferentiable = Differentiable{1}
const SecondOrderDifferentiable = Differentiable{2}

function Base.:+(::Differentiable{N1}, ::Differentiable{N2}) where {N1,N2}
    return Differentiable{min(N1,N2)}()
end


"""
    differentiable(model)

Return an instance of `Differentiable{N}`, where `N` represents the order.
"""
differentiable(model::DensityModel) = differentiable(model.logdensity)

"""
    PosteriorModel

Represents a model which can be decomposed into a prior and a likelihood.
"""
struct PosteriorModel{P1,P2} <: AbstractModel
    logprior::P1
    loglikelihood::P2
end

logprior(model::PosteriorModel, args...) = model.logprior(args...)
loglikelihood(model::PosteriorModel, args...) = model.loglikelihood(args...)
logdensity(model::PosteriorModel, args...) = logprior(model, args...) + loglikelihood(model, args...)

function differentiable(model::PosteriorModel)
    return differentiable(model.logprior) + differentiable(model.loglikelihood)
end

?

Then we can also add (but in a different package; maybe Bijectors.jl itself or Turing.jl):

struct TransformedModel{M,B} <: AbstractMCMC.AbstractModel
    model::M
    transform::B
end

function AbstractMCMC.logdensity(tmodel::TransformedModel, y)
    x, logjac = forward(tmodel.transform, y)
    return AbstractMCMC.logdensity(tmodel.model, x) + logjac
end

function AbstractMCMC.differentiable(tmodel::TransformedModel)
    return AbstractMCMC.differentiable(tmodel.model)
end

And then things would "just work".

We might also want some of the following methods (though implementations should go somewhere else):

  • domain(model): return some notion of whether a model expects inputs in a particular domain.
    • Only issue is that we can't really say much about what to expect for a return-value here.
    • Should also go together with a hasdomain to indicate whether it has this method implemented.
  • length/size/etc.: returns the properties of the variables used in the model (this kind of goes under domain if it could be handled nicely).
    • This doesn't make sense for every model, and so we might need something similar to the iterator traits, e.g. HasLength, HasSize, etc. But I also don't like this because we'll end up with a lot of "maybe" existing methods 😕

@cpfiffer
Copy link
Member

I LOVE this sketch. It's super minimal, and I think it's flexible enough to meet a bunch of downstream needs. I'm happy with putting it in AbstractMCMC since it touches so few things.

@torfjelde
Copy link
Member Author

torfjelde commented Apr 17, 2024

Thinking about this again due to work on Gibbs sampler for Turing.jl (TuringLang/Turing.jl#2099)

I think we need the following from a AbstractSampler to do "most" interesting "meta"-samplers:

  1. Ability to get and set parameters and logprob.
  2. Ability to go from the state of one sampler, say, SamplerA, to the state of another sampler, say, SamplerB.
  3. Ability to re-evaluate the logprob as the sampler would.
    • Crucial part is the "as the sampler would", e.g. might need to recompute gradients too.

In short, we need something like:

# Needs to be implemented on a case-by-case basis.
function params_and_logprob(sampler, state)
    # TODO: implement
end

function set_params_and_logprob!!(sampler, state, params, logprob)
    # TODO: implement
end

# Default get and set.
params(sampler, state) = first(params_and_logprob(sampler, state))
logprob(sampler, state) = last(params_and_logprob(sampler, state))

function setparams!!(sampler, state, params)
    return set_params_and_logprob!!(sampler, state, params, logprob(sampler, state))
end
function setlogprob!!(sampler, state, logprob)
    return set_params_and_logprob!!(sampler, state, params(sampler, state), logprob)
end

# Default implementation.
function state_from(model_dst, sampler_dst, sampler_src, state_dst, state_src)
    # Extract parameters and logprob from the source sampler.
    params_src, lp_src = getparams_and_logprob(sampler_src, state_src)
    # Set the parameters and logprob in the destination sampler.
    return setparams_and_logprob!!(state_dst, params_src, lp_src)
end

function state_from_with_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src)
    # Extract parameters from the source sampler.
    params_src = getparams(sampler_src, state_src)
    # Set the parameters and logprob in the destination sampler.
    state_dst = setparams!!(state_dst, params_src)
    # Re-evaluate the log density of the destination model.
    return recompute_logprob!!(model_dst, sampler_dst, state_dst)
end

# Default implementation.
function recompute_logprob!!(model::AbstractMCMC.LogDensityModel, sampler, state)
    # Extract parameters and logprob from the source sampler.
    params = getparams(sampler, state)
    lp = LogDensityProblems.logdensity(model.logdensity, params)
    return setlogprob!!(state, lp)
end

For example, if we want compositions of samplers, we can do that as:

function composition_step(
    rng::Random.AbstractRNG,
    model_outer,
    model_inner,
    sampler_outer,
    sampler_inner,
    state_outer,
    state_inner;
    kwargs...
)
    # Take a step with the inner model.
    transition_inner, state_inner = AbstractMCMC.step(
        rng,
        model_inner,
        sampler_inner,
        state_inner;
        kwargs...
    )

    # Update the outer state from the inner state.
    state_outer = if composition_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src)
        state_from_with_recompute_logprob(model_outer, sampler_outer, sampler_inner, state_outer, state_inner)
    else
        state_from(model_outer, sampler_outer, sampler_inner, state_outer, state_inner)
    end

    # Take a step with the outer sampler.
    transition_outer, state_outer = AbstractMCMC.step(
        rng,
        model_outer,
        sampler_outer,
        state_outer;
        kwargs...
    )

    return (transition_inner, transition_outer), (state_inner, state_outer)
end

Another example is Gibbs sampling, though this only requires recompute_logprob!!. Implementing a Gibbs sampler can be boiled down to the following step function:

function gibbs_step(
    rng::Random.AbstractRNG,
    model_dst,
    sampler_dst,
    sampler_src,
    state_dst,
    state_src;
    kwargs...
)
    # `model_dst` might be different here, e.g. conditioned on new values, so we need to check if need to recompute the log-probability.
    if gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src)
        # Re-evaluate the log density of the destination model.
        state_dst = recompute_logprob!!(model_dst, sampler_dst, state_dst, logprob_dst)
    end

    # Step!
    return AbstractMCMC.step(rng, model_dst, sampler_dst, state_dst; kwargs...)
end

EDIT: Currently giving this recompute_logprob!! a go in TuringLang/Turing.jl#2099.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants