Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Implement Binarized Neural Networks from http://arxiv.org/pdf/1602.02…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sathish Nagappan authored and Jennifer Myers committed Jul 13, 2016
1 parent bf71b9d commit caf0aaa
Show file tree
Hide file tree
Showing 18 changed files with 1,050 additions and 66 deletions.
14 changes: 14 additions & 0 deletions examples/binary/README.md
@@ -0,0 +1,14 @@
## Model

This is an implementation of a Binarized Neural Network trained on the MNIST dataset.

### Instructions
```
python binary/train.py -e 20
```
## Citation
```
Binarized Neural Networks: Training Neural Networks with Weights and Activations Constrained to +1 or -1
http://arxiv.org/pdf/1602.02830v3.pdf
```

90 changes: 90 additions & 0 deletions examples/binary/train.py
@@ -0,0 +1,90 @@
#!/usr/bin/env python
# ----------------------------------------------------------------------------
# Copyright 2016 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
"""
Trains BinaryNet on MNIST dataset.
Reference:
"Binarized Neural Networks: Training Neural Networks with Weights and
Activations Constrained to +1 or -1"
http://arxiv.org/pdf/1602.02830v3.pdf
Usage:
python binary/train.py -e 20
"""

from neon.callbacks.callbacks import Callbacks
from neon.data import ArrayIterator, load_mnist
from neon.initializers import Uniform
from neon.layers import BinaryAffine, GeneralizedCost
from neon.models import Model
from neon.optimizers import MultiOptimizer, ShiftAdaMax, ShiftSchedule
from neon.transforms import Identity, Misclassification, Sign, SquareHingeLoss
from neon.util.argparser import NeonArgparser


# parse the command line arguments
parser = NeonArgparser(__doc__)

args = parser.parse_args()

# load up the mnist data set
# split into train and tests sets
(X_train, y_train), (X_test, y_test), nclass = load_mnist(path=args.data_dir)

# setup a training set iterator
train_set = ArrayIterator(X_train, y_train, nclass=nclass, lshape=(1, 28, 28))
# setup a validation data set iterator
valid_set = ArrayIterator(X_test, y_test, nclass=nclass, lshape=(1, 28, 28))

# setup weight initialization function
init = Uniform(-1, 1)

# setup layers
layers = [
BinaryAffine(nout=4096, init=init, batch_norm=True, activation=Sign()),
BinaryAffine(nout=4096, init=init, batch_norm=True, activation=Sign()),
BinaryAffine(nout=4096, init=init, batch_norm=True, activation=Sign()),
BinaryAffine(nout=10, init=init, batch_norm=True, activation=Identity())
]

# setup cost function as Square Hinge Loss
cost = GeneralizedCost(costfunc=SquareHingeLoss())

# setup optimizer
LR_start = 1.65e-2


def ShiftAdaMax_with_Scale(LR=1):
return ShiftAdaMax(learning_rate=LR_start * LR, schedule=ShiftSchedule(2, shift_size=1))


optimizer = MultiOptimizer({
'default': ShiftAdaMax_with_Scale(),
'BinaryLinear_0': ShiftAdaMax_with_Scale(57.038),
'BinaryLinear_1': ShiftAdaMax_with_Scale(73.9008),
'BinaryLinear_2': ShiftAdaMax_with_Scale(73.9008),
'BinaryLinear_3': ShiftAdaMax_with_Scale(52.3195)
})

# initialize model object
bnn = Model(layers=layers)

# configure callbacks
callbacks = Callbacks(bnn, eval_set=valid_set, **args.callback_args)

