# 16: Activation Functions

### ðŸŽ¯ Objective
This notebook visually explores the behavior of various activation functions available in PyTorch. We will look at their shapes, ranges, and how to access them using both the functional API (`torch.*`) and the module API (`torch.nn.*`).

### ðŸ“š Key Concepts
- **Sigmoid:** Squashes input to $(0, 1)$. Used for probabilities.
- **Tanh:** Squashes input to $(-1, 1)$. Zero-centered.
- **ReLU (Rectified Linear Unit):** $max(0, x)$. The standard for hidden layers.
- **API Differences:** `torch.functional` (stateless functions) vs `torch.nn` (layer objects).

In [1]:
# import libraries
import torch
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size':18})

## 1. The Standard Functions (Functional API)

These are the most common activation functions used in Deep Learning. We access them directly from the `torch` library (e.g., `torch.sigmoid`).



In [None]:
# variable to evaluate over (x-axis)
x = torch.linspace(-3,3,101)

# create a function that returns the activated output
def NNoutputx(actfun):
  # get activation function type
  # getattr(obj, 'name') is equivalent to obj.name
  # Example: getattr(torch, 'relu') -> torch.relu
  actfun = getattr(torch,actfun)
  return actfun( x )

In [None]:
# the activation functions
activation_funs = [ 'relu', 'sigmoid', 'tanh' ]

fig = plt.figure(figsize=(10,8))

for actfun in activation_funs:
  plt.plot(x,NNoutputx(actfun),label=actfun,linewidth=3)

# add reference lines
dashlinecol = [.7,.7,.7]
plt.plot(x[[0,-1]],[0,0],'--',color=dashlinecol)
plt.plot(x[[0,-1]],[1,1],'--',color=dashlinecol)
plt.plot([0,0],[-1,3],'--',color=dashlinecol)

# make the plot look nicer
plt.legend()
plt.xlabel('x')
plt.ylabel('$\sigma(x)$')
plt.title('Various activation functions')
plt.xlim(x[[0,-1]])
plt.ylim([-1,3])
plt.show()

# Observations:
# 1. ReLU is 0 for negative x, and linear (x) for positive x.
# 2. Sigmoid is strictly between 0 and 1.
# 3. Tanh is strictly between -1 and 1.

## 2. Advanced Functions (Class API)

PyTorch also implements activation functions as **Classes** inside `torch.nn`. This allows them to be treated as layers (like `nn.Linear`).

Here we explore some variations of ReLU:
* **ReLU6:** $min(max(0,x), 6)$. Used in mobile models to prevent precision loss.
* **LeakyReLU:** $max(0.01x, x)$. Allows a small gradient when $x < 0$ to prevent "dead neurons."
* **Hardshrink:** Sets values near 0 to exactly 0 (creates sparsity).

In [None]:
# create a function that returns the activated output FUNCTION
# this is different from the previous function
def NNoutput(actfun):
  # get activation function type
  # this code replaces torch.nn.relu with torch.nn.<actfun>
  actfun = getattr(torch.nn,actfun)
  
  # Crucial difference: torch.nn functions are CLASSES.
  # We must instantiate the class first (add parentheses)
  return actfun()

In [None]:
# the activation functions
activation_funs = [ 'ReLU6', 'Hardshrink', 'LeakyReLU' ]

fig = plt.figure(figsize=(10,8))

for actfun in activation_funs:
  plt.plot(x,NNoutput(actfun)(x),label=actfun,linewidth=3)

# add reference lines
dashlinecol = [.7,.7,.7]
plt.plot(x[[0,-1]],[0,0],'--',color=dashlinecol)
plt.plot(x[[0,-1]],[1,1],'--',color=dashlinecol)
plt.plot([0,0],[-1,3],'--',color=dashlinecol)

# make the plot look nicer
plt.legend()
plt.xlabel('x')
plt.ylabel('$\sigma(x)$')
plt.title('Various activation functions')
plt.xlim(x[[0,-1]])
plt.ylim([-1,3])
# plt.ylim([-.1,.1])
plt.show()

In [None]:
# relu6 in more detail
# It looks like ReLU, but flattens out at y=6
x = torch.linspace(-3,9,101)
relu6 = torch.nn.ReLU6()

plt.plot(x,relu6(x))
plt.title('ReLU6')
plt.show()

## 3. Differences between `torch` and `torch.nn`

- **`torch.<func>` (e.g., `torch.relu`):** This is a **function**. It takes input, does math, returns output. It has no internal state (parameters). Good for quick calculations.
- **`torch.nn.<Class>` (e.g., `torch.nn.ReLU`):** This is a **class**. You instantiate it as an object. This object can be added to an `nn.Sequential` model definition.

In [None]:
# redefine x (fewer points to facilitate visualization)
x = torch.linspace(-3,3,21)

# Option 1: Functional API (in torch)
y1 = torch.relu(x)

# Option 2: Class API (in torch.nn)
f = torch.nn.ReLU()
y2 = f(x)


# the results are exactly the same
plt.plot(x,y1,'ro',label='torch.relu')
plt.plot(x,y2,'bx',label='torch.nn.ReLU')
plt.legend()
plt.xlabel('Input')
plt.ylabel('Output')
plt.show()

In [None]:
# List of activation functions in PyTorch:
#  https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

## 4. Exploration: How Non-Linearity Works

This exploration shows how a single neuron creates interesting shapes.
We take two inputs, $x_1$ and $x_2$. We combine them linearly: $z = w_1x_1 + w_2x_2$.
Then we pass $z$ through a non-linearity (ReLU).

The result shows how the activation function "bends" the linear input.

In [None]:
# The goal of these explorations is to help you appreciate the remarkably diverse nonlinear shapes that a node can produce.
# All explorations use the code below.

In [None]:
# create input vectors
x1 = torch.linspace(-1,1,20)
x2 = 2*x1

# and corresponding weights
w1 = -.3
w2 = .5

# their linear combination
linpart = x1*w1 + x2*w2

# and the nonlinear output
y = torch.relu(linpart)

# and plot!
plt.plot(x1,linpart,'bo-',label='Linear input')
plt.plot(x1,y,'rs',label='Nonlinear output')
plt.ylabel('$\\hat{y}$ (output of activation function)')
plt.xlabel('x1 variable')
# plt.ylim([-.1,.1]) # optional -- uncomment and modify to zoom in
plt.legend()
plt.show()

In [None]:
# 1) Look through the code to make sure you understand what it does (linear weighted combination -> nonlinear function).
#
# 2) Set x2=x1**2 and run the code. Then set one of the weights to be negative. Then set the negative weight to be close
#    to zero (e.g., -.01) with the positive weight relatively large (e.g., .8). Then swap the signs.
#
# 3) Set x2=x1**2, and set the weights to be .4 and .6. Now set w2=.6 (you might want to zoom in on the y-axis).
#
# 4) Set x2 to be the absolute value of x1 and both weights positive. Then set w2=-.6. Why does w2<0 have such a big impact?
#    More generally, under what conditions are the input and output identical?
#
# 5) Have fun! Spend a few minutes playing around with the code. Also try changing the activation function to tanh or
#    anything else. The goal is to see that really simple input functions with really simple weights can produce really
#    complicated-looking nonlinear outputs.
#