In [1]:
import mxnet as mx
import numpy as np

In [2]:
class Binarize(mx.operator.CustomOp):
    def forward(self, is_train, req, in_data, out_data, aux):
        """Implements forward computation.

        is_train : bool, whether forwarding for training or testing.
        req : list of {'null', 'write', 'inplace', 'add'}, how to assign to out_data. 'null' means skip assignment, etc.
        in_data : list of NDArray, input data.
        out_data : list of NDArray, pre-allocated output buffers.
        aux : list of NDArray, mutable auxiliary states. Usually not used.
        """
        _x = in_data[0].asnumpy()
        x = _x.copy()
        # Just to make sure that the results are only in -1 and 1.
        x[x == 0] = -1
        y = np.sign(x) * np.mean(np.abs(_x))
        print(y)
        self.assign(out_data[0], req[0], mx.nd.array(y))

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        """Implements backward computation

        req : list of {'null', 'write', 'inplace', 'add'}, how to assign to in_grad
        out_grad : list of NDArray, gradient w.r.t. output data.
        in_grad : list of NDArray, gradient w.r.t. input data. This is the output buffer.
        """
        x = in_data[0].asnumpy()
        dx = out_grad[0].asnumpy()
        dx[x <= -1] = 0
        dx[x >= 1] = 0
        self.assign(in_grad[0], req[0], mx.nd.array(dx))
        
@mx.operator.register("binarize")
class BinarizeProp(mx.operator.CustomOpProp):
    def __init__(self):
        super(BinarizeProp, self).__init__(True)

    def create_operator(self, ctx, in_shapes, in_dtypes):
        #  create and return the CustomOp class.
        return Binarize()

In [3]:
class BinaryBlock(mx.gluon.Block):
    def __init__(self, **kwargs):
        super(BinaryBlock, self).__init__(**kwargs)
    
    def forward(self, x):
        return mx.nd.Custom(x, op_type='binarize')

In [4]:
binarize = BinaryBlock()
binarize.initialize()
x = mx.nd.uniform(shape=(4, 3))
y = binarize(x)
print(y)

[[0.63595426 0.63595426 0.63595426]
 [0.63595426 0.63595426 0.63595426]
 [0.63595426 0.63595426 0.63595426]
 [0.63595426 0.63595426 0.63595426]]

[[0.63595426 0.63595426 0.63595426]
 [0.63595426 0.63595426 0.63595426]
 [0.63595426 0.63595426 0.63595426]
 [0.63595426 0.63595426 0.63595426]]
<NDArray 4x3 @cpu(0)>
