Reference: https://github.com/digantamisra98/Echo for implementation/equation of activation functions.
<br>Reference: https://echo-ai.readthedocs.io/en/latest/ for explanation of functions/parameters and graphs.

In [1]:
import torch
import torch.nn.functional as F

### Weighted Tanh

In [21]:
# Applies the weighted tanh function element-wise:
# weightedtanh(x) = tanh(x * weight)

def weighted_tanh(input, weight=1, inplace=False):
    if inplace is False:
        return torch.tanh(weight * input)
    else:
        input *= weight
        torch.tanh_(input)

### Swish

In [None]:
# Applies the Swish function element-wise
# Swish(x, beta) = x*sigmoid(beta*x) = frac{x}{(1+e^{-beta*x})}

def swish(input, beta=1.25):
    return input * torch.sigmoid(beta * input)

### ESwish

In [None]:
# Applies the E-Swish function element-wise
# ESwish(x, beta) = beta*x*sigmoid(x)

def eswish(input, beta=1.375):
    return beta * input * torch.sigmoid(input)

### Aria-2

In [2]:
# Applies the Aria-2 function element-wise
# Aria2(x, alpha, beta) = (1+e^{-beta*x})^{-alpha}

def aria2(input, beta=0.5, alpha=1):
    return torch.pow((1 + torch.exp(-beta * input)), -alpha)

### ELiSH

In [23]:
# Applies the ELiSH (Exponential Linear Sigmoid SquasHing) function element-wise
# ELiSH(x) =x / (1+e^{-x}), x >= 0  
#          =(e^{x} - 1) / (1 + e^{-x}), x < 0

def elish(input):
    return (input >= 0).float() * input * torch.sigmoid(input) + (input < 0).float() * (
        torch.exp(input) - 1
    ) / (torch.exp(-input) + 1)

### HardELiSH

In [1]:
# Applies the HardELiSH (Exponential Linear Sigmoid SquasHing) function element-wise
# HardELiSH(x) = \\left\\{\\begin{matrix} x \\times max(0, min(1, (x + 1) / 2)), x \\geq 0 \\\\ (e^{x} - 1)\\times max(0, min(1, (x + 1) / 2)), x < 0 \\end{matrix}\\right.

def hard_elish(input):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    return (input >= 0).float() * input * torch.max(
        torch.tensor(0.0, device=device),
        torch.min(torch.tensor(1.0, device=device), (input + 1.0) / 2.0),
    ) + (input < 0).float() * (
        (torch.exp(input) - 1)
        * torch.max(
            torch.tensor(0.0, device=device),
            torch.min(torch.tensor(1.0, device=device), (input + 1.0) / 2.0),
        )
    )

### Mila

In [5]:
# Applies the mila function element-wise
# mila(x) = x * tanh(softplus(beta + x)) = x * tanh(ln(1 + e^{beta + x}))

def mila(input, beta=-0.25):
    return input * torch.tanh(F.softplus(input + beta))

### SineReLU

In [6]:
# Applies the SineReLU activation function element-wise
# SineReLU(x, epsilon) = \\left\\{\\begin{matrix} x , x > 0 \\\\ epsilon * (sin(x) - cos(x)), x \\leq  0 \\end{matrix}\\right.

def sineReLU(input, eps=0.01):
    return (input > 0).float() * input + (input <= 0).float() * eps * (
        torch.sin(input) - torch.cos(input)
    )

### Flatten T-Swish

In [7]:
# Applies the FTS (Flatten T-Swish) activation function element-wise
# FTS(x) = \\left\\{\\begin{matrix} \\frac{x}{1 + e^{-x}} , x \\geq  0 \\\\ 0, x < 0 \\end{matrix}\\right.

def fts(input):
    return torch.clamp(input / (1 + torch.exp(-input)), min=0)

### SQNL

In [8]:
# Applies the SQNL activation function element-wise
# SQNL(x) = \\left\\{\\begin{matrix} 1, x > 2 \\\\ x - \\frac{x^2}{4}, 0 \\leq x \\leq 2 \\\\  x + \\frac{x^2}{4}, -2 \\leq x < 0 \\\\ -1, x < -2 \\end{matrix}\\right.

def sqnl(input):
    return (
        (input > 2).float()
        + (input - torch.pow(input, 2) / 4)
        * (input >= 0).float()
        * (input <= 2).float()
        + (input + torch.pow(input, 2) / 4)
        * (input < 0).float()
        * (input >= -2).float()
        - (input < -2).float()
    )

### ISRU

In [9]:
# Applies the ISRU function element-wise
# ISRU(x, alpha) = \\frac{x}{\\sqrt{1 + alpha * x^2}}

def isru(input, alpha=1.0):
    return input / (torch.sqrt(1 + alpha * torch.pow(input, 2)))

