In [16]:
# Internal
from pathlib import Path
import math

# External
import numpy as np
import torch
import torch.nn as nn
import httpx

import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

# project Configs
DATASET_PATH = Path("./data")
CHECKPOINT_PATH = Path("./models")
torch.backends.cudnn.benchmark=False
torch.backends.cudnn.deterministic=True
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
pio.templates.default = "plotly_dark"

In [2]:
def set_seed(seed=42):
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

In [3]:
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/"
# Files to download
pretrained_files = ["FashionMNIST_elu.config", "FashionMNIST_elu.tar",
                    "FashionMNIST_leakyrelu.config", "FashionMNIST_leakyrelu.tar",
                    "FashionMNIST_relu.config", "FashionMNIST_relu.tar",
                    "FashionMNIST_sigmoid.config", "FashionMNIST_sigmoid.tar",
                    "FashionMNIST_swish.config", "FashionMNIST_swish.tar",
                    "FashionMNIST_tanh.config", "FashionMNIST_tanh.tar"]

In [4]:
if not DATASET_PATH.exists():
  DATASET_PATH.mkdir(parents=True, exist_ok=True)

for file_path in pretrained_files:
  file_url = base_url+file_path
  file_path = DATASET_PATH / file_path
  response = httpx.get(str(file_url))
  with open(file_path, "wb") as f:
    f.write(response.content)

In [5]:
class ActivationFunction(nn.Module):

  def __init__(self) -> None:
    super().__init__()
    self.name = self.__class__.__name__
    self.config = {"name":self.name}

In [6]:
class Sigmoid(ActivationFunction):
  def forward(self,x):
    return 1/(1+torch.exp(-x))

class Tanh(ActivationFunction):
  def forward(self,x):
    x_exp,neg_x_exp = torch.exp(x), torch.exp(-x)
    return (x_exp - neg_x_exp)/(x_exp + neg_x_exp)

In [7]:
class ReLU(ActivationFunction):
  def forward(self,x):
    return x * (x>0).float()

In [8]:
class LeakyReLU(ActivationFunction):
  def __init__(self,alpha=0.001) -> None:
    super().__init__()
    self.alpha = alpha

  def forward(self,x):
    return torch.where(x>0, x, self.alpha*x)


In [9]:
class Swish(ActivationFunction):
  def forward(self,x):
    return x*torch.sigmoid(x)

In [10]:

class ELU(ActivationFunction):

    def forward(self, x):
        return torch.where(x > 0, x, torch.exp(x)-1)

In [11]:
act_fn_by_name = {
    "sigmoid": Sigmoid,
    "tanh": Tanh,
    "relu": ReLU,
    "leakyrelu": LeakyReLU,
    "elu": ELU,
    "swish": Swish
}

In [23]:
def get_grads(act,x):
  x = x.clone().requires_grad_()
  out = act(x)
  out.sum().backward()
  return x.grad

In [28]:
def vis_activation(fig,act,x,row,col):
  y = act(x)
  # print(locals())
  grads = get_grads(act,x)
  x,y,grads = x.cpu().numpy(),y.cpu().numpy(),grads.cpu().numpy()
  fig.add_trace(go.Scatter(x=x,y=y,name="ActFn",mode="lines"),row=row,col=col)
  fig.add_trace(go.Scatter(x=x,y=grads,name="gradFn",mode="lines"),row=row,col=col)

In [29]:
act_fns = [act_fn() for act_fn in act_fn_by_name.values()]
x = torch.linspace(-5,5,1000)
rows = math.ceil(len(act_fns)/2)
fig = make_subplots(rows=rows,cols=2)
for i,act in enumerate(act_fns):
  row,col = i//2+1,i%2+1
  vis_activation(fig,act,x,row,col)
fig