# Activation Functions: Breaking Linearity

**Inference Engineering Series - Notebook 2**

---

In the previous notebook, we saw that neural networks are built from linear layers: $y = Wx + b$. But here's the problem -- if you stack two linear layers without anything between them, you just get another linear layer. The network can't learn anything more complex than a single layer could.

**Activation functions** are the non-linear operations we insert between layers to break this linearity. They are what give neural networks the power to approximate any function.

In this notebook, we'll explore every major activation function used in modern LLMs and understand why the field has converged on specific choices like SwiGLU.

## What You'll Learn

1. **Why activation functions are necessary** - the linear collapse problem
2. **Classic activation functions** - Sigmoid, Tanh, ReLU
3. **Modern activation functions** - GELU, SiLU/Swish, SwiGLU
4. **Visual comparison** of all major activation functions
5. **Which LLMs use which activations** - and why
6. **Compute cost** of different activation functions during inference

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import torch
import torch.nn as nn
import torch.nn.functional as F

print(f"PyTorch version: {torch.__version__}")

## Part 1: The Linearity Collapse Problem

Let's prove mathematically and experimentally that stacking linear layers without activation functions is pointless.

If we have two linear layers:
- Layer 1: $h = W_1 x + b_1$
- Layer 2: $y = W_2 h + b_2$

Substituting:
$$y = W_2(W_1 x + b_1) + b_2 = (W_2 W_1) x + (W_2 b_1 + b_2) = W' x + b'$$

This is just a single linear layer with $W' = W_2 W_1$ and $b' = W_2 b_1 + b_2$.

In [None]:
# Demonstrate linear collapse
np.random.seed(42)

# Two linear layers: 4 -> 8 -> 3
W1 = np.random.randn(8, 4)  # Layer 1 weights
b1 = np.random.randn(8)     # Layer 1 bias
W2 = np.random.randn(3, 8)  # Layer 2 weights
b2 = np.random.randn(3)     # Layer 2 bias

# Collapsed single layer equivalent
W_collapsed = W2 @ W1           # (3, 4)
b_collapsed = W2 @ b1 + b2      # (3,)

# Test with random inputs
x = np.random.randn(4)

# Two-layer computation
h = W1 @ x + b1
y_two_layers = W2 @ h + b2

# Single collapsed layer
y_collapsed = W_collapsed @ x + b_collapsed

print("Two-layer output:  ", y_two_layers.round(6))
print("Collapsed output:  ", y_collapsed.round(6))
print(f"\nAre they identical? {np.allclose(y_two_layers, y_collapsed)}")
print(f"Max difference:    {np.max(np.abs(y_two_layers - y_collapsed)):.2e}")
print("\n=> Two linear layers without activation = one linear layer!")
print("   All that extra computation and parameters are wasted.")

In [None]:
# Visual demonstration: linear layers can only learn linear decision boundaries
from matplotlib.colors import ListedColormap

# Create a non-linear dataset (XOR-like)
np.random.seed(42)
n_points = 200

