Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Radial BNN #12

Open
silasbrack opened this issue Feb 8, 2022 · 2 comments
Open

Implementing Radial BNN #12

silasbrack opened this issue Feb 8, 2022 · 2 comments

Comments

@silasbrack
Copy link

Hi,

I’m trying to fit a radial BNN posterior variational approximation as per this paper.

However, since I’ll be training a BNN, I don’t want to have to write a custom guide and define this variational approximation for all of my layers, and so was trying to implement a custom AutoGuide which automatically puts a radial BNN approximation on all of my weights.

The radial approximation is defined as follows:
image
where I just need to sample all epsilon_MFVI from an independent standard normal distribution, normalize them, and multiply them by r, which is a scalar sampled from a standard normal.

How could I go about implementing this in TyXe?
Is there a smarter way of implementing this variational approximation?

P.S. Big fan of this project!

Thanks in advance.

@hpplyt
Copy link
Collaborator

hpplyt commented Feb 9, 2022

Hi,
this should overall be doable. Essentially you'd need to

  1. implement a RadialNormal distribution class that inherits from pyro.distributions.Distribution. The easiest approach would probably be to inherit from pyro.distributions.Normal and overwrite the rsample and log_prob methods, as far as I remember the parameterization is the same as for a Normal distribution, so you can inherit the boilerplate code.
  2. implement a corresponding autoguide class. If you just want something quick and dirty that works, you should be able to subclass pyro.infer.autoguide.AutoNormal, copy-paste their forward method and change the line where they instantiate the dist.Normal to use your RadialNormal instead.

And then the AutoRadial guide would (hopefully :-) ) work in place of an AutoNormal guide as in the examples.


(2) is unfortunately a bit ugly, ideally we'd have some kind of autoguide factory class in tyxe that can generate autoguides for a given distribution to make adding custom distribution easier. I'll give this some more thought when I get the chance.

As an additional note on (1), you might want to also implement the KL divergence between the RadialNormal and a Normal distribution (I think that's what they use as a prior in the paper). For that you need to implement the kl as a decorated method like:

@torch.distributions.kl.register_kl(RadialNormal, dist.Normal)
def _kl(q, p):
    ...

where q is a RadialNormal and p a Normal object.

Sorry this is all a bit more involved than it should be, but I hope this helps. If you need any more details or if I overlooked any issues, let me know. And if you have a go at an implementation, feel free to link your repo/a gist here, I'm happy to take a look at it.

@silasbrack
Copy link
Author

Hey, sorry for the delay; thanks a lot for the help!

I managed to implement a script for running VI with BNNs with this radial approximation with your tips.
Fortunately, the KL divergence between the radial posterior and a normal prior is the same as for the mean-field (up to a constant), so I didn't bother updating the calculation of the KL divergence, just the sampling.

I've actually had great results with this posterior. In general, it seems to me that the mean-field approximation often struggles to converge to an accurate solution and the radial posterior consistently seems to outperform both mean-field and low-rank approximations.

Feel free to take a look at it in https://github.com/silasbrack/approximate-inference-for-bayesian-neural-networks/blob/main/src/guides/radial.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants