In [None]:
import torch
from torch import Tensor

# Tutorial 1b: Softmax Function

**Question:** To have the logistic regressor output probabilities, they need to be processed through a softmax layer. Implement a softmax layer yourself. What numerical issues may arise in this layer? How can you solve them? Use the testing code to confirm you implemented it correctly.

In [None]:
logits = torch.rand((1, 20)) + 100

In [None]:
print(logits)

tensor([[100.0316, 100.4614, 100.7729, 100.1885, 100.0520, 100.3409, 100.2897,
         100.1420, 100.0357, 100.6357, 100.9895, 100.4577, 100.5971, 100.7787,
         100.4909, 100.5415, 100.9061, 100.0407, 100.1753, 100.3512]])


In [None]:
def bad_softmax(x: Tensor) -> Tensor:
    return torch.exp(x) / torch.sum(torch.exp(logits), axis=0)

In [None]:
torch.sum(bad_softmax(logits))

tensor(nan)

In [None]:
def good_softmax(x: Tensor) -> Tensor:
    ###########################################################################
    # TODO: Implement a more stable way to compute softmax                    #
    ###########################################################################
    x_exp = torch.exp(x - torch.max(x))
    partition = x_exp.sum(1, keepdims=True)
    print(x_exp / partition)
    
    return x_exp / partition


In [None]:
torch.sum(good_softmax(logits))

tensor([[0.0326, 0.0502, 0.0685, 0.0382, 0.0333, 0.0445, 0.0423, 0.0365, 0.0328,
         0.0597, 0.0851, 0.0500, 0.0575, 0.0689, 0.0517, 0.0544, 0.0783, 0.0329,
         0.0377, 0.0449]])


tensor(1.0000)

Because of numerical issues like the one you just experiences, PyTorch code typically uses a `LogSoftmax` layer.

**Question [optional]:** PyTorch automatically computes the backpropagation gradient of a module for you. However, it can be instructive to derive and implement your own backward function. Try and implement the backward function for your softmax module and confirm that it is correct.

In [None]:
class Softmax(torch.nn.Module):
    def forward(self, x):
        exp = torch.exp(x - x.max(dim=-1, keepdim=True).values)
        softmax = exp / exp.sum(dim=-1, keepdim=True)
        return softmax
    
    def backward(self, grad_output):
        # Compute the gradient of the loss with respect to the softmax output.
        softmax_output = self.forward_output
        grad_softmax = softmax_output * grad_output - softmax_output * (softmax_output * grad_output).sum(dim=-1, keepdim=True)
        grad_softmax = softmax_output * grad_output*(1- softmax_output).sum(dim=-1, keepdim=True)

        # Return the gradient of the loss with respect to the input.
        return grad_softmax



