### Activation Functions
In this dashboard we'll look at the plots of different activation functions and their gradients.  
The list of activation functions is dynamically generated by introspecting the [keras activations](https://keras.io/api/layers/activations/#available-activations) module. For each activation function, the UI for additional parameters (for e.g. `relu`) is also dynamically generated by introspecting the function.

In [None]:
import inspect
from collections import OrderedDict

import tensorflow as tf

import ipywidgets as w
import bqplot.pyplot as plt

In [None]:
from tensorflow.keras import activations

def get_activations():
    """introspect keras API and get activation funcs"""
    activations_dict = {}
    for name, obj in inspect.getmembers(activations):
        if inspect.isfunction(obj):
            args = list(inspect.signature(obj).parameters.keys())
            if "x" in args:
                activations_dict[name] = obj

    return activations_dict

In [None]:
WIDGETS_MAP = {float: w.FloatText, int: w.IntText, bool: w.Checkbox}
textbox_layout = w.Layout(width="180px")

class KeywordArgsWidget(w.Box):
    """
    automatic keyword args UI for an object (class or func)
    """
    def __init__(self, obj, orientation="horizontal"):
        self.obj = obj
        self.orientation = orientation
        self.widgets_layout = w.Box()

        self.param_wids = {}
        if self.orientation == "horizontal":
            self.widgets_layout = w.HBox()
        else:
            self.widgets_layout = w.VBox()
        self._build_widgets()
        super(KeywordArgsWidget, self).__init__(children=[self.widgets_layout])
    
    def _build_widgets(self):
        if self.obj:
            params = inspect.signature(self.obj).parameters

            self.widgets = OrderedDict({param_name: 
                WIDGETS_MAP.get(type(param.default), w.FloatText)(
                    description=param_name, 
                    layout=textbox_layout,
                    value=param.default
                )
                for param_name, param in params.items()
                if param.default is not inspect._empty})
            self.widgets_layout.children = list(self.widgets.values())
    
    def get_param_values(self):
        params = {k: v.value for k, v in self.widgets.items() if v.value}
        return params

In [None]:
activations_dict = get_activations()
# remove softmax since it takes 2D tensors
del activations_dict["softmax"]

x = tf.range(-10, 10, delta=.1)

In [None]:
# create keyword args widget for each activation func
activations_param_widgets = {
    func_name: KeywordArgsWidget(func)
    for func_name, func in activations_dict.items()
}

activations_dd = w.Dropdown(description="Activation", 
                            value="relu",
                            options=activations_dict.keys())
param_widget_placeholder = w.Box()

fig_layout = w.Layout(width="700px", height="500px")
fig_args = dict(animation_duration=750, layout=fig_layout)

activation_fig = plt.figure(**fig_args)
activation_plot = plt.plot(x.numpy(), [], "m")

grad_fig = plt.figure(**fig_args)
grad_plot = plt.plot(x.numpy(), [], "y")

def update(*args):
    selected_activation = activations_dd.value
    activation_func = activations_dict[selected_activation]
    param_widget = activations_param_widgets[selected_activation]
    param_widget_placeholder.children = [param_widget]
    
    with tf.GradientTape() as tape:
        tape.watch(x)

        # get the activation func keyword param values if any
        param_vals = param_widget.get_param_values()
        # evaluate activatation func on x
        y = activation_func(x, **param_vals)

    # compute gradient of activations
    dy_dx = tape.gradient(y, x)

    # update plots
    activation_fig.title = f"{selected_activation}(x)"
    grad_fig.title = f"{selected_activation}\'(x)"
    activation_plot.y = y.numpy()
    grad_plot.y = dy_dx.numpy()
    
# register callbacks
activations_dd.observe(update)

# redraw plots whenever activation func params are updated
for keyword_arg_widget in activations_param_widgets.values():
    for widget in keyword_arg_widget.widgets.values():
        widget.observe(update)

update(None)

w.VBox([w.HBox([activations_dd, param_widget_placeholder]), 
        w.HBox([activation_fig, grad_fig])])