In [1]:
import tensorflow as tf
from tensorflow.python.framework import ops # for gradient
from tensorflow.python.ops import gen_nn_ops # compute gradient

- op: 학습되는 train_op(output)
- gradient > 0인 값들에 grad 산출

- decorator(@ops.RegisterGradient): 함수(_GuidedReluGrad)를 입력으로 받아 특정 함수(function)를 반환함
- 아래의 function은 사용자가 직접 정의한 op로 'Relu'와 바꿀 op가 됨

```python
def ops.RegisterGradient("GuidedRelu"):
    def function()...
        # something
        _GuidedReluGrad(op, grad):
        # something
        return function            
```

In [3]:
tf.reset_default_graph()

In [8]:
feature_out = tf.constant([[1,-1,5],[2,-5,-7],[-3,2,4]], dtype=tf.float32, name ='features')
feature_out

<tf.Tensor 'features:0' shape=(3, 3) dtype=float32>

In [9]:
grad = tf.constant([[-2,3,-1],[6,-3,1],[2,-1,3]], dtype=tf.float32, name ='gradients')
grad

<tf.Tensor 'gradients:0' shape=(3, 3) dtype=float32>

In [12]:
# feature output > 0
BackpropRelu = gen_nn_ops.relu_grad(grad, feature_out)

# grad > 0
DeconvRelu = tf.where(0. < grad, grad, tf.zeros(grad.get_shape()))

# (feature output > 0 & grad > 0)
GuidedReluGrad  = tf.where(0. < grad, gen_nn_ops.relu_grad(grad, feature_out), tf.zeros(grad.get_shape()))

In [17]:
sess = tf.InteractiveSession()
print('BackpropRelu: \n',BackpropRelu.eval())
print('DeconvRelu: \n',DeconvRelu.eval())
print('GuidedReluGrad: \n',GuidedReluGrad.eval())

BackpropRelu: 
 [[-2.  0. -1.]
 [ 6. -0.  0.]
 [ 0. -1.  3.]]
DeconvRelu: 
 [[0. 3. 0.]
 [6. 0. 1.]
 [2. 0. 3.]]
GuidedReluGrad: 
 [[0. 0. 0.]
 [6. 0. 0.]
 [0. 0. 3.]]


```python
@ops.RegisterGradient("BackpropRelu")
def _BackpropRelu(unused_op, grad):
    return gen_nn_ops.relu_grad(grad, unused_op.outputs[0])

@ops.RegisterGradient("DeconvRelu")
def _DeconvRelu(unused_op, grad):
    return tf.where(0. < grad, grad, tf.zeros(tf.shape(grad)))

@ops.RegisterGradient("GuidedRelu")
def _GuidedReluGrad(unused_op, grad):
    return tf.where(0. < grad, gen_nn_ops.relu_grad(grad, unused_op.outputs[0]),
                    tf.zeros(tf.shape(grad)))

```

```python
    g = tf.get_default_graph()
    with g.gradient_override_map({"Relu": "<method>"}):
        ...graph...
```