# The Limits of Depth

_Adapted from [Dataflowr Module 14b](https://dataflowr.github.io/website/modules/14b-depth/) by Andrei Bursuc.__

In practice, training very deep networks is **hard**. In this notebook, we explore the main obstacles and their solutions:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
%matplotlib inline

torch.manual_seed(42)
np.random.seed(42)

# Use GPU if available
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f'Using device: {device}')

In [None]:
def make_spiral_data(n=500):
    """Two interleaved spirals."""
    t = torch.linspace(0, 4 * np.pi, n // 2)
    r = t / (4 * np.pi) * 2
    noise = torch.randn(n // 2) * 0.08
    X0 = torch.stack([(r + noise) * torch.cos(t), (r + noise) * torch.sin(t)], dim=1)
    X1 = torch.stack([(r + noise) * torch.cos(t + np.pi), (r + noise) * torch.sin(t + np.pi)], dim=1)
    X = torch.cat([X0, X1], dim=0)
    y = torch.cat([torch.zeros(n // 2), torch.ones(n // 2)]).long()
    return X, y

def plot_decision_boundary(model, X, y, ax=None, title=None, resolution=200):
    """Plot the decision boundary of a 2D classifier."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
    margin = 0.5
    x_min, x_max = X[:, 0].min() - margin, X[:, 0].max() + margin
    y_min, y_max = X[:, 1].min() - margin, X[:, 1].max() + margin
    xx, yy = torch.meshgrid(
        torch.linspace(x_min, x_max, resolution),
        torch.linspace(y_min, y_max, resolution),
        indexing='xy'
    )
    grid = torch.stack([xx.flatten(), yy.flatten()], dim=1)
    with torch.no_grad():
        logits = model(grid)
        preds = logits.argmax(dim=1).reshape(xx.shape)
    cmap_bg = ListedColormap(['#AACCFF', '#FFAAAA'])
    ax.contourf(xx.numpy(), yy.numpy(), preds.numpy(), alpha=0.3, cmap=cmap_bg)
    ax.scatter(X[y == 0, 0], X[y == 0, 1], c='royalblue', alpha=0.5, s=10, edgecolors='k', linewidths=0.3)
    ax.scatter(X[y == 1, 0], X[y == 1, 1], c='crimson', alpha=0.5, s=10, edgecolors='k', linewidths=0.3)
    if title:
        ax.set_title(title, fontsize=12)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.2)
    return ax

## The Vanishing Gradient Problem

### The Chain Rule in Deep Networks

During backpropagation, the gradient of the loss with respect to the weights of layer $\ell$ involves a **product of Jacobians** across all subsequent layers:

$$
\frac{\partial \mathcal{L}}{\partial W_\ell} = \frac{\partial \mathcal{L}}{\partial f_L} \cdot \prod_{k=\ell+1}^{L} \frac{\partial f_k}{\partial f_{k-1}} \cdot \frac{\partial f_\ell}{\partial W_\ell}
$$

Each factor $\frac{\partial f_k}{\partial f_{k-1}} = \text{diag}(\sigma'(z_k)) \cdot W_k$ depends on the **derivative of the activation function**.

### Sigmoid

The sigmoid activation function is:

$$
\sigma(x) = \frac{1}{1 + e^{-x}}, \qquad \sigma'(x) = \sigma(x)(1 - \sigma(x))
$$

In [None]:
# Visualize sigmoid and its derivative
x = torch.linspace(-6, 6, 300)
sigmoid = torch.sigmoid(x)
sigmoid_grad = sigmoid * (1 - sigmoid)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(x.numpy(), sigmoid.numpy(), 'b-', linewidth=2)
ax1.set_title(r'Sigmoid: $\sigma(x) = 1/(1+e^{-x})$', fontsize=13)
ax1.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax1.grid(True, alpha=0.3)
ax1.set_xlabel('x')

ax2.plot(x.numpy(), sigmoid_grad.numpy(), 'r-', linewidth=2)
ax2.axhline(y=0.25, color='gray', linestyle='--', alpha=0.5, label='max = 1/4')
ax2.fill_between(x.numpy(), sigmoid_grad.numpy(), alpha=0.2, color='red')
ax2.set_title(r"Sigmoid derivative: $\sigma'(x) = \sigma(x)(1-\sigma(x))$", fontsize=13)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_xlabel('x')

plt.tight_layout()
plt.show()

The maximum value of $\sigma'(x)$ is $\frac{1}{4}$ (at $x=0$). This means that at each layer, the gradient is multiplied by a factor $\leq \frac{1}{4}$.

For a network with $L$ layers using sigmoid activations:

$$
\left\|\frac{\partial \mathcal{L}}{\partial W_1}\right\| \leq \left(\frac{1}{4}\right)^{L-1} \cdot \left\|\frac{\partial \mathcal{L}}{\partial W_L}\right\|
$$

With $L = 10$ layers: $(1/4)^9 \approx 4 \times 10^{-6}$ — the gradient has **vanished**.


In practice, we can compare gradient norms across layers for networks of increasing depth, using **sigmoid** activation.

In [None]:
def get_gradient_norms(model, X, y):
    """Perform one forward+backward pass and return gradient norms per layer."""
    model.train()
    model.zero_grad()
    logits = model(X)
    loss = nn.CrossEntropyLoss()(logits, y)
    loss.backward()
    grad_norms = []
    for name, param in model.named_parameters():
        if 'weight' in name and param.grad is not None:
            grad_norms.append(param.grad.norm().item())
    return grad_norms


X_spiral, y_spiral = make_spiral_data(n=1000)

# Sigmoid networks of increasing depth
sigmoid_models = [
    ('4 layers', lambda: nn.Sequential(
        nn.Linear(2, 20), nn.Sigmoid(), nn.Linear(20, 20), nn.Sigmoid(),
        nn.Linear(20, 20), nn.Sigmoid(), nn.Linear(20, 20), nn.Sigmoid(),
        nn.Linear(20, 2))),
    ('8 layers', lambda: nn.Sequential(
        nn.Linear(2, 20), nn.Sigmoid(), nn.Linear(20, 20), nn.Sigmoid(),
        nn.Linear(20, 20), nn.Sigmoid(), nn.Linear(20, 20), nn.Sigmoid(),
        nn.Linear(20, 20), nn.Sigmoid(), nn.Linear(20, 20), nn.Sigmoid(),
        nn.Linear(20, 20), nn.Sigmoid(), nn.Linear(20, 20), nn.Sigmoid(),
        nn.Linear(20, 2))),
    ('16 layers', lambda: nn.Sequential(*(
        [nn.Linear(2, 20), nn.Sigmoid()] +
        [l for _ in range(15) for l in [nn.Linear(20, 20), nn.Sigmoid()]] +
        [nn.Linear(20, 2)]))),
]

fig, axes = plt.subplots(1, 3, figsize=(18, 4.5))

for ax, (name, make_model) in zip(axes, sigmoid_models):
    torch.manual_seed(42)
    model = make_model()
    grad_norms = get_gradient_norms(model, X_spiral, y_spiral)
    ax.bar(range(len(grad_norms)), grad_norms[::-1], color='indianred', alpha=0.8)
    ax.set_title(name, fontsize=13)
    ax.set_xlabel('Layer (output → input)', fontsize=11)
    ax.set_ylabel('Gradient norm', fontsize=11)
    ax.grid(True, alpha=0.3, axis='y')

plt.suptitle('Vanishing gradients with Sigmoid activation (at initialization)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## Alternative Activation Functions

### ReLU

The ReLU activation function $\text{ReLU}(x) = \max(0, x)$ has derivative:

$$
\text{ReLU}'(x) = \begin{cases} 1 & \text{if } x > 0 \\ 0 & \text{if } x < 0 \end{cases}
$$

Unlike Sigmoid (whose derivative is always $\leq 1/4$), ReLU **preserves the gradient magnitude** for active neurons ($\text{ReLU}'(x) = 1$). Given **proper initialization**, this keeps gradient norms roughly stable across layers.

In [None]:
# Compare gradient norms: Sigmoid vs ReLU across depths
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

configs = [
    ('Sigmoid (Xavier init)', nn.Sigmoid, lambda w: nn.init.xavier_normal_(w)),
    ('ReLU (Kaiming init)',   nn.ReLU,    lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu')),
]

depths = [4, 8, 16, 32]

for ax, (act_name, Act, init_weight) in zip(axes, configs):
    for d in depths:
        torch.manual_seed(42)
        model = nn.Sequential(*(
            [nn.Linear(2, 20), Act()] +
            [l for _ in range(d - 1) for l in [nn.Linear(20, 20), Act()]] +
            [nn.Linear(20, 2)]))
        for m in model:
            if isinstance(m, nn.Linear):
                init_weight(m.weight)
                nn.init.zeros_(m.bias)
        grad_norms = get_gradient_norms(model, X_spiral, y_spiral)
        ax.plot(range(len(grad_norms)), grad_norms[::-1], 'o-', label=f'{d} layers',
                markersize=3, alpha=0.8)
    ax.set_title(act_name, fontsize=14)
    ax.set_xlabel('Layer (output → input)')
    ax.set_ylabel('Gradient norm')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Gradient norms at initialization (log scale)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

ReLU solved the vanishing gradient problem but introduced the **dead neuron** issue: if a neuron's input is always negative, it never activates and its weights never update.

### Variations

Several variants have been proposed:


- **LeakyReLU**: $\max(\alpha x, x)$
- **PReLU**: $\max(\alpha x, x)$
- **ELU**: $x$ if $x>0$, $\alpha(e^x-1)$ otherwise
- **GELU**: $x \cdot \Phi(x)$
- **SiLU/Swish**: $x \cdot \sigma(x)$

In [None]:
# Visualize activation functions
x = torch.linspace(-4, 4, 300)

activations = {
    'ReLU': F.relu(x),
    'LeakyReLU': F.leaky_relu(x, 0.1),
    'ELU': F.elu(x),
    'GELU': F.gelu(x),
    'SiLU (Swish)': F.silu(x),
}

fig, axes = plt.subplots(1, len(activations), figsize=(4 * len(activations), 3.5))

for ax, (name, y_act) in zip(axes, activations.items()):
    ax.plot(x.numpy(), y_act.numpy(), linewidth=2)
    ax.axhline(y=0, color='gray', linewidth=0.5)
    ax.axvline(x=0, color='gray', linewidth=0.5)
    ax.set_title(name, fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-2, 4)

plt.suptitle('Common Activation Functions', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

We can compare these alternative on our spiral dataset.

In [None]:
# Compare activations on spiral data with a 8-layer network
X_spiral, y_spiral = make_spiral_data(n=1000)

activations = [
    ('Sigmoid',   nn.Sigmoid),
    ('ReLU',      nn.ReLU),
    ('LeakyReLU', nn.LeakyReLU),
    ('GELU',      nn.GELU),
]

fig, axes = plt.subplots(1, 4, figsize=(20, 4))

for ax, (name, Act) in zip(axes, activations):
    torch.manual_seed(42)
    model = nn.Sequential(
        nn.Linear(2, 20), Act(), nn.Linear(20, 20), Act(),
        nn.Linear(20, 20), Act(), nn.Linear(20, 20), Act(),
        nn.Linear(20, 20), Act(), nn.Linear(20, 20), Act(),
        nn.Linear(20, 20), Act(), nn.Linear(20, 20), Act(),
        nn.Linear(20, 2),
    )
    for m in model:
        if isinstance(m, nn.Linear):
            if Act == nn.Sigmoid:
                nn.init.xavier_normal_(m.weight)
            else:
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            nn.init.zeros_(m.bias)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(3000):
        logits = model(X_spiral)
        loss = criterion(logits, y_spiral)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        final_acc = (model(X_spiral).argmax(1) == y_spiral).float().mean().item()

    print(f'{name}: accuracy={final_acc:.2%}')
    plot_decision_boundary(model, X_spiral, y_spiral, ax=ax,
                          title=f'{name} (acc={final_acc:.0%})')

plt.suptitle('8-layer MLP with different activations', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## The Degradation Problem

Unfortunately, deeper plain networks (with ReLU) can have **higher training error** than their shallower counterparts. This is not overfitting — the model fails to fit even the *training data*.

This phenomenon is called the **degradation problem**:


A deeper model should be at least as expressive as a shallower one (the extra layers could, in theory, just learn the identity). But in practice, optimization struggles to find this solution.

In [None]:
# Degradation problem: deeper plain ReLU networks perform worse
X_spiral, y_spiral = make_spiral_data(n=1000)

depths = [4, 8, 16, 32]  # number of hidden layers

fig, axes = plt.subplots(1, len(depths), figsize=(5 * len(depths), 4.5))
all_histories = []

for ax, n_hidden in zip(axes, depths):
    torch.manual_seed(42)

    # Build a plain network with n_hidden hidden layers of width 20
    layers = [nn.Linear(2, 20), nn.ReLU()]
    for _ in range(n_hidden - 1):
        layers += [nn.Linear(20, 20), nn.ReLU()]
    layers.append(nn.Linear(20, 2))
    model = nn.Sequential(*layers)

    # Kaiming initialization
    for m in model:
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            nn.init.zeros_(m.bias)

    n_params = sum(p.numel() for p in model.parameters())

    # Train with SGD (not Adam) — Adam's adaptive LR masks the degradation problem
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    acc_history = []
    for epoch in range(5000):
        logits = model(X_spiral)
        loss = criterion(logits, y_spiral)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                acc = (model(X_spiral).argmax(1) == y_spiral).float().mean().item()
                acc_history.append(acc)

    final_acc = acc_history[-1]
    all_histories.append((n_hidden, acc_history))
    print(f'{n_hidden} hidden layers: accuracy={final_acc:.2%}, params={n_params}')

    plot_decision_boundary(model, X_spiral, y_spiral, ax=ax,
                          title=f'{n_hidden} hidden layers\nacc={final_acc:.0%}, {n_params} params')

plt.suptitle('Plain ReLU networks of increasing depth (SGD)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Training accuracy curves (smoothed with moving average)
fig, ax = plt.subplots(figsize=(10, 5))
colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(depths)))

for (n_hidden, acc_history), color in zip(all_histories, colors):
    acc = np.array(acc_history)
    smoothed = acc
    epochs = np.arange(len(acc)) * 10
    ax.plot(epochs, smoothed, color=color, linewidth=2, label=f'{n_hidden} hidden layers')

ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Training Accuracy', fontsize=12)
ax.set_title('Degradation Problem: Deeper ≠ Better (plain ReLU networks)', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim(0.4, 1.05)
plt.tight_layout()
plt.show()

The deeper networks have **more parameters** and are **strictly more expressive**, yet they achieve lower **training** accuracy. This is not overfitting — it is an optimization failure.

The degradation problem is distinct from vanishing gradients. Even with ReLU and Kaiming initialization, very deep networks are hard to optimize because:

1. **The loss landscape becomes increasingly chaotic** — deeper networks have more saddle points and poor local minima.
2. **Learning the identity is hard** — if the extra layers need to approximate the identity function to match a shallower network's performance, they must learn $W \approx I$, which is a non-trivial optimization target for randomly initialized weights.
3. **Small per-layer errors accumulate** — each layer introduces a small approximation error, and these errors compound through the depth of the network.

## Going deeper: Residual Connections (ResNet)

 Instead of learning a mapping $\mathcal{H}(x)$ directly, each block learns a **residual**:

$$
\mathcal{H}(x) = x + \mathcal{F}(x)
$$

where $\mathcal{F}(x) = W_2 \, \sigma(W_1 x + b_1) + b_2$ is the residual branch.

- If the optimal transformation is close to the identity, the network only needs to learn $\mathcal{F}(x) \approx 0$, which is easy.
- Gradients flow directly through the skip connection: $\frac{\partial \mathcal{H}}{\partial x} = I + \frac{\partial \mathcal{F}}{\partial x}$, preventing vanishing gradients even with dozens of layers.

In [None]:
class ResidualBlock(nn.Module):
    """A residual block: output = x + F(x)."""
    def __init__(self, width):
        super().__init__()
        self.fc1 = nn.Linear(width, width)
        self.fc2 = nn.Linear(width, width)

    def forward(self, x):
        residual = x
        out = F.relu(self.fc1(x))
        out = self.fc2(out)
        return F.relu(out + residual)  # skip connection


class PlainBlock(nn.Module):
    """A plain block (no skip connection): output = F(x)."""
    def __init__(self, width):
        super().__init__()
        self.fc1 = nn.Linear(width, width)
        self.fc2 = nn.Linear(width, width)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(out)
        return F.relu(out)


def make_deep_net(block_class, n_blocks, width=20):
    """Build a deep network by stacking blocks."""
    layers = [nn.Linear(2, width), nn.ReLU()]
    for _ in range(n_blocks):
        layers.append(block_class(width))
    layers.append(nn.Linear(width, 2))
    return nn.Sequential(*layers)

Let's compare **plain networks** (no skip connections) versus **ResNets** as we increase depth. Each block has 2 linear layers, so a network with $n$ blocks has $2n + 2$ total layers.

In [None]:
X_spiral, y_spiral = make_spiral_data(n=1000)

block_counts = [2, 5, 10, 25]  # 6, 12, 22, 52 total layers

fig, axes = plt.subplots(2, len(block_counts), figsize=(5 * len(block_counts), 10))
resnet_histories = {'Plain': [], 'ResNet': []}

for j, n_blocks in enumerate(block_counts):
    total_layers = 2 * n_blocks + 2
    for i, (name, BlockClass) in enumerate([('Plain', PlainBlock), ('ResNet', ResidualBlock)]):
        torch.manual_seed(42)
        model = make_deep_net(BlockClass, n_blocks)
        n_params = sum(p.numel() for p in model.parameters())

        # Train with Adam
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=3000)
        criterion = nn.CrossEntropyLoss()
        acc_history = []
        for epoch in range(3000):
            logits = model(X_spiral)
            loss = criterion(logits, y_spiral)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (epoch + 1) % 10 == 0:
                with torch.no_grad():
                    acc = (model(X_spiral).argmax(1) == y_spiral).float().mean().item()
                    acc_history.append(acc)

        final_acc = acc_history[-1]
        resnet_histories[name].append((n_blocks, total_layers, n_params, acc_history))
        print(f'{name} {total_layers}L: accuracy={final_acc:.2%}, params={n_params}')

        plot_decision_boundary(model, X_spiral, y_spiral, ax=axes[i, j],
            title=f'{name} — {total_layers} layers\nacc={final_acc:.0%}, {n_params} params')

axes[0, 0].set_ylabel('Plain (no skip)', fontsize=14)
axes[1, 0].set_ylabel('ResNet (skip)', fontsize=14)
plt.tight_layout()
plt.show()

The training curves tell the same story: plain networks struggle to converge as depth grows, while ResNets maintain stable training regardless of depth.

In [None]:
# Training curves: Plain vs ResNet (smoothed)
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(block_counts)))

for idx, (n_blocks, total_layers, n_params, acc_history) in enumerate(resnet_histories['Plain']):
    axes[0].plot(np.arange(len(acc_history)) * 10, acc_history,
                 color=colors[idx], linestyle='--', alpha=0.7, linewidth=2,
                 label=f'Plain {total_layers}L')
for idx, (n_blocks, total_layers, n_params, acc_history) in enumerate(resnet_histories['ResNet']):
    axes[0].plot(np.arange(len(acc_history)) * 10, acc_history,
                 color=colors[idx], linestyle='-', alpha=0.9, linewidth=2,
                 label=f'ResNet {total_layers}L')

axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Training Accuracy (dashed=Plain, solid=ResNet)')
axes[0].legend(fontsize=8, ncol=2)
axes[0].grid(True, alpha=0.3)

# Final accuracy vs depth
plain_accs = [h[-1] for _, _, _, h in resnet_histories['Plain']]
resnet_accs = [h[-1] for _, _, _, h in resnet_histories['ResNet']]
total_layers_list = [2 * n + 2 for n in block_counts]

axes[1].plot(total_layers_list, plain_accs, 'o--', color='indianred', linewidth=2,
             markersize=8, label='Plain')
axes[1].plot(total_layers_list, resnet_accs, 's-', color='seagreen', linewidth=2,
             markersize=8, label='ResNet')
axes[1].set_xlabel('Total layers')
axes[1].set_ylabel('Final accuracy')
axes[1].set_title('Final Accuracy vs Depth')
axes[1].legend(fontsize=12)
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim(0.4, 1.05)

plt.tight_layout()
plt.show()