In [8]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def sigmoid_derivative(x):
    return sigmoid(x) * (1 - sigmoid(x))

def tanh(x):
    return np.tanh(x)

def tanh_derivative(x):
    return 1 - np.tanh(x)**2

def relu(x):
    return np.maximum(0, x)

def relu_derivative(x):
    return np.where(x > 0, 1, 0)

def parametric_relu(x, alpha=0.01):
    return np.maximum(alpha * x, x)

def parametric_relu_derivative(x, alpha=0.01):
    return np.where(x > 0, 1, alpha)

x_values = np.linspace(-5, 5, 100)


sigmoid_trace = go.Scatter(x=x_values, y=sigmoid(x_values), name='Sigmoid', mode='lines')
sigmoid_derivative_trace = go.Scatter(x=x_values, y=sigmoid_derivative(x_values), name='Sigmoid Derivative', mode='lines')

tanh_trace = go.Scatter(x=x_values, y=tanh(x_values), name='Tanh', mode='lines')
tanh_derivative_trace = go.Scatter(x=x_values, y=tanh_derivative(x_values), name='Tanh Derivative', mode='lines')

relu_trace = go.Scatter(x=x_values, y=relu(x_values), name='ReLU', mode='lines')
relu_derivative_trace = go.Scatter(x=x_values, y=relu_derivative(x_values), name='ReLU Derivative', mode='lines')

initial_alpha = 0.05
parametric_relu_trace = go.Scatter(x=x_values, y=parametric_relu(x_values, alpha=initial_alpha), name=f'Parametric ReLU (alpha = {initial_alpha})', mode='lines')
parametric_relu_derivative_trace = go.Scatter(x=x_values, y=parametric_relu_derivative(x_values, alpha=initial_alpha), name='Parametric ReLU Derivative', mode='lines')

fig = make_subplots(rows=1, cols=1)

fig.add_trace(sigmoid_trace)
fig.add_trace(sigmoid_derivative_trace)
fig.add_trace(tanh_trace)
fig.add_trace(tanh_derivative_trace)
fig.add_trace(relu_trace)
fig.add_trace(relu_derivative_trace)
fig.add_trace(parametric_relu_trace)
fig.add_trace(parametric_relu_derivative_trace)


fig.update_xaxes(title_text="x")
fig.update_yaxes(title_text="f(x)")
fig.update_layout(title_text="Activation functions and their derivatives", title_x=0.4)

buttons = [
    dict(label="All", method="update", args=[{"visible": [True, False, True, False, True, False, True, False]}]),
    dict(label="Sigmoid", method="update", args=[{"visible": [True, True, False, False, False, False, False, False]}]),
    dict(label="Tanh", method="update", args=[{"visible": [False, False, True, True, False, False, False, False]}]),
    dict(label="ReLU", method="update", args=[{"visible": [False, False, False, False, True, True, False, False]}]),
    dict(label="Parametric ReLU", method="update", args=[{"visible": [False, False, False, False, False, False, True, True]}]),
]

fig.update_layout(
    updatemenus=[
        {
            "buttons": buttons,
            "direction": "down",
            "pad": {"r": 10, "t": 10},  
            "showactive": True,
            "x": 0.33,
            "xanchor": "left",
            "y": 1.163,
            "yanchor": "top"
        },
    ]
)

fig.show()