# Class 0: top-left and bottom-right
x0a = np.random.randn(n_points//4, 2) * 0.4 + np.array([-1, 1])
x0b = np.random.randn(n_points//4, 2) * 0.4 + np.array([1, -1])
x0 = np.vstack([x0a, x0b])

# Class 1: top-right and bottom-left
x1a = np.random.randn(n_points//4, 2) * 0.4 + np.array([1, 1])
x1b = np.random.randn(n_points//4, 2) * 0.4 + np.array([-1, -1])
x1 = np.vstack([x1a, x1b])

X = np.vstack([x0, x1])
y = np.array([0] * len(x0) + [1] * len(x1))

# Train linear model (no activation)
X_torch = torch.tensor(X, dtype=torch.float32)
y_torch = torch.tensor(y, dtype=torch.long)

linear_model = nn.Sequential(
    nn.Linear(2, 16),
    nn.Linear(16, 16),
    nn.Linear(16, 2)
)

nonlinear_model = nn.Sequential(
    nn.Linear(2, 16),
    nn.ReLU(),
    nn.Linear(16, 16),
    nn.ReLU(),
    nn.Linear(16, 2)
)

# Train both models
for model, name in [(linear_model, 'Linear'), (nonlinear_model, 'Non-linear')]:
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(500):
        optimizer.zero_grad()
        output = model(X_torch)
        loss = criterion(output, y_torch)
        loss.backward()
        optimizer.step()
    
    acc = (model(X_torch).argmax(dim=1) == y_torch).float().mean()
    print(f"{name:12s} model - Final loss: {loss.item():.4f}, Accuracy: {acc.item():.1%}")

In [None]:
# Visualize decision boundaries
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Create a mesh grid for decision boundary
xx, yy = np.meshgrid(np.linspace(-2.5, 2.5, 200), np.linspace(-2.5, 2.5, 200))
grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)

for ax, model, title in [(axes[0], linear_model, 'WITHOUT Activation Functions\n(Linear Only - Cannot Solve XOR)'),
                          (axes[1], nonlinear_model, 'WITH Activation Functions (ReLU)\n(Can Learn Non-linear Boundaries)')]:
    with torch.no_grad():
        Z = model(grid).argmax(dim=1).numpy().reshape(xx.shape)
    
    ax.contourf(xx, yy, Z, alpha=0.3, cmap=ListedColormap(['#FF6B6B', '#4ECDC4']))
    ax.scatter(x0[:, 0], x0[:, 1], c='#FF6B6B', edgecolors='black', s=30, label='Class 0')
    ax.scatter(x1[:, 0], x1[:, 1], c='#4ECDC4', edgecolors='black', s=30, label='Class 1')
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.legend()
    ax.set_xlim(-2.5, 2.5)
    ax.set_ylim(-2.5, 2.5)
    ax.grid(True, alpha=0.2)

plt.suptitle('Why Activation Functions Are Essential', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## Part 2: Classic Activation Functions

Let's explore the activation functions that shaped the history of deep learning.

In [None]:
# Define all activation functions from first principles

def sigmoid(x):
    """Sigmoid: squashes input to (0, 1)"""
    return 1 / (1 + np.exp(-x))

def tanh(x):
    """Tanh: squashes input to (-1, 1)"""
    return np.tanh(x)

def relu(x):
    """ReLU: max(0, x) - the workhorse of deep learning"""
    return np.maximum(0, x)

def leaky_relu(x, alpha=0.01):
    """Leaky ReLU: allows small gradient for negative inputs"""
    return np.where(x > 0, x, alpha * x)

# Their derivatives (important for understanding gradient flow)
def sigmoid_grad(x):
    s = sigmoid(x)
    return s * (1 - s)

def tanh_grad(x):
    return 1 - np.tanh(x)**2

def relu_grad(x):
    return np.where(x > 0, 1.0, 0.0)

def leaky_relu_grad(x, alpha=0.01):
    return np.where(x > 0, 1.0, alpha)

In [None]:
# Plot classic activation functions and their gradients
x = np.linspace(-5, 5, 1000)

classics = [
    ('Sigmoid', sigmoid, sigmoid_grad, '#FF6B6B'),
    ('Tanh', tanh, tanh_grad, '#4ECDC4'),
    ('ReLU', relu, relu_grad, '#45B7D1'),
    ('Leaky ReLU', leaky_relu, leaky_relu_grad, '#96CEB4'),
]

fig, axes = plt.subplots(2, 4, figsize=(18, 8))

for idx, (name, func, grad_func, color) in enumerate(classics):
    # Function
    ax = axes[0, idx]
    ax.plot(x, func(x), linewidth=2.5, color=color)
    ax.axhline(y=0, color='black', linewidth=0.5)
    ax.axvline(x=0, color='black', linewidth=0.5)
    ax.set_title(f'{name}', fontsize=13, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-2, 2)
    
    # Gradient
    ax = axes[1, idx]
    ax.plot(x, grad_func(x), linewidth=2.5, color=color, linestyle='--')
    ax.axhline(y=0, color='black', linewidth=0.5)
    ax.axvline(x=0, color='black', linewidth=0.5)
    ax.set_title(f'{name} Gradient', fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.1, 1.2)

axes[0, 0].set_ylabel('f(x)', fontsize=12)
axes[1, 0].set_ylabel("f'(x)", fontsize=12)

plt.suptitle('Classic Activation Functions and Their Gradients', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### Problems with Classic Activations

| Activation | Problem |
|---|---|
| **Sigmoid** | Gradients vanish for large/small inputs (saturates at 0 and 1). Output not zero-centered. |
| **Tanh** | Still saturates, though zero-centered. Gradient max is 1.0 (at x=0). |
| **ReLU** | "Dead neurons" - once a neuron outputs 0, it may never recover. Gradient is exactly 0 for x < 0. |
| **Leaky ReLU** | Fixes dead neurons but the negative slope is arbitrary. |

## Part 3: Modern Activation Functions Used in LLMs

Modern language models have converged on smoother activation functions. Let's explore them.

In [None]:
# Modern activation functions from first principles

def gelu(x):
    """GELU: Gaussian Error Linear Unit
    Used in: BERT, GPT-2, GPT-3
    GELU(x) = x * Phi(x) where Phi is the CDF of standard normal
    """
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

def silu(x):
    """SiLU (Sigmoid Linear Unit) / Swish
    Used in: EfficientNet, various models
    SiLU(x) = x * sigmoid(x)
    """
    return x * sigmoid(x)

def mish(x):
    """Mish activation
    Mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^x))
    """
    return x * np.tanh(np.log(1 + np.exp(x)))

print("Key insight: All modern activations follow the pattern x * g(x)")
print("where g(x) is a smooth gating function between 0 and 1.")
print()
print("  GELU: x * Phi(x)       - Phi is Gaussian CDF")
print("  SiLU: x * sigmoid(x)   - sigmoid is the gate")
print("  Mish: x * tanh(softplus(x))")

In [None]:
# Compare all modern activations
x = np.linspace(-5, 5, 1000)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: All activations overlaid
ax = axes[0]
functions = [
    ('ReLU', relu, '#45B7D1', '-'),
    ('GELU', gelu, '#FF6B6B', '-'),
    ('SiLU/Swish', silu, '#4ECDC4', '-'),
    ('Mish', mish, '#96CEB4', '--'),
]

for name, func, color, style in functions:
    ax.plot(x, func(x), linewidth=2.5, color=color, linestyle=style, label=name)

ax.axhline(y=0, color='black', linewidth=0.5)
ax.axvline(x=0, color='black', linewidth=0.5)
ax.set_title('Modern Activation Functions', fontsize=13, fontweight='bold')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
ax.set_xlim(-5, 5)
ax.set_ylim(-1, 5)
ax.set_xlabel('x', fontsize=12)
ax.set_ylabel('f(x)', fontsize=12)

# Right: Zoomed in near zero - the critical region
ax = axes[1]
x_zoom = np.linspace(-2, 2, 1000)

for name, func, color, style in functions:
    ax.plot(x_zoom, func(x_zoom), linewidth=2.5, color=color, linestyle=style, label=name)

ax.axhline(y=0, color='black', linewidth=0.5)
ax.axvline(x=0, color='black', linewidth=0.5)
ax.set_title('Zoomed: Near Zero Region', fontsize=13, fontweight='bold')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
ax.set_xlabel('x', fontsize=12)
ax.set_ylabel('f(x)', fontsize=12)

plt.tight_layout()
plt.show()

print("Key observation: GELU and SiLU are smooth, while ReLU has a sharp corner at x=0")
print("The smoothness near zero helps with gradient flow during training.")
print("GELU and SiLU also allow small negative values, unlike ReLU.")

## Part 4: SwiGLU - The Activation Function of Modern LLMs

Most modern LLMs (Llama, Mistral, Gemma, Qwen, etc.) use **SwiGLU** in their feed-forward networks. SwiGLU is not just an activation function -- it's an **activation + gating mechanism**.

Standard FFN:
$$\text{FFN}(x) = W_2 \cdot \text{activation}(W_1 x + b_1) + b_2$$

SwiGLU FFN:
$$\text{FFN}_{\text{SwiGLU}}(x) = W_2 \cdot [\text{SiLU}(W_{gate} x) \odot (W_{up} x)] + b_2$$

Where $\odot$ is element-wise multiplication. The key insight: instead of one projection, we use **two** projections -- one goes through SiLU and acts as a "gate" that controls what information flows through.

In [None]:
class StandardFFN(nn.Module):
    """Standard Feed-Forward Network with ReLU/GELU."""
    def __init__(self, hidden_dim, ffn_dim, activation='relu'):
        super().__init__()
        self.up = nn.Linear(hidden_dim, ffn_dim)
        self.down = nn.Linear(ffn_dim, hidden_dim)
        self.activation = nn.ReLU() if activation == 'relu' else nn.GELU()
    
    def forward(self, x):
        x = self.up(x)           # Project up
        x = self.activation(x)    # Activation
        x = self.down(x)          # Project down
        return x

class SwiGLUFFN(nn.Module):
    """SwiGLU Feed-Forward Network (used in Llama, Mistral, etc.)."""
    def __init__(self, hidden_dim, ffn_dim):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_dim, ffn_dim, bias=False)  # Gate
        self.up_proj = nn.Linear(hidden_dim, ffn_dim, bias=False)    # Up projection
        self.down_proj = nn.Linear(ffn_dim, hidden_dim, bias=False)  # Down projection
    
    def forward(self, x):
        gate = F.silu(self.gate_proj(x))  # SiLU applied to gate projection
        up = self.up_proj(x)               # Up projection (no activation)
        x = gate * up                      # Element-wise gating
        x = self.down_proj(x)              # Project back down
        return x

# Compare parameter counts
hidden_dim = 4096
ffn_dim_standard = 4 * hidden_dim  # 16384
ffn_dim_swiglu = int(2/3 * 4 * hidden_dim)  # 10922 (reduced to keep param count similar)

standard = StandardFFN(hidden_dim, ffn_dim_standard)
swiglu = SwiGLUFFN(hidden_dim, ffn_dim_swiglu)

std_params = sum(p.numel() for p in standard.parameters())
swiglu_params = sum(p.numel() for p in swiglu.parameters())

print(f"Standard FFN (ReLU):")
print(f"  FFN dim: {ffn_dim_standard}")
print(f"  Parameters: {std_params:,}")
print(f"  Matrices: up({hidden_dim}x{ffn_dim_standard}) + down({ffn_dim_standard}x{hidden_dim})")
print(f"\nSwiGLU FFN:")
print(f"  FFN dim: {ffn_dim_swiglu}")
print(f"  Parameters: {swiglu_params:,}")
print(f"  Matrices: gate({hidden_dim}x{ffn_dim_swiglu}) + up({hidden_dim}x{ffn_dim_swiglu}) + down({ffn_dim_swiglu}x{hidden_dim})")
print(f"\nParameter ratio: {swiglu_params/std_params:.2f}x")

In [None]:
# Visualize the gating mechanism in SwiGLU
np.random.seed(42)
x_demo = np.linspace(-3, 3, 100)

# Simulate gate and up projections (1D for visualization)
gate_values = silu(x_demo * 1.5 + 0.5)  # After SiLU
up_values = x_demo * 0.8 - 0.3           # Linear (no activation)
output = gate_values * up_values           # Element-wise product

fig, axes = plt.subplots(1, 4, figsize=(18, 4))

axes[0].plot(x_demo, silu(x_demo * 1.5 + 0.5), color='#FF6B6B', linewidth=2)
axes[0].set_title('Gate: SiLU(W_gate * x)', fontsize=12, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].axhline(y=0, color='black', linewidth=0.5)

axes[1].plot(x_demo, up_values, color='#4ECDC4', linewidth=2)
axes[1].set_title('Up: W_up * x', fontsize=12, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].axhline(y=0, color='black', linewidth=0.5)

axes[2].text(0.5, 0.5, 'gate * up\n(element-wise)', fontsize=14, ha='center', va='center',
            transform=axes[2].transAxes, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='#FFD93D', alpha=0.8))
axes[2].axis('off')

axes[3].plot(x_demo, output, color='#45B7D1', linewidth=2)
axes[3].set_title('Output: gate * up', fontsize=12, fontweight='bold')
axes[3].grid(True, alpha=0.3)
axes[3].axhline(y=0, color='black', linewidth=0.5)

plt.suptitle('SwiGLU Gating Mechanism', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("The gate controls how much of the 'up' projection passes through.")
print("This gives the model more expressivity than a simple activation function.")

## Part 5: Which LLMs Use Which Activations?

Let's survey the activation functions used across popular LLM families.

In [None]:
# LLM activation function survey
llm_activations = [
    ('GPT-2 (2019)', 'GELU', 'OpenAI', 2019),
    ('GPT-3 (2020)', 'GELU', 'OpenAI', 2020),
    ('BERT (2018)', 'GELU', 'Google', 2018),
    ('T5 (2019)', 'ReLU', 'Google', 2019),
    ('PaLM (2022)', 'SwiGLU', 'Google', 2022),
    ('Llama 1 (2023)', 'SwiGLU', 'Meta', 2023),
    ('Llama 2 (2023)', 'SwiGLU', 'Meta', 2023),
    ('Llama 3 (2024)', 'SwiGLU', 'Meta', 2024),
    ('Mistral (2023)', 'SwiGLU', 'Mistral AI', 2023),
    ('Mixtral (2024)', 'SwiGLU', 'Mistral AI', 2024),
    ('Gemma (2024)', 'GELU', 'Google', 2024),
    ('Qwen 2 (2024)', 'SwiGLU', 'Alibaba', 2024),
    ('Phi-3 (2024)', 'SwiGLU', 'Microsoft', 2024),
    ('DeepSeek-V2 (2024)', 'SwiGLU', 'DeepSeek', 2024),
]

print(f"{'Model':<25s} {'Activation':<12s} {'Organization':<15s} {'Year'}")
print("-" * 60)
for model, act, org, year in llm_activations:
    print(f"{model:<25s} {act:<12s} {org:<15s} {year}")

# Count
from collections import Counter
act_counts = Counter(act for _, act, _, _ in llm_activations)
print("\nActivation function usage:")
for act, count in act_counts.most_common():
    print(f"  {act}: {count} models")

In [None]:
# Visualize the timeline
fig, ax = plt.subplots(figsize=(14, 6))

activation_colors = {'GELU': '#FF6B6B', 'SwiGLU': '#4ECDC4', 'ReLU': '#45B7D1'}

for idx, (model, act, org, year) in enumerate(llm_activations):
    color = activation_colors.get(act, 'gray')
    ax.scatter(year, idx, s=200, c=color, edgecolors='black', linewidth=1, zorder=3)
    ax.text(year + 0.1, idx, f'  {model}', fontsize=10, va='center')

# Legend
legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=c, 
                          markersize=12, label=name) 
                   for name, c in activation_colors.items()]
ax.legend(handles=legend_elements, fontsize=12, loc='lower right')

ax.set_xlabel('Year', fontsize=12)
ax.set_title('LLM Activation Functions Over Time', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.set_yticks([])

plt.tight_layout()
plt.show()

print("\nClear trend: The industry has converged on SwiGLU starting from 2022-2023.")
print("PaLM was one of the first major models to use it, followed by Llama.")

## Part 6: Let's Verify with a Real Model

Let's load a real model and inspect what activation function it uses.

In [None]:
!pip install transformers -q

In [None]:
from transformers import AutoConfig

models_to_check = [
    'gpt2',
    'bert-base-uncased',
    'meta-llama/Llama-2-7b-hf',
    'mistralai/Mistral-7B-v0.1',
    'Qwen/Qwen2-0.5B',
    'google/gemma-2b',
    'microsoft/phi-2',
]

print(f"{'Model':<40s} {'Activation':<15s} {'Hidden Dim':>10s} {'FFN Dim':>10s}")
print("-" * 80)

for model_name in models_to_check:
    try:
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        
        # Different models store activation in different config fields
        act = getattr(config, 'hidden_act', 
               getattr(config, 'activation_function', 
               getattr(config, 'hidden_activation', 'unknown')))
        hidden = getattr(config, 'hidden_size', 
                  getattr(config, 'n_embd', '?'))
        ffn = getattr(config, 'intermediate_size', 
               getattr(config, 'n_inner', '?'))
        
        print(f"{model_name:<40s} {str(act):<15s} {str(hidden):>10s} {str(ffn):>10s}")
    except Exception as e:
        print(f"{model_name:<40s} Error: {e}")

## Part 7: Interactive Comparison

Let's create a comprehensive side-by-side comparison showing how each activation transforms a distribution of values.

In [None]:
# How do activation functions transform a Gaussian distribution?
np.random.seed(42)
input_values = np.random.randn(10000)

activations = [
    ('ReLU', relu),
    ('GELU', gelu),
    ('SiLU/Swish', silu),
    ('Sigmoid', sigmoid),
    ('Tanh', tanh),
]

fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# Input distribution
ax = axes[0, 0]
ax.hist(input_values, bins=80, color='gray', alpha=0.7, density=True)
ax.set_title('Input Distribution\n(Standard Normal)', fontsize=12, fontweight='bold')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_xlim(-4, 4)

colors = ['#45B7D1', '#FF6B6B', '#4ECDC4', '#FFD93D', '#96CEB4']
for idx, ((name, func), color) in enumerate(zip(activations, colors)):
    row, col = divmod(idx + 1, 3)
    ax = axes[row, col]
    
    output = func(input_values)
    ax.hist(output, bins=80, color=color, alpha=0.7, density=True)
    ax.set_title(f'After {name}\nmean={output.mean():.3f}, std={output.std():.3f}', 
                fontsize=12, fontweight='bold')
    ax.set_xlabel('Value')
    ax.set_ylabel('Density')

plt.suptitle('How Activation Functions Transform Gaussian Inputs', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# The key comparison: how much information is preserved?
print("Information preservation analysis:")
print("-" * 60)

for name, func in activations:
    output = func(input_values)
    
    # What fraction of values are set to (near) zero?
    near_zero = np.mean(np.abs(output) < 0.01)
    
    # What's the effective range?
    p5, p95 = np.percentile(output, [5, 95])
    
    # Correlation with input
    correlation = np.corrcoef(input_values, output)[0, 1]
    
    print(f"{name:15s}: near_zero={near_zero:.1%}, range=[{p5:.2f}, {p95:.2f}], "
          f"corr_with_input={correlation:.3f}")

## Part 8: Compute Cost of Activation Functions

During inference, activation functions are not free -- they take time to compute. Let's benchmark them.

In [None]:
import time

# Benchmark activation functions in PyTorch
x_bench = torch.randn(1024, 4096)  # Typical hidden state

torch_activations = [
    ('ReLU', F.relu),
    ('GELU (approx)', lambda x: F.gelu(x, approximate='tanh')),
    ('GELU (exact)', lambda x: F.gelu(x, approximate='none')),
    ('SiLU/Swish', F.silu),
    ('Sigmoid', torch.sigmoid),
    ('Tanh', torch.tanh),
]

print(f"Benchmarking on tensor shape: {x_bench.shape}")
print(f"Total elements: {x_bench.numel():,}")
print("-" * 50)

results = []
for name, func in torch_activations:
    # Warmup
    for _ in range(10):
        _ = func(x_bench)
    
    # Benchmark
    times = []
    for _ in range(100):
        start = time.time()
        _ = func(x_bench)
        times.append(time.time() - start)
    
    median_time = np.median(times) * 1000
    results.append((name, median_time))
    print(f"{name:20s}: {median_time:.3f} ms")

# Relative to ReLU
relu_time = results[0][1]
print("\nRelative to ReLU:")
for name, t in results:
    print(f"  {name:20s}: {t/relu_time:.2f}x")

In [None]:
# Visualize compute cost comparison
fig, ax = plt.subplots(figsize=(10, 5))

names = [r[0] for r in results]
times = [r[1] for r in results]
colors = ['#45B7D1', '#FF6B6B', '#FF8E8E', '#4ECDC4', '#FFD93D', '#96CEB4']

bars = ax.barh(names, times, color=colors, edgecolor='black', linewidth=0.5)

for bar, t in zip(bars, times):
    ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2, 
           f'{t:.3f} ms', va='center', fontsize=11)

ax.set_xlabel('Time (ms)', fontsize=12)
ax.set_title('Activation Function Compute Time (CPU, 1024x4096 tensor)', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.show()

print("Note: On GPU, the differences are much smaller because activation functions")
print("are typically memory-bound and fused with other operations (kernel fusion).")

## Part 9: Activation Functions and the Vanishing Gradient Problem

Let's visualize why sigmoid/tanh caused vanishing gradients in deep networks and how ReLU-family functions solve this.

In [None]:
# Simulate gradient flow through many layers
def simulate_gradient_flow(activation_name, num_layers=50, hidden_dim=128):
    """Simulate how gradients flow backward through layers."""
    
    if activation_name == 'sigmoid':
        activation = nn.Sigmoid()
    elif activation_name == 'tanh':
        activation = nn.Tanh()
    elif activation_name == 'relu':
        activation = nn.ReLU()
    elif activation_name == 'gelu':
        activation = nn.GELU()
    elif activation_name == 'silu':
        activation = nn.SiLU()
    
    # Build deep network
    layers = []
    for i in range(num_layers):
        layers.append(nn.Linear(hidden_dim, hidden_dim))
        layers.append(activation)
    model = nn.Sequential(*layers)
    
    # Forward pass
    x = torch.randn(1, hidden_dim, requires_grad=True)
    output = model(x)
    loss = output.sum()
    loss.backward()
    
    # Collect gradient norms at each layer
    grad_norms = []
    for layer in model:
        if hasattr(layer, 'weight') and layer.weight.grad is not None:
            grad_norms.append(layer.weight.grad.norm().item())
    
    return grad_norms

# Run for different activations
fig, ax = plt.subplots(figsize=(14, 6))

activation_names = ['sigmoid', 'tanh', 'relu', 'gelu', 'silu']
colors = ['#FFD93D', '#96CEB4', '#45B7D1', '#FF6B6B', '#4ECDC4']

for act_name, color in zip(activation_names, colors):
    try:
        grads = simulate_gradient_flow(act_name, num_layers=30)
        # Layer index goes from output (last) to input (first)
        ax.plot(range(len(grads)), grads, linewidth=2, color=color, label=act_name.upper())
    except Exception as e:
        print(f"{act_name}: {e}")

ax.set_xlabel('Layer (from first to last)', fontsize=12)
ax.set_ylabel('Gradient Norm', fontsize=12)
ax.set_title('Gradient Norms Across Layers (30-layer network)', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_yscale('log')
plt.tight_layout()
plt.show()

print("Sigmoid/Tanh: gradients vanish exponentially as we go deeper")
print("ReLU-family: gradients flow much more effectively through the network")

## Part 10: The Full Picture - Activation in a Transformer FFN

Let's put it all together and show how activation functions fit into the feed-forward network of a transformer block.

In [None]:
# Implement a complete transformer FFN block and trace through it

class TransformerFFNBlock(nn.Module):
    """Complete FFN block as used in a transformer."""
    def __init__(self, hidden_dim=768, ffn_type='swiglu'):
        super().__init__()
        self.ffn_type = ffn_type
        
        if ffn_type == 'swiglu':
            ffn_dim = int(hidden_dim * 8/3)  # SwiGLU uses 8/3 multiplier
            self.gate_proj = nn.Linear(hidden_dim, ffn_dim, bias=False)
            self.up_proj = nn.Linear(hidden_dim, ffn_dim, bias=False)
            self.down_proj = nn.Linear(ffn_dim, hidden_dim, bias=False)
        elif ffn_type == 'gelu':
            ffn_dim = hidden_dim * 4
            self.up_proj = nn.Linear(hidden_dim, ffn_dim)
            self.down_proj = nn.Linear(ffn_dim, hidden_dim)
        
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.ffn_dim = ffn_dim
    
    def forward(self, x, return_intermediates=False):
        intermediates = {}
        
        # Layer norm
        normed = self.layer_norm(x)
        intermediates['after_layernorm'] = normed.detach()
        
        if self.ffn_type == 'swiglu':
            gate = self.gate_proj(normed)
            intermediates['gate_before_activation'] = gate.detach()
            
            gate = F.silu(gate)
            intermediates['gate_after_silu'] = gate.detach()
            
            up = self.up_proj(normed)
            intermediates['up_projection'] = up.detach()
            
            hidden = gate * up
            intermediates['after_gating'] = hidden.detach()
            
            output = self.down_proj(hidden)
        else:
            up = self.up_proj(normed)
            intermediates['before_activation'] = up.detach()
            
            activated = F.gelu(up)
            intermediates['after_activation'] = activated.detach()
            
            output = self.down_proj(activated)
        
        intermediates['output'] = output.detach()
        
        # Residual connection
        result = x + output
        intermediates['after_residual'] = result.detach()
        
        if return_intermediates:
            return result, intermediates
        return result

# Run forward pass with SwiGLU
ffn_block = TransformerFFNBlock(hidden_dim=768, ffn_type='swiglu')
ffn_block.eval()

x_input = torch.randn(1, 10, 768)  # (batch=1, seq_len=10, hidden=768)
with torch.no_grad():
    output, intermediates = ffn_block(x_input, return_intermediates=True)

print("SwiGLU FFN Trace (showing first token):")
print("=" * 60)
for name, tensor in intermediates.items():
    vals = tensor[0, 0]  # First batch, first token
    print(f"{name:30s} shape={str(list(tensor.shape)):15s} "
          f"mean={vals.mean():.4f} std={vals.std():.4f} "
          f"min={vals.min():.4f} max={vals.max():.4f}")

In [None]:
# Visualize the intermediate values
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

plot_keys = [
    ('after_layernorm', 'After LayerNorm', '#45B7D1'),
    ('gate_before_activation', 'Gate (before SiLU)', '#FFD93D'),
    ('gate_after_silu', 'Gate (after SiLU)', '#FF6B6B'),
    ('up_projection', 'Up Projection', '#4ECDC4'),
    ('after_gating', 'After Gating (gate * up)', '#96CEB4'),
    ('after_residual', 'After Residual + FFN', '#DDA0DD'),
]

for idx, (key, title, color) in enumerate(plot_keys):
    ax = axes[idx // 3, idx % 3]
    vals = intermediates[key][0, 0].numpy()  # First token
    ax.hist(vals, bins=80, color=color, alpha=0.7, edgecolor='black', linewidth=0.3)
    ax.set_title(f'{title}\nmean={vals.mean():.3f}, std={vals.std():.3f}', fontsize=11, fontweight='bold')
    ax.axvline(x=0, color='red', linestyle='--', alpha=0.5)

plt.suptitle('Value Distributions Through SwiGLU FFN Block', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 11: Sparsity Patterns

An interesting property of activation functions like ReLU and SiLU: they create **sparsity** in the activations. This has implications for inference optimization (e.g., mixture of experts, activation checkpointing).

In [None]:
# Analyze sparsity patterns of different activations
x_test = torch.randn(1000, 4096)  # 1000 samples, 4096 dimensions

activations_to_test = [
    ('ReLU', F.relu),
    ('GELU', F.gelu),
    ('SiLU', F.silu),
    ('Sigmoid', torch.sigmoid),
    ('Tanh', torch.tanh),
]

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

print(f"{'Activation':<12s} {'Zero%':>8s} {'Near-Zero%':>12s} {'Mean':>8s}")
print("-" * 45)

for idx, (name, func) in enumerate(activations_to_test):
    output = func(x_test)
    
    exact_zeros = (output == 0).float().mean().item()
    near_zeros = (output.abs() < 0.1).float().mean().item()
    mean_val = output.mean().item()
    
    print(f"{name:<12s} {exact_zeros:>7.1%} {near_zeros:>11.1%} {mean_val:>8.4f}")
    
    # Visualize sparsity pattern (first sample, first 256 dims)
    ax = axes[idx]
    vals = output[0, :256].numpy()
    ax.bar(range(len(vals)), vals, width=1.0,
           color=['#FF6B6B' if v > 0.1 else '#E8E8E8' for v in np.abs(vals)])
    ax.set_title(f'{name}\nActive: {(np.abs(vals)>0.1).mean():.0%}', fontsize=11, fontweight='bold')
    ax.set_xlabel('Dimension')
    ax.axhline(y=0, color='black', linewidth=0.5)

plt.suptitle('Activation Sparsity Patterns (first 256 dimensions)', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

---

## Key Takeaways

1. **Activation functions are essential** - without them, any number of stacked linear layers collapses to a single linear transformation. They give neural networks the power to learn non-linear functions.

2. **The evolution**: Sigmoid/Tanh (1990s) -> ReLU (2012) -> GELU (2016) -> SwiGLU (2020+). Each generation solved problems of the previous one.

3. **Modern LLMs overwhelmingly use SwiGLU** (Llama, Mistral, Qwen, Phi, DeepSeek) or GELU (GPT, BERT, Gemma). The industry has converged on these choices.

4. **SwiGLU is more than an activation** - it's a gated mechanism using two projections: `SiLU(W_gate * x) * (W_up * x)`. The gate controls information flow.

5. **The tradeoff with SwiGLU**: it uses 3 weight matrices instead of 2 (gate + up + down), but the FFN dimension is reduced by 2/3 to compensate, keeping total parameter count similar.

6. **For inference**: activation functions are typically not the bottleneck. Matrix multiplications dominate compute cost. However, on GPU, activations are often fused into matmul kernels.

7. **Sparsity from activations** is a research direction for faster inference: if many activation values are near zero, we might skip those computations.

---

**Next notebook:** We'll explore tokenization -- how LLMs convert text into the numbers that flow through these linear layers and activations.