Let's look at some of the popular activation functions and their effects on the training. 
Its important to choose good activation functions.

In [6]:
# Standard libraries
import os
import json
import math
import numpy as np
from typing import Any, Sequence
import pickle
from copy import deepcopy

## import plotting libraries
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf')
import seaborn as sns
sns.set()

## progress bar
from tqdm.auto import tqdm

## Jax
import jax
import jax.numpy as jnp
from jax import random

## Flax
import flax
from flax import linen as nn
from flax.training import train_state, checkpoints

## optimization using optax
import optax

  set_matplotlib_formats('svg', 'pdf')


In [4]:
## dataset path: for storing all the downloaded datasets. This prevents duplicate downloads.
DATASET_PATH = os.path.abspath("data/")
CHECKPOINT_PATH = os.path.abspath("saved_models/activation_functions/")

# check the device we will be using.
print(f'device: {jax.devices()[0]}')


device: TFRT_CPU_0


In [5]:
# lets download some of the pretrained models that will be used in this notebook
import urllib.request # for downloading url based files.
from urllib.error import HTTPError # Captures HTTP errors

# github link to the saved models 
base_url = 'https://raw.githubusercontent.com/phlippe/saved_models/main/JAX/tutorial3/'

# files to download
pretrained_files = [
    "FashionMNIST_elu.config", "FashionMNIST_elu.tar",
    "FashionMNIST_leakyrelu.config", "FashionMNIST_leakyrelu.tar",
    "FashionMNIST_relu.config", "FashionMNIST_relu.tar",
    "FashionMNIST_sigmoid.config", "FashionMNIST_sigmoid.tar",
    "FashionMNIST_swish.config", "FashionMNIST_swish.tar",
    "FashionMNIST_tanh.config", "FashionMNIST_tanh.tar"
]

# create checkpoint directory if it doesn't exist
os.makedirs(CHECKPOINT_PATH, exist_ok= True)

# for each file, check whether it already exists. if not, try downloading it.

for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print('contact author')

## Common Activation functions:

Let's implement common activation functions present in jax.nn or flax.linen by ourselves for better understanding. <br>

All the activation functions in nn.linen are in nn.Module so that they can be integrated into a neural network. <br>

- sigmoid (nn.sigmoid)
- tanh (nn.tanh)

In [10]:
################################

class Sigmoid(nn.Module):

    def __call__(self, x):
        return 1/(1+ jnp.exp(-x))

###################################

class Tanh(nn.Module):
    def __call__(self, x):
        x_exp, neg_x_exp = jnp.exp(x), jnp.exp(-x)
        return (x_exp-neg_x_exp)/(x_exp+ neg_x_exp)

class ReLU(nn.Module):

    def __call__(self, x):
        return jnp.maximum(x, 0)

class LeakyReLU(nn.Module):

    alpha: float = 0.1

    def __call__(self, x):
        return jnp.where(x>0, x, self.alpha*x)

class ELU(nn.Module):

    def __call__(self, x):
        return jnp.where(x>0, x, jnp.exp(x)-1)

class Swish(nn.Module):
    def __call__(self, x):
        return x*nn.sigmoid(x)

act_fn_by_name = {
    'sigmoid': Sigmoid,
    'tanh': Tanh,
    'relu': ReLU,
    'leakyrelu': LeakyReLU,
    'elu': ELU,
    'swish': Swish,
}