### ISRLU

In [10]:
# Applies the ISRLU function element-wise
# ISRLU(x, alpha)=\\left\\{\\begin{matrix} x, x\\geq 0 \\\\  x * (\\frac{1}{\\sqrt{1 + \\alpha*x^2}}), x <0 \\end{matrix}\\right.

def isrlu(input, alpha=1.0):
    return (input < 0).float() * isru(input, alpha) + (input >= 0).float() * input

### Bent's identity

In [11]:
# Applies the Bent's Identity function element-wise
# bentId(x) = x + \\frac{\\sqrt{x^{2}+1}-1}{2}

def bent_id(input):
    return input + ((torch.sqrt(torch.pow(input, 2) + 1) - 1) / 2)

### Soft Clipping

In [None]:
# Applies the Soft Clipping function element-wise
# SC(x) = 1 / \\alpha * log(\\frac{1 + e^{\\alpha * x}}{1 + e^{\\alpha * (x-1)}})

def soft_clipping(input, alpha=0.5):
    return (1 / alpha) * torch.log(
        (1 + torch.exp(alpha * input)) / (1 + torch.exp(alpha * (input - 1)))
    )

### SReLU

### BReLU

In [4]:
# Not sure
# BReLU is applied differently at even and odd indices
# BReLU(x_i) = \\left\\{\\begin{matrix} f(x_i), i \\mod 2 = 0\\\\  - f(-x_i), i \\mod 2 \\neq  0 \\end{matrix}\\right.

def brelu(input):
    # get lists of odd and even indices
    input_shape = input.shape[0]
    even_indices = [i for i in range(0, input_shape, 2)]
    odd_indices = [i for i in range(1, input_shape, 2)]

    # clone the input tensor
    output = input.clone()

    # apply ReLU to elements where i mod 2 == 0
    output[even_indices] = output[even_indices].clamp(min=0)

    # apply inversed ReLU to inversed elements where i mod 2 != 0
    output[odd_indices] = (
        0 - output[odd_indices]
    )  # reverse elements with odd indices
    output[odd_indices] = -output[odd_indices].clamp(min=0)  # apply reversed ReLU
        
    return output

### APL

In [5]:
# Not sure
# APL is applied differently at each index
# APL(x_i) = max(0,x) + \\sum_{s=1}^{S}{a_i^s * max(0, -x + b_i^s)}
def apl(input, S, a, b):
    
    output = input.clamp(min=0)
    
    for s in range(S):
        t = -input + torch.pow(b[s],s)
        output += torch.pow(a[s],s) * t.clamp(min=0)
        
    return output

### Soft Exponential

In [6]:
# Not sure
# Applies the soft exponential function element-wise
# SoftExponential(x, \\alpha) = \\left\\{\\begin{matrix} - \\frac{log(1 - \\alpha(x + \\alpha))}{\\alpha}, \\alpha < 0\\\\  x, \\alpha = 0\\\\  \\frac{e^{\\alpha * x} - 1}{\\alpha} + \\alpha, \\alpha > 0 \\end{matrix}\\right.

def softExp(input, alpha = 0.0):
    if alpha == 0.0:
        return input

    if alpha < 0.0:
        return -torch.log(1 - alpha * (input + alpha)) / alpha

    if alpha > 0.0:
        return (torch.exp(alpha * input) - 1) / alpha + alpha

### Maxout

### Mish

In [13]:
# Applies the Mish function element-wise
# beta_mish(x) = x * tanh(ln(1 + e^{x}))

def mish(input):
    return input * torch.tanh(torch.log(1 + torch.exp(input)))

### Beta Mish

In [11]:
# Applies the Beta Mish function element-wise
# beta_mish(x, beta) = x * tanh(ln((1 + e^{x})^{beta}))

def beta_mish(input, beta=1.5):
    return input * torch.tanh(torch.log(torch.pow((1 + torch.exp(input)), beta)))

### LeCun's Tanh

In [14]:
# Applies the Le Cun's Tanh function element-wise
# lecun_tanh(x) = 1.7159 * tanh((2/3) * input)

def lecun_tanh(input):
    return 1.7159 * torch.tanh((2 * input) / 3)

### SiLU

In [None]:
# Applies the Sigmoid Linear Unit (SiLU) function element-wise
# SiLU(x) = x * sigmoid(x)

def silu(input, inplace=False):
    if inplace:
        result = input.clone()
        torch.sigmoid_(input)
        input *= result
    else:
        return input * torch.sigmoid(input)

### NLReLU

In [16]:
# Applies the natural logarithm ReLU activation function element-wise

def nl_relu(input, beta=1.):
    return (input > 0).float() * torch.log(1. + beta * torch.clamp(input, min = 0))