Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bayesian quantification prototype #28

Merged
merged 2 commits into from
Mar 15, 2024

Conversation

pawel-czyz
Copy link

Hi Alex,

In #27 we discussed two threads: other solvers for ACC/PCC (such as BBSE or invariant ratio estimators) and Bayesian quantification.

This PR sketches how the Bayesian quantification estimator could look like. Instead of solving an equation (by explicit inversion or optimization), it uses a Markov chain Monte Carlo NUTS sampler to find all prevalence vector values $P_\text{test}(Y)$ compatible with the observed distribution of classifier predictions.
More formally, we are doing Bayesian inference in the following model:
image
image

where $\pi'$ is the latent variable modelling $P_\text{test}(Y)$ and $\phi$ is a latent variable modelling $P(C\mid Y)$ matrix (where $C$ is the classifier prediction and $Y$ is the true label).

@AlexMoreo AlexMoreo marked this pull request as ready for review March 15, 2024 13:29
@AlexMoreo AlexMoreo merged commit 2cc4908 into HLT-ISTI:devel Mar 15, 2024
@AlexMoreo
Copy link
Contributor

Just merged.

The method seems to be working like a charm. I noticed its a bit slow in test, though, although this can be due to some cuda errors jax is reporting after init...

Nice contribution thanks!

@pawel-czyz
Copy link
Author

Hi Alex,

Thanks for merging! Regarding:

I noticed its a bit slow in test, though, although this can be due to some cuda errors jax is reporting after init...

you are definitely right that it's slower than matrix-inversion methods. It samples from the posterior using MCMC, rather than doing matrix inversion. With the current settings it usually needs a few seconds on my laptop to run. For unit tests decreasing the number of collected samples (both in the warmup and the proper sampling phase of the MCMC) may speed the runtime, but there are trade-offs between the exploration of the posterior (and obtaining good estimate of the mean and credible intervals from the samples) and the time... The current settings seemed to work quite well for a range of problems I studied, but I guess they'll generally be suboptimal (e.g., in problems with tens or hundreds of categories I expect that many more samples should be needed).

The CUDA errors JAX is reporting should be harmless: if CUDA is available on a machine (and the CUDA-compatible version of JAX is installed), then the sampling should be a bit faster (and the errors you see are most likely the JAX's observations that it's running on CPU, rather than CUDA), but I generally use a CPU-only machine for running it.

Best wishes,
Pawel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants