In [143]:
import sys
import os
import numpy as np

In [144]:
class ReLu():
    def __init__(self):
        self.mask = None

    def __call__(self, x: np.ndarray) -> np.ndarray:
        """ forward
        self.mask : boolean_mask_ where(x>0, True)
        """
        self.mask  = [x > 0]
        dims = list(x.shape)
        out = np.zeros(dims)
        out[x > 0] = x[x > 0]
        return out

    def backward(self, grad: np.ndarray) -> np.ndarray:
        mask = self.mask[0]
        dx = grad * mask
        return dx

In [149]:
# Create data
batch, channel, dim = np.arange(2,5)
x = np.random.randn(batch, channel, dim)
# ReLu forward
relu = ReLu()
out = relu(x)
print('-----------------------  x  ---------------------\n',x)
print()
print('--------------------- output --------------------\n',out)
print()
# Create demo gradient
grad = np.random.randn(batch, channel, dim)
dx = relu.backward(grad)
print('----------------------- dx ----------------------\n', dx)

-----------------------  x  ---------------------
 [[[ 0.34361829 -1.76304016  0.32408397 -0.38508228]
  [-0.676922    0.61167629  1.03099952  0.93128012]
  [-0.83921752 -0.30921238  0.33126343  0.97554513]]

 [[-0.47917424 -0.18565898 -1.10633497 -1.19620662]
  [ 0.81252582  1.35624003 -0.07201012  1.0035329 ]
  [ 0.36163603 -0.64511975  0.36139561  1.53803657]]]

--------------------- output --------------------
 [[[0.34361829 0.         0.32408397 0.        ]
  [0.         0.61167629 1.03099952 0.93128012]
  [0.         0.         0.33126343 0.97554513]]

 [[0.         0.         0.         0.        ]
  [0.81252582 1.35624003 0.         1.0035329 ]
  [0.36163603 0.         0.36139561 1.53803657]]]

----------------------- dx ----------------------
 [[[-0.03582604  0.         -2.6197451   0.        ]
  [ 0.         -0.29900735  0.09176078 -1.98756891]
  [-0.          0.          1.47789404 -0.51827022]]

 [[-0.         -0.          0.          0.        ]
  [-0.5297602   0.51326743 