In [14]:
import torch
from torch.autograd import Function
import torch.nn.functional as F

In [26]:
class SimpleSoftmax(Function):
    def forward(ctx, x):
        exp_x = torch.exp(x - x.max(dim=-1, keepdim=True).values)
        sum_exp = exp_x.sum(dim=-1, keepdim=True)
        result = exp_x / sum_exp
        ctx.save_for_backward(result)
        return result

    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        total_effect = (grad_output * result).sum(dim=-1, keepdim=True)
        grad_input = result * (grad_output - total_effect)
        return grad_input

def my_simple_softmax(x):
    return SimpleSoftmax.apply(x)

In [29]:
#与库函数对比一下
x = torch.randn(2, 3, requires_grad=True)
print(x)

y_mine = my_simple_softmax(x)
print("\nMine:\n", y_mine)

y_official = F.softmax(x, dim=-1)
print("\nOfficial:\n", y_official)

print("\nTensor close?", 
      torch.allclose(y_mine, y_official, atol=1e-6))

loss_mine = y_mine.sum()
loss_official = y_official.sum()

loss_mine.backward()
grad_mine = x.grad.clone()

x.grad.zero_()
loss_official.backward()
grad_official = x.grad

print("\nMy gradient:\n", grad_mine)
print("\nOfficial gradient:\n", grad_official)

# 比较反向传播梯度
print("\nGradient close?",
      torch.allclose(grad_mine, grad_official, atol=1e-6))

tensor([[ 0.5753,  0.1906, -1.5559],
        [-0.0312,  0.7530,  0.3748]], requires_grad=True)

Mine:
 tensor([[0.5557, 0.3783, 0.0660],
        [0.2132, 0.4670, 0.3199]], grad_fn=<SimpleSoftmaxBackward>)

Official:
 tensor([[0.5557, 0.3783, 0.0660],
        [0.2132, 0.4670, 0.3199]], grad_fn=<SoftmaxBackward0>)

Tensor close? True

My gradient:
 tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.2705e-08, 2.7832e-08, 1.9067e-08]])

Official gradient:
 tensor([[0., 0., 0.],
        [0., 0., 0.]])

Gradient close? True
