In [None]:
"""残差连接演示"""

import torch
import torch.nn as nn


class ResidualConnection(nn.Module):
    def __init__(self, use_residual):
        super().__init__()
        self.use_residual = use_residual
        # 搭建5个线性层加ReLU激活函数
        self.layers = nn.ModuleList(
            [
                nn.Sequential(nn.Linear(3, 3), nn.ReLU()),
                nn.Sequential(nn.Linear(3, 3), nn.ReLU()),
                nn.Sequential(nn.Linear(3, 3), nn.ReLU()),
                nn.Sequential(nn.Linear(3, 3), nn.ReLU()),
                nn.Sequential(nn.Linear(3, 1), nn.ReLU()),
            ]
        )

    def forward(self, x):
        for layer in self.layers:
            layer_output = layer(x)
            if self.use_residual and x.shape == layer_output.shape:
                x = x + layer_output  # 添加残差连接
            else:
                x = layer_output
        return x


def print_gradients(model, x):
    output = model(x)
    target = torch.tensor([[0.0]])
    loss = nn.MSELoss()
    loss = loss(output, target)
    loss.backward()
    for name, param in model.named_parameters():
        if "weight" in name:
            print(f"{name} 梯度平均值为 {param.grad.abs().mean().item()}")


x = torch.tensor([[1.0, 0.0, 1.0]])
torch.manual_seed(3309)
print("不使用残差连接")
print_gradients(ResidualConnection(use_residual=False), x)
# layers.0.0.weight 梯度平均值为 0.0021177639719098806
# layers.1.0.weight 梯度平均值为 0.004377271514385939
# layers.2.0.weight 梯度平均值为 0.017308441922068596
# layers.3.0.weight 梯度平均值为 0.011054093018174171
# layers.4.0.weight 梯度平均值为 0.240334153175354
print("使用残差连接")
print_gradients(ResidualConnection(use_residual=True), x)
# layers.0.0.weight 梯度平均值为 0.13310331106185913
# layers.1.0.weight 梯度平均值为 0.35141655802726746
# layers.2.0.weight 梯度平均值为 0.23713016510009766
# layers.3.0.weight 梯度平均值为 0.6084821224212646
# layers.4.0.weight 梯度平均值为 2.3913090229034424