Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about retrieving binary weights from the model #685

Open
utexasnew opened this issue Jul 14, 2021 · 2 comments
Open

Question about retrieving binary weights from the model #685

utexasnew opened this issue Jul 14, 2021 · 2 comments

Comments

@utexasnew
Copy link

utexasnew commented Jul 14, 2021

After building out the simple BNN from the following guide: https://docs.larq.dev/larq/tutorials/binarynet_cifar10/

I try retrieving binary weights to examine via https://docs.larq.dev/larq/guides/bnn-optimization/#retrieving-the-binary-weights

And I notice that despite the kernel quantization, I receive kernel values that are not entirely +1 and -1. For example, values such as 8.16999555e-01 and 3.77580225e-02 appear within the weight kernel.

Is there any intuitive explanation for this? Thank you!

@lgeiger
Copy link
Member

lgeiger commented Jul 15, 2021

Outside the quantize_context this is expected when training models with latent weights as explained in the docs you linked above since the weights are only binarized in the forward pass and stored as floating point values during training.

However, the following should return binarized weights for the binary convolutions:

with larq.context.quantized_scope(True):
    weights = model.get_weights()  # get binary weights

Although keep in mind that the model might include some full precision layers like batch norms that won't appear quantized in the Keras model but can be fused when deploying with larq compute engine.

@mervess
Copy link

mervess commented Sep 15, 2021

If the model was saved within the scope beforehand, the one below also works for me when loading the model or the weights.

import larq as lq    # this line is necessary now when loading the model even though lq won't be used.
weights = model.get_weights()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants