In [1]:
import torch
from torch.autograd import Function
import torch.nn.functional as F
from torch import nn

In [3]:
class SimpleSoftmax(Function):
    def forward(ctx, x):
        exp_x = torch.exp(x - x.max(dim=-1, keepdim=True).values)
        sum_exp = exp_x.sum(dim=-1, keepdim=True)
        result = exp_x / sum_exp
        ctx.save_for_backward(result)
        return result

    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        total_effect = (grad_output * result).sum(dim=-1, keepdim=True)
        grad_input = result * (grad_output - total_effect)
        return grad_input

def my_simple_softmax(x):
    return SimpleSoftmax.apply(x)

In [5]:
#与库函数对比一下
x = torch.randn(4, 3, requires_grad=True)
print(x)

labels = torch.tensor([0, 2, 1, 2])

y_mine = my_simple_softmax(x)
print("\nMine:\n", y_mine)

y_official = F.softmax(x, dim=-1)
print("\nOfficial:\n", y_official)

loss = nn.CrossEntropyLoss()

print("\nTensor close?", 
      torch.allclose(y_mine, y_official, atol=1e-6))

loss_mine = loss(y_mine, labels)
loss_official = loss(y_official, labels)

loss_mine.backward()
grad_mine = x.grad.clone()

x.grad.zero_()
loss_official.backward()
grad_official = x.grad

print("\nMy gradient:\n", grad_mine)
print("\nOfficial gradient:\n", grad_official)

# 比较反向传播梯度
print("\nGradient close?",
      torch.allclose(grad_mine, grad_official, atol=1e-6))

tensor([[-0.8923,  0.7547,  0.5274],
        [ 1.3074, -0.0950,  0.9085],
        [-0.6985,  0.1870,  1.0975],
        [-0.7084,  2.2288, -0.2792]], requires_grad=True)

Mine:
 tensor([[0.0968, 0.5027, 0.4005],
        [0.5216, 0.1283, 0.3501],
        [0.1058, 0.2565, 0.6376],
        [0.0467, 0.8815, 0.0718]], grad_fn=<SimpleSoftmaxBackward>)

Official:
 tensor([[0.0968, 0.5027, 0.4005],
        [0.5216, 0.1283, 0.3501],
        [0.1058, 0.2565, 0.6376],
        [0.0467, 0.8815, 0.0718]], grad_fn=<SoftmaxBackward0>)

Tensor close? True

My gradient:
 tensor([[-0.0243,  0.0157,  0.0087],
        [ 0.0507,  0.0083, -0.0590],
        [ 0.0034, -0.0531,  0.0497],
        [-0.0023,  0.0236, -0.0213]])

Official gradient:
 tensor([[-0.0243,  0.0157,  0.0087],
        [ 0.0507,  0.0083, -0.0590],
        [ 0.0034, -0.0531,  0.0497],
        [-0.0023,  0.0236, -0.0213]])

Gradient close? True