# run fit
bnn.fit(train_set, optimizer=optimizer, num_epochs=args.epochs, cost=cost, callbacks=callbacks)
print('Misclassification error = %.1f%%' % (bnn.eval(valid_set, metric=Misclassification())*100))
77 changes: 72 additions & 5 deletions neon/backends/backend.py
Expand Up @@ -33,9 +33,9 @@ class OpCollection(object):
zero_operand_ops = {"rand", "onehot"}
unary_ops = {"finite", "neg", "abs", "sgn", "sqrt", "sqr", "exp", "log",
"exp2", "log2", "sig", "sig2", "tanh", "tanh2", "transpose",
"safelog"}
"safelog", "rint", "binarize"}
binary_ops = {"assign", "add", "sub", "mul", "div", "eq", "ne", "lt", "le",
"gt", "ge", "pow", "minimum", "maximum", "dot"}
"gt", "ge", "pow", "minimum", "maximum", "dot", "shift"}
reduction_ops = {"sum", "max", "min", "argmax", "argmin"}
float_ops = zero_operand_ops | unary_ops | binary_ops
ew_ops = float_ops - {'dot', 'transpose'}
Expand Down Expand Up @@ -697,6 +697,18 @@ def dot(self, a, b, out=None):
"""
return OpTreeNode.build("dot", a, b, out=out)

def xnor_compound_dot(self, A, B, C, beta=0.0):
"""
Performs XNOR GEMM
C = A * B
Arguments:
A (Tensor): left-hand side operand.
B (Tensor): right-hand side operand.
C (Tensor): output operand
"""
raise NotImplementedError()

def add(self, a, b, out=None):
"""
Perform element-wise addition on the operands, storing the resultant
Expand Down Expand Up @@ -1069,6 +1081,35 @@ def finite(self, a, out=None):
"""
return OpTreeNode.build("finite", a, None, out=out)

def rint(self, a, out=None):
"""
Perform element-wise rounding to nearest int.
Arguments:
a (Tensor): input to be transformed.
out (Tensor, optional): where the result will be stored. If out is
None, only the op-tree will be returned.
Returns:
OpTreeNode: the resulting op-tree
"""
return OpTreeNode.build("rint", a, None, out=out)

def binarize(self, a, stochastic=True, out=None):
"""
Perform element-wise binarization.
Arguments:
a (Tensor): input to be transformed.
stochastic (Bool, optional): stochastic or deterministic
out (Tensor, optional): where the result will be stored. If out is
None, only the op-tree will be returned.
Returns:
OpTreeNode: the resulting op-tree
"""
return OpTreeNode.build("binarize", a, None, stochastic=stochastic, out=out)

def equal(self, a, b, out=None):
"""
Performs element-wise equality testing on each element of left and
Expand Down Expand Up @@ -1205,6 +1246,25 @@ def minimum(self, a, b, out=None):
"""
return OpTreeNode.build("minimum", a, b, out=out)

def shift(self, a, b, value=True, out=None):
"""
Performs element-wise shift based on corresponding elements of left
and right, storing the result in out. Positive is left shift, and
negative is right shift. Each operand is assumed to be the same shape
(or broadcastable as such).
Arguments:
a (Tensor, numeric): left-hand side operand.
b (Tensor, numeric): right-hand side operand.
value (int): shift by value or exponent
out (Tensor, optional): where the result will be stored. If out is
None, only the op-tree will be returned.
Returns:
OpTreeNode: the resulting op-tree
"""
return OpTreeNode.build("shift", a, b, value=value, out=out)

def clip(self, a, a_min, a_max, out=None):
"""
Performs element-wise clipping of Tensor `a`, storing the result in out.
Expand Down Expand Up @@ -1359,7 +1419,7 @@ def mean(self, a, axis=None, partial=None, out=None, keepdims=True):
return self.multiply(self.sum(a), 1.0 / (shape[0] * shape[1]), out=out)
return self.multiply(self.sum(a, axis=axis), 1.0 / shape[axis], out=out)

def var(self, a, axis=None, partial=None, out=None, keepdims=True):
def var(self, a, axis=None, partial=None, out=None, keepdims=True, binary=False):
"""
Calculates the variance of the elements along the specified
axes.
Expand All @@ -1379,9 +1439,16 @@ def var(self, a, axis=None, partial=None, out=None, keepdims=True):
Returns:
OpTreeNode: the resulting op-tree
"""
if binary:
def self_shift(x):
return self.shift(x, x)
op = self_shift
else:
op = self.square

if axis is None:
return self.mean(self.square(a - self.mean(a)), out=out)
return self.mean(self.square(a - self.mean(a, axis=axis)), axis=axis, out=out)
return self.mean(op(a - self.mean(a)), out=out)
return self.mean(op(a - self.mean(a, axis=axis)), axis=axis, out=out)

def std(self, a, axis=None, partial=None, out=None, keepdims=True):
"""
Expand Down

0 comments on commit caf0aaa

Please sign in to comment.