Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Substantial updates to tutorial 01_gaussian-mixture-model #439

Closed
wants to merge 16 commits into from

Conversation

JasonPekos
Copy link
Member

@JasonPekos JasonPekos commented Apr 10, 2024

First PR! Hope everything is ok.

As discussed in the slack, this PR adds the following significant changes to this tutorial

  • Covers the use of ordered() from Bijectors.jl in making the model identifiable (currently it is multimodal, and the seed is just lucky.)
  • Introduces a (MUCH faster) version of the model with assignments marginalized out with Turing.@addlogprob!
  • A very simple version of the marginalized model using only ~ MixtureModel(dists, weights)
  • Example of recovering marginalized assignment draws with generated_quantities()

There are also a few minor changes:

  • Changed chain plots from grouping chains together to grouping parameters together. I think this makes the multimodality discussion much clearer.
  • Added a tiny amount of burnin to the MCMC sampler, as really far-out initialization was making the plots hard to read
  • Added one more chain to n-chains to make multimodality more likely
  • Added and reworked tests for the models (now tests rhat to make sure the bijector discussions are all correct)

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I made a few comments 🙂

Where we sum the components with `logsumexp` from the [`StatsFuns.jl` package](https://github.com/JuliaStats/StatsFuns.jl).


The manually incremented likelihood can be added to the log-probability with `Turing.@addlogprob!`, giving us the following model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should not recommend the use of Turing.@addlogprob! in it's so easy to misuse and to get (silently) wrong results because it operates completely outside of the ~ logic in Turing/DynamicPPL. Instead I think usually one should use ~ with a (possibly custom) distribution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I initially wasn't going to include that section for basically the reasons you bring up, but I ended up including it (even though I don't actually sample from that model) to motivate what's going on with the MixtureModel lpdf.

I can replace it with a custom distribution (although this might be a little long for a model that's really just exposition), or omit it entirely.

Now, re-running our model, we can see that the assigned means are consistent across chains:

```julia
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's keep the tutorial simple and avoid surprising warnings in singlethreaded environments:

Suggested change
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn);
chains = sample(model, sampler, nsamples, nchains; discard_initial = burn);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually un-resolving this because I don't think it works? As it is right now, I'm not sure if Turing allows multiple chains without specifying a type of parallelism.

The documentation, if it's current, seems to suggest I should do something like:

chains = mapreduce(c -> sample(model_fun, sampler, 1000), chainscat, 1:num_chains)

I'm not sure if that's worth it just to get rid of the warning — let me know what you think though.

Comment on lines 372 to 376
# Return sample_class(yi) for fixed μ, w.
function sample_class(xi)
lvec = [(logpdf(d, xi) + log(w[i])) for (i, d) in enumerate(dists)]
rand(Categorical(exp.(lvec .- logsumexp(lvec))))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be defined outside of the model and probably use softmax or softmax! directly.

JasonPekos and others added 12 commits April 10, 2024 12:30
fix sample call

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
remove use of MCMCThread()

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Remove Bijectors import

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@JasonPekos JasonPekos mentioned this pull request Apr 17, 2024
@JasonPekos
Copy link
Member Author

I think maybe this should be closed and revisited when #441 is done?

fwiw the current thing that's keeping this frozen is the multithreading stuff. If we want to stay away from:

chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn);

to avoid warnings in single threaded environments, we'll need to update a bunch of tutorials, because this is pretty common across all the tutorials.

@yebai
Copy link
Member

yebai commented May 23, 2024

Thanks, @JasonPekos, for the PR. Would you like to migrate your changes here to #441?

@JasonPekos
Copy link
Member Author

Thanks, @JasonPekos, for the PR. Would you like to migrate your changes here to #441?

Yup, will do.

@JasonPekos JasonPekos closed this May 23, 2024
@JasonPekos JasonPekos mentioned this pull request May 24, 2024
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants