Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending chain representations #57

Open
cscherrer opened this issue Feb 25, 2021 · 8 comments
Open

Extending chain representations #57

cscherrer opened this issue Feb 25, 2021 · 8 comments

Comments

@cscherrer
Copy link

Hi,

I had misunderstood some of the goals of this package, thankfully @cpfiffer got me straightened out. I'm looking into using this as an interface for populating SampleChains.

The docs seem mostly oriented toward people building new samplers, and not so much for people building new ways of representing chains. So I have lots of questions...

  1. I'd like to avoid getting all of the samples and then calling bundle_samples. It seems like I should be able to instead overload save!! and then have bundle_samples be a no-op. Is that right?
  2. Cameron mentioned a resume function that can pick up sampling where it left off, without needing to go back through the warmup phase. But I don't see it in this repo. Where can I find it?
  3. Where are things like per-sample diagnostics stored? Divergences for HMC, that sort of thing.
  4. Do you have any examples of using this with log-weighted samples? I need these for importance sampling, VI, etc.

I'm sure I'll have more to come, but this will get me started. Thanks :)

@cpfiffer
Copy link
Member

  1. Yeah, I think that's right.
  2. The resume stuff is in Turing proper -- it kind of implicitly assumes that the resume knowledge lives with the modelling language. We could add it in here, I think, but not sure what the form would take yet.
  3. Totally up to you -- store them in whatever internal struct as part of state or sample as needed (in this context)
  4. MCMCChains has a field for this but nothing we use that much. I'd just include it in state or sample and toss it off to the chain to handle as-needed.

@cscherrer
Copy link
Author

Great, thanks @cpfiffer .

The resume will change the state of an existing object, so I think it should have a !. I could just fall back on my current interface for this.

I think I'm still confused what you mean by "state". There are a few things other than samples that can be important. Some are fixed size, like

  • The iterator
  • Some characteristics that don't change, but apply overall. For example, the log-density function, or the mass matrix for HMC. These are fixed-size.
  • The log-density contribution of each scalar component of a sample. This comes into play for samplers like Gen.

And some that scale linearly with the number of samples:

  • Per-sample diagnostics, like divergences. These scale with the number of samples, but don't change the semantics of the samples.
  • Log-density values.
  • Log-weights. These affect the semantics, for example coming into play for computing expected values.

I think only the first group should be considered "state", and per-sample diagnostics should be separate from samples and state (I'm currently calling them "info", which I think IIRC I got from Turing somewhere). I was thinking of having separate logp and logweight fields as part of the interface, one for each sample.

@cscherrer
Copy link
Author

Hmm, and further complicating this is that state is not currently allowed to be saved:

samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

That could make this very tricky.

@cscherrer
Copy link
Author

I guess I could overload mcmcsample?

@devmotion
Copy link
Member

Sure, one can always roll a custom mcmcsample. The only downside is that you lose some of the default features. BTW if you want to have more control over the sampling procedure (e.g., computing statistics after every nth step, using convergence criteria) the iteration (or transducer) interface can be useful.

Regarding the points in the OP, I agree with what @cpfiffer said. resume is currently defined in DynamicPPL but there are a bunch of issues and discussions in which we suggested to move it to AbstractMCMC (and the possibility to specify initial samples as well). In general, the policy was to experiment with interface changes/extensions in downstream packages first before moving them to AbstractMCMC and enforcing them in all implementations. I imagined that the sample/mcmcsample methods should probably allow to specify an initial state, and then resume could be defined as

function resume(rng::Random.AbstractRNG, chain, args...; kwargs...)
    return sample(
        rng, getmodel(chain), getsampler(chain), args...;
        state=getstate(chain), kwargs...,
    )
end

@cscherrer
Copy link
Author

Thanks @devmotion , I was hoping to allow a convergence criteria as a stopping condition, so this is great.

There does seem to be an assumption that everything the user could need is required to be part of the sample. For DynamicHMC, my setup looks like this (AdvancedHMC will be very similar):

@concrete struct  DynamicHMCChain{T} <: AbstractChain{T}
    samples     # :: AbstractVector{T}
    logp        # log-density for distribution the sample was drawn from
    info        # Per-sample metadata, type depends on sampler used
    meta        # Metadata associated with the sample as a whole
    state       
    transform
end

Here

  • samples includes variables specified by the model:
julia> samples(chain)
100-element TupleVector with schema (x = Float64, σ = Float64)
(x = -0.1±0.34, σ = 0.576±0.37)
  • logp has the log-density information for each sample:
julia> logp(chain)[1:5]
5-element ElasticArrays.ElasticVector{Float64, 0, Vector{Float64}}:
 -1.136224195720376
 -0.42132266397402207
 -0.9789248604768969
 -1.136224195720376
 -0.8517859618293282
  • Many samplers will also use a logweights field

  • info has some diagnostic information:

julia> info(chain)[1:5]
5-element ElasticArrays.ElasticVector{DynamicHMC.TreeStatisticsNUTS, 0, Vector{DynamicHMC.TreeStatisticsNUTS}}:
 DynamicHMC.TreeStatisticsNUTS(-1.283461597663962, 3, turning at positions 6:9, 0.9703961901160859, 11, DynamicHMC.Directions(0xdfea943d))
 DynamicHMC.TreeStatisticsNUTS(-1.150959614787742, 1, turning at positions 2:3, 0.9646928495527286, 3, DynamicHMC.Directions(0x74715257))
 DynamicHMC.TreeStatisticsNUTS(-1.1699430991621091, 3, turning at positions 3:6, 1.0, 11, DynamicHMC.Directions(0x8472ffea))
 DynamicHMC.TreeStatisticsNUTS(-1.941965877904205, 1, turning at positions 2:3, 0.7405784149911505, 3, DynamicHMC.Directions(0x8c1d2457))
 DynamicHMC.TreeStatisticsNUTS(-1.3103584844087501, 2, turning at positions -2:1, 0.9999999999999999, 3, DynamicHMC.Directions(0x9483096d))
  • meta has information determined by the warmup phase, and will be different for each sampler:
julia> meta(chain).H
Hamiltonian with Gaussian kinetic energy (Diagonal), diag(M⁻¹): [1.1613920024118645, 0.7589536122573856]

julia> meta(chain).algorithm
DynamicHMC.NUTS{Val{:generalized}}(10, -1000.0, Val{:generalized}())

julia> meta(chain).ϵ
0.2634132789343616

julia> meta(chain).rng
Random._GLOBAL_RNG()
  • state contains the iterator state, and is assumed to not be accessed by the end user
  • transform is specific to HMC, and is just what you'd expect.

I guess I could cram my samples, logp, logweights, and info into your samples, as long as you don't assign any semantics to this. Then meta and transform would only be written after warmup, and our state fields would match up. Does that sound right?

@cpfiffer
Copy link
Member

Yeah, that should work. I think you could just dump them into a tuple or small wrapper struct when you return them as state.

@cscherrer
Copy link
Author

I can't return them as state, that would make them unavailable since save!! doesn't include state as an argument. I think everything needs to be in sample, then I can pull it apart after receiving it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants