In [13]:
import numpy as np

def relu(x):
    return np.maximum(x, 0)
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
def tanh(x):
    return np.tanh(x)
def linear(x):
    return x
def softmax(x):
    exps = np.exp(x - x.max())
    return exps / np.sum(exps)

def g_relu(x):
    return 1 * (x > 0)
def g_sigmoid(x):
    return (1 - x) * x
def g_tanh(x):
    return 1 - x*x
def g_linear(x):
    return 1 * (x==x)
def g_softmax(x):
    dx_ds = np.diag(x) - np.dot(x, x.T)
    return dx_ds.sum(axis=0).reshape(-1, 1) 

class Activation:
    
    types = ("LINEAR", "RELU", "SIGMOID", "TANH", "SOFTMAX")
    
    def __init__(self, acti):
        funcs = {
            "TANH" : tanh,
            "SIGMOID" : sigmoid,
            "RELU" : relu,
            "LINEAR" : linear,
            "SOFTMAX" : softmax
        }

        grads = {
            "TANH" : g_tanh,
            "SIGMOID" : g_sigmoid,
            "RELU" : g_relu,
            "LINEAR" : g_linear,
            "SOFTMAX" : g_softmax
        }
        self.acti = acti
        self.func = funcs[acti]
        self.grad = grads[acti]
        
        return       
    
    def __str__(self):
        s = "\nActivation:" + self.acti
        return s

    def __test__():
       a = np.arange(-8, 8).reshape(4,4)
       for i in Activation.types:
           act = Activation(i)
           b = act.func(a)
           print( str(b) )

if __name__ == "__main__" and '__file__' not in globals():
    Activation.__test__()


[[-8 -7 -6 -5]
 [-4 -3 -2 -1]
 [ 0  1  2  3]
 [ 4  5  6  7]]
[[0 0 0 0]
 [0 0 0 0]
 [0 1 2 3]
 [4 5 6 7]]
[[  3.35350130e-04   9.11051194e-04   2.47262316e-03   6.69285092e-03]
 [  1.79862100e-02   4.74258732e-02   1.19202922e-01   2.68941421e-01]
 [  5.00000000e-01   7.31058579e-01   8.80797078e-01   9.52574127e-01]
 [  9.82013790e-01   9.93307149e-01   9.97527377e-01   9.99088949e-01]]
[[-0.99999977 -0.99999834 -0.99998771 -0.9999092 ]
 [-0.9993293  -0.99505475 -0.96402758 -0.76159416]
 [ 0.          0.76159416  0.96402758  0.99505475]
 [ 0.9993293   0.9999092   0.99998771  0.99999834]]
[[  1.93367168e-07   5.25626458e-07   1.42880085e-06   3.88388338e-06]
 [  1.05574896e-05   2.86982322e-05   7.80098831e-05   2.12052848e-04]
 [  5.76419403e-04   1.56687039e-03   4.25919530e-03   1.15776932e-02]
 [  3.14714330e-02   8.55482245e-02   2.32544184e-01   6.32120630e-01]]
