# Customization

## Customize the activation function

Modification based on [Ref](https://medium.com/@chinesh4/custom-activation-function-in-tensorflow-for-deep-neural-networks-from-scratch-tutorial-b12e00652e24)

In [None]:
customize activation    

#work/python/keras-1

#work/DNN/Comparison-activation

In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

### Example 1

[Gaussin Radial Basis Function](https://en.wikipedia.org/wiki/Radial_basis_function) is defined as
$$
\phi(x) = e^{-(\epsilon r)^2}
$$

We found out that radial basis function works pretty good as an activation in a network, this will be demonstrated in the following tutorials. The raial basis function only makes a difference in an interval close to the center which is similar to the idea of "on-and-off" of an activation function.

In a dimension 1 space, let us set activation to be $\phi(x)=e^{-x^2}$. Since this is not in TensorFlow's list of activation functions https://www.tensorflow.org/api_docs/python/tf/keras/activations, we will use it for demonstration of customizing activation function.

In [2]:
#######################################
#define the activation function
def rbf(x):
    return tf.math.exp(-x**2)

#######################################
#define the derivative of the activation function
def d_rbf(x):
    return tf.gradients(rbf,x)

#######################################
#we couldn't use “d_rbf” as an activation function if we wanted to 
#because tensorflow doesn't know how to calculate the gradients of that function.
def rbf_grad(op, grad):
    x = op.inputs[0]
    n_gr = d_rbf(x)    #defining the gradient.
    return grad * n_gr

def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
    # Need to generate a unique name to avoid duplicates:
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+2))
    tf.RegisterGradient(rnd_name)(grad)
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
    
def tf_rbf(x,name=None):
    with tf.name_scope(name, "rbf", [x]) as name:
        y = py_func(rbf,   #forward pass function
                    [x],
                    [tf.float32],
                    name=name,
                    grad= rbf_grad) #the function that overrides gradient
        y[0].set_shape(x.get_shape())     #when using with the code, it is used to specify the rank of the input.
    return y[0]

In [3]:
## Training data
np.random.seed(1)
x_train = np.linspace(-40,40,20)
y_train = 0.3*x_train**2 + np.random.normal(0, 1, len(x_train)) 

In [4]:
model_rbf = tf.keras.Sequential()
model_rbf.add(tf.keras.layers.Dense(3,activation=rbf))     #no need to put quote around rbf
model_rbf.add(tf.keras.layers.Dense(1))

model_rbf.compile(loss='mean_squared_error', optimizer=tf.keras.optimizers.Adam(0.1))
model_rbf.fit(x_train,y_train, epochs=1000, verbose=0)

# model.summary()
model_rbf.get_weights()   #test if model is trained successfully

[array([[-2.6122294e-03,  1.6906145e-03, -2.2995217e+00]], dtype=float32),
 array([-4.2092835e-04,  2.8977118e-04,  4.8410983e+00], dtype=float32),
 array([[  57.36145 ],
        [  59.288876],
        [-115.397705]], dtype=float32),
 array([55.409958], dtype=float32)]

## Cutomize the loss function 

In [9]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(3,activation=rbf))     #no need to put quote around rbf
model.add(tf.keras.layers.Dense(1))

In [12]:
# customize loss
def custom_loss(ytrue,ypred):
    val = tf.math.reduce_mean(tf.math.square((ytrue-ypred)/10))
    return val

In [16]:
model.compile(loss=custom_loss,optimizer=tf.keras.optimizers.Adam(0.01))
model.fit(x_train,y_train,epochs=100,verbose=0)

model.get_weights()    #test if model is trained successfully

[array([[ 4.8343139e-05, -1.8820183e-01,  5.0299735e-05]], dtype=float32),
 array([1.9036455e-05, 5.2888556e-03, 2.9459905e-05], dtype=float32),
 array([[ 5.030023 ],
        [-4.4802723],
        [ 4.8218493]], dtype=float32),
 array([3.9862227], dtype=float32)]