In [None]:
from typing import Callable, List
import matplotlib.pyplot as plt

import torch
from torch import nn, Tensor, optim
import torch.nn.functional as F


class ActivationFunction(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = self.__class__.__name__
        self.config = {'name': self.name}


class Sigmoid(ActivationFunction):
    def forward(self, x: Tensor) -> Tensor:
        return 1 / (1 + torch.exp(-x))
    

class Tanh(ActivationFunction):
    def forward(self, x: Tensor) -> Tensor:
        x_exp, neg_x_exp = torch.exp(x), torch.exp(-x)
        return (x_exp - neg_x_exp) / (x_exp + neg_x_exp)


class ReLU(ActivationFunction):  
    def forward(self, x: Tensor) -> Tensor:
        return x * (x > 0).float()

    
act_fn_by_name = {
    'sigmoid': Sigmoid,
    'tanh': Tanh,
    'relu': ReLU,
}


def get_grads(act_fn: nn.Module, x: Tensor) -> Tensor:
    x = x.clone().requires_grad_()
    out = act_fn(x)
    out.sum().backward()
    return x.grad


def vis_act_fn(act_fn: Callable, ax: List, x: Tensor) -> None:
    y = act_fn(x)
    y_grads = get_grads(act_fn, x)
    
    x, y, y_grads = x.cpu().numpy(), y.cpu().numpy(), y_grads.cpu().numpy()
    
    ax.plot(x, y, linewidth=2, label='Activation Function')
    ax.plot(x, y_grads, linewidth=2, label='Gradient')
    ax.set_title(act_fn.name)
    ax.legend()
    ax.set_ylim(-1.5, x.max())


act_fns = [act_fn() for act_fn in act_fn_by_name.values()]
x = torch.linspace(-5, 5, 1000)

fig, axes = plt.subplots(1, 3, figsize=(9, 3))
for i, act_fn in enumerate(act_fns):
    vis_act_fn(act_fn, axes.flat[i], x)

fig.subplots_adjust(hspace=0.3)
fig.tight_layout()
plt.show()