An unofficial implementation of Bayesian Flow Networks (BFNs) for discrete variables in JAX.
BFNs are a new class of generative models that share some philosophy of diffusion models: they both try to model a complex probability distribution by iteratively learning the data distribution under various levels of corruption.1 In diffusion models, a neural network acts on the space of the data itself, e.g. (in the reverse process) taking as input the noisy value of every pixel and outputting a (hopefully less noisy) estimate for every pixel.2 In BFNs, the neural network acts on the space of parametrised probability distributions for each factorised component of the data, e.g. each pixel intensity is parametrised by a mean and standard deviation of a Gaussian distribution, while inputs and outputs of the neural network are the estimated means and standard deviations for each pixel.
Whereas in the reverse process diffusion models start with an image that consists of pure noise, BFNs start with a uniform prior over the individual parameters of each pixel's probability distribution. In each step during training, the model gets to view a corrupted version of each pixel (with the level of corruption set by the noise schedule), and the pixel parameters are updated according to the rules of Bayesian inference. The neural network then sees all pixel distributions simultaneously and gets another go at updating the parameters of the pixel distributions (which is how pixel correlations get learnt). These steps repeat until barely any noise is being added to the true image, much like in diffusion models. Conceptually, with BFN there is no need to have in mind a forward diffusion process whose reverse we are trying to match: we are just starting with prior beliefs about parameters, then updating our beliefs during the "reverse" process according to a combination of Bayesian inference and a neural network.
Acting on the parameters of a factorised probability distribution allows a consistent framework for modelling both continuous and discrete variables (and discretised continuous variables in the middle!). In one case the parameters are the means and standard deviations of Gaussian distributions, and in the other case the parameters are the logits of categorical distributions. On both cases the parameters are real numbers. Corrupting data can always be interpreted the same way: smoothing out each probability distribution through convolution with a noise kernel, then sampling from the resulting distribution. Hence for discrete variables, there is no need to define a Markov transition kernel or map to a continuous embedding space to diffuse.
It also turns out that in all cases training is just maximising the log-likelihood of the data through the evidence lower bound (ELBO), without any auxiliary losses needed.3
Furthermore, there are no restrictions placed on the architecture of the neural network because all it has to do is take as input a tensor of shape (num_params, length) and output a tensor of equal size.
When modelling discrete variables such as text tokens, many transformers already accept one-hot encoded tensors and output logits with that exact shape, so minimal modifications are needed to get up and running.
The BFN paper quotes a bits-per-character score on the text8 dataset better than other discrete diffusion models like Multinomial Diffusion and D3PM.4
The Bayesian Flow Network preprint is quite heavy on the setup needed to derive closed-form expressions for the loss and sampling procedures, but the final expressions and pseudocode are comparatively simple to implement.
Below are some notebooks that interactively demonstrate some concepts in the paper.
I found it easiest to play around with this repository with pip installing the package in editable mode:
git clone https://github.com/ElisR/BFN.git
cd BFN
pip install -e .- Loss function and sampling for discrete distributions.
- Simple discrete training example notebook.
- Basic tests for discrete case.
- Bayesian flow visualisation for discrete distribution.
- Bayesian flow visualisation for continuous distribution.
- Loss function and sampling for continuous probability distribution.
- Create a working
pyproject.toml.
Footnotes
-
The success of diffusion models shows that this is seemingly easier than trying to learn the data distribution directly. ↩
-
That estimate is often reparametrised as an estimate for the noise added to a clean image at that time-step. ↩
-
In the paper they first present this ELBO as the expected number of bits required for Alice, who has access to the true data, to transmit it to Bob according to the BFN scheme described above. In this interpretation Alice sends latent variables—increasingly revealing noisy observations of the true data—to Bob, who continually updates his posterior belief of the factorised distribution according to Bayesian inference and a neural network. The estimate for the number of bits assumes that Alice sends latent variables and finally the true data according to an efficient bits-back encoding scheme. ↩
-
More parameters seem to have been used than for the quoted D3PM result, so more direct comparisons would be nice. ↩