In [2]:
'''
梯度反传规律研究
'''
import torch
import torch.nn as nn
from torch.optim import Adam, SGD, RMSprop, AdamW
import torch.nn.functional as F

def check_missing_gradients(model):
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is None:
            print(f"[ERROR] 参数 '{name}' 需要梯度（requires_grad=True），但未收到梯度！")
        elif param.requires_grad and torch.all(param.grad == 0):
            print(f"[WARNING] 参数 '{name}' 需要梯度，但梯度全零（可能未被正确更新）")

class testNet(nn.Module):
    def __init__(self, dim_in: int=10, dim_out: int=4, **kwargs) -> None:
        super().__init__(**kwargs)

        self.net = nn.Sequential(
            nn.Linear(dim_in, dim_in),
            nn.LayerNorm(dim_in),
            nn.ReLU(),
            nn.Linear(dim_in, dim_out),
        )

        self.extr = nn.Linear(dim_in, dim_out)

        # 优化器
        self.optimizer = Adam(                              # 网络参数优化器
            self.parameters(),
            lr = 1e-3,
            betas=(0.9, 0.999),
            weight_decay = 1e-6,
        )

    def loss(self, results, labels):
        '''
        results: tensor, [B, dim_out]
        labels: tensor, [B, ]
        '''
        loss = F.cross_entropy(results, labels)
        return loss


    def forward(self, x):
        out = self.net(x)
        return out

batch_size = 3
dim_in = 10
dim_out = 4

datas = torch.randn((batch_size, dim_in))
labels = torch.randint(0, dim_out, (batch_size, ))

net = testNet(dim_in, dim_out)

results = net(datas)
loss = net.loss(results, labels)
net.optimizer.zero_grad()
loss.backward()
check_missing_gradients(net)
net.optimizer.step()

[ERROR] 参数 'extr.weight' 需要梯度（requires_grad=True），但未收到梯度！
[ERROR] 参数 'extr.bias' 需要梯度（requires_grad=True），但未收到梯度！
