In [30]:
import torch

In [37]:
def G_modified(X, model):
    # 开始计时
    # start = time.time()
    
    input_dim, m = model.W.shape  # m: 隐藏层神经元数量, input_dim: 输入维度
    batch_size = X.shape[0]       # batch_size: 批处理大小
    
    # 初始化 Jacobian 矩阵 J，大小为 (batch_size, m * (input_dim + 1))
    J = torch.zeros(batch_size, m * (input_dim + 1), device=X.device)
    
    # 计算所有样本的 <w_i, x> 和 ReLU 激活
    relu_input = X @ model.W  # (batch_size, m)
    relu_output = torch.relu(relu_input)  # (batch_size, m)
    
    # 对 w_i 的部分并行计算 Jacobian
    for j in range(m):
        mask = (relu_input[:, j] > 0).float()  # 只选择 ReLU 激活大于0的元素
        J[:, j*input_dim:(j+1)*input_dim] = (model.a[j] * X * mask.view(-1, 1)) / m
        print(j, (model.a[j] * X * mask.view(-1, 1)) / m)
    # 对 a_i 的部分并行计算 Jacobian
    J[:, m*input_dim:] = relu_output / m
    return J

In [38]:
class test_model(torch.nn.Module):
    def __init__(self, D, m, W: torch.Tensor, a: torch.Tensor):
        super(test_model, self).__init__()
        self.m = m
        self.W = torch.nn.Parameter(W, requires_grad=True)
        self.a = torch.nn.Parameter(a, requires_grad=True)

    def forward(self, X):
        return torch.relu(X @ self.W) @ self.a / self.m

In [80]:
D, m = 3, 2
W = torch.tensor([[0.5, -0.3], [0.8, 0.6], [-0.2, 0.7]])
a = torch.tensor([0.9, -1.1])
X = torch.tensor([[1., 2., 3.], [0.5, 1.0, 1.5]])
model = test_model(3, 2, W, a)

In [81]:
# 计算 G_modified
J = G_modified(X, model)
print(J)

0 tensor([[0.4500, 0.9000, 1.3500],
        [0.2250, 0.4500, 0.6750]], grad_fn=<DivBackward0>)
1 tensor([[-0.5500, -1.1000, -1.6500],
        [-0.2750, -0.5500, -0.8250]], grad_fn=<DivBackward0>)
tensor([[ 0.4500,  0.9000,  1.3500, -0.5500, -1.1000, -1.6500,  0.7500,  1.5000],
        [ 0.2250,  0.4500,  0.6750, -0.2750, -0.5500, -0.8250,  0.3750,  0.7500]],
       grad_fn=<CopySlices>)


In [82]:
P = torch.zeros((m*(D+1), m*(D+1)))

for i in range(D):
    for j in range(m):
        P[j*D + i, i*m + j] = 1
# 令最后m行m列为单位矩阵
for i in range(m*D, m*(D+1)):
    P[i, i] = 1

print(torch.mm(J, P))

tensor([[ 0.4500, -0.5500,  0.9000, -1.1000,  1.3500, -1.6500,  0.7500,  1.5000],
        [ 0.2250, -0.2750,  0.4500, -0.5500,  0.6750, -0.8250,  0.3750,  0.7500]],
       grad_fn=<MmBackward0>)


In [79]:
def auto_grad_G(X, model):
    # length = model.W.shape[0] * model.W.shape[1]
    # height = X.shape[0]
    # J = torch.zeros(height, length)
    # y = model(X).flatten()
    # grad_y = torch.zeros(y.shape)
    # for i in range(y.shape[0]):
    #     grad_y.zero_()
    #     grad_y[i] = 1
    #     w_grad = torch.autograd.grad(y, model.W, grad_y, retain_graph=True, create_graph=True)[0]
    #     J[i] = w_grad.flatten()
    # return J
    output = model(X)
    output.backward()
    return model.W.grad, model.a.grad

model = test_model(3, 2, W, a)
X = torch.tensor([[0.5, 1.0, 1.5]])
w_grad, a_grad = auto_grad_G(X, model)
print(model.W.data.flatten())
print(w_grad.flatten())
print(model.a.data.flatten(), a_grad.flatten())

tensor([ 0.5000, -0.3000,  0.8000,  0.6000, -0.2000,  0.7000])
tensor([ 0.2250, -0.2750,  0.4500, -0.5500,  0.6750, -0.8250])
tensor([ 0.9000, -1.1000]) tensor([0.3750, 0.7500])


In [57]:
tmp_G = torch.tensor([0.45, 0.9, 1.35, -0.55, -1.1, -1.65])
print(tmp_G.reshape(model.W.shape))

tensor([[ 0.4500,  0.9000],
        [ 1.3500, -0.5500],
        [-1.1000, -1.6500]])
