Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
71 lines (47 sloc) 2.13 KB
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import keras.backend as K
def round_through(x):
'''Element-wise rounding to the closest integer with full gradient propagation.
A trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182)
'''
rounded = K.round(x)
return x + K.stop_gradient(rounded - x)
def _hard_sigmoid(x):
'''Hard sigmoid different from the more conventional form (see definition of K.hard_sigmoid).
# Reference:
- [BinaryNet: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1, Courbariaux et al. 2016](http://arxiv.org/abs/1602.02830}
'''
x = (0.5 * x) + 0.5
return K.clip(x, 0, 1)
def binary_sigmoid(x):
'''Binary hard sigmoid for training binarized neural network.
# Reference:
- [BinaryNet: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1, Courbariaux et al. 2016](http://arxiv.org/abs/1602.02830}
'''
return round_through(_hard_sigmoid(x))
def binary_tanh(x):
'''Binary hard sigmoid for training binarized neural network.
The neurons' activations binarization function
It behaves like the sign function during forward propagation
And like:
hard_tanh(x) = 2 * _hard_sigmoid(x) - 1
clear gradient when |x| > 1 during back propagation
# Reference:
- [BinaryNet: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1, Courbariaux et al. 2016](http://arxiv.org/abs/1602.02830}
'''
return 2 * round_through(_hard_sigmoid(x)) - 1
def binarize(W, H=1):
'''The weights' binarization function,
# Reference:
- [BinaryNet: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1, Courbariaux et al. 2016](http://arxiv.org/abs/1602.02830}
'''
# [-H, H] -> -H or H
Wb = H * binary_tanh(W / H)
return Wb
def _mean_abs(x, axis=None, keepdims=False):
return K.stop_gradient(K.mean(K.abs(x), axis=axis, keepdims=keepdims))
def xnorize(W, H=1., axis=None, keepdims=False):
Wb = binarize(W, H)
Wa = _mean_abs(W, axis, keepdims)
return Wa, Wb