#### 7. Provide an example where the gradients vanish for the sigmoid activation function.

Consider a very simple network which takes in $x$ maps it to $-r \cdot x, r >> 0$ and then applies the sigmoid, which gives us a total function
$$
f(x) = \sigma(-rx)
$$
whose derivative is
$$
f'(x) = -r \sigma'(-rx) = -r \sigma(-rx) (1 - \sigma(-rx)).
$$
Suppose we have an input of $ = 1$ then
$$
f'(x) = -r \sigma(-r) (1 - \sigma(-r)).
$$

![image](mlp_7_1.png)

In [None]:
import torch
from torch import Tensor
import matplotlib.pyplot as plt

# Re-define the computation for function and its derivative at x = 1
rs: Tensor = torch.linspace(0.5, 9, 1000)
# Function f(x) for x = 1
f_rs: Tensor = (-rs).sigmoid()
# Derivative f'(x) for x = 1
fp_rs: Tensor = -rs * (-rs).sigmoid() * (1 - (-rs).sigmoid())

# Create subplots
fig, axs = plt.subplots(2, 1, figsize=(7, 10))

# Plot for the function f at x = 1
axs[0].plot(rs, f_rs, label="$f(r) = \\sigma(-r)$")
axs[0].set_xlabel("$r$")
axs[0].set_ylabel("$f(r)$")
axs[0].legend()
axs[0].set_title("Function $f(r) = \\sigma(-r)$ at $x = 1$")

# Plot for the derivative of f at x = 1
axs[1].plot(rs, fp_rs, label="$f'(r) = -r \\sigma(-r) (1 - \\sigma(-r))$")
axs[1].axhline(0, color="red", linestyle="--")
axs[1].set_xlabel("$r$")
axs[1].set_ylabel("$f'(r)$")
axs[1].legend()
axs[1].set_title("Derivative of $f(r)$ at $x = 1$")

plt.tight_layout()
plt.savefig("mlp_7_2.png")
plt.show()

In [None]:
import torch
from torch import Tensor
import matplotlib.pyplot as plt

# First range of r values
rs1: Tensor = torch.linspace(0.5, 9, 1000)
sigmoid_rs1: Tensor = (-rs1).sigmoid()
fps1: Tensor = -rs1 * sigmoid_rs1 * (1 - sigmoid_rs1)

# Second range of r values
rs2: Tensor = torch.linspace(7, 9, 1000)
sigmoid_rs2: Tensor = (-rs2).sigmoid()
fps2: Tensor = -rs2 * sigmoid_rs2 * (1 - sigmoid_rs2)

# Create subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 5))

# Plot for the first range of r
axs[0].plot(rs1, fps1, label="$\\nabla f_r(1)$")
axs[0].axhline(0, color="red", linestyle="--")
axs[0].set_xlabel("$r$")
axs[0].set_ylabel("$\\nabla f_r(1)$")
axs[0].legend()
axs[0].set_title("Range 0.5 to 9")

# Plot for the second range of r
axs[1].plot(rs2, fps2, label="$\\nabla f_r(1)$")
axs[1].axhline(0, color="red", linestyle="--")
axs[1].set_xlabel("$r$")
axs[1].set_ylabel("$\\nabla f_r(1)$")
axs[1].legend()
axs[1].set_title("Zoomed Range 7 to 9")

plt.tight_layout()
plt.savefig("mlp_7_1.png")
plt.show()