In [None]:
import torch
from torch import Tensor

# Tutorial 1b: Softmax Function

**Question:** To have the logistic regressor output probabilities, they need to be processed through a softmax layer. Implement a softmax layer yourself. What numerical issues may arise in this layer? How can you solve them? Use the testing code to confirm you implemented it correctly.

In the softmax layer, numerical issues can arise due to exponentiating large or very negative numbers.

> This can lead to numerical instability and
>> - overflow : when numbers become too large to represent
>> - underflow : when numbers become too close to zero to represent, which can result in NaN (for example the output of our bad_softmax function) values


To solve these numerical issues, several techniques exist but we used the Numerical Stability Techniques in our good_softmax function:

> This approach is to normalize the input values before applying the softmax function. This can involve subtracting the maximum value from each input value, which prevents overflow by ensuring that the largest input value becomes zero. The resulting probabilities are mathematically equivalent but more numerically stable.

In [None]:
logits = torch.rand((1, 20)) + 100

In [None]:
logits

tensor([[100.6454, 100.2599, 100.5345, 100.7686, 100.6782, 100.7261, 100.9622,
         100.0570, 100.1590, 100.5197, 100.2220, 100.3827, 100.8921, 100.6602,
         100.4100, 100.7493, 100.8350, 100.5977, 100.9931, 100.5157]])

In [None]:
def bad_softmax(x: Tensor) -> Tensor:
    return torch.exp(x) / torch.sum(torch.exp(logits), axis=0)

In [None]:
torch.sum(bad_softmax(logits))

tensor(nan)

In [None]:
def good_softmax(x: Tensor) -> Tensor:
    ###########################################################################
    # TODO: Implement a more stable way to compute softmax                    #
    ###########################################################################
  z = x - max(x)
  softmax = torch.exp(z)/torch.sum(torch.exp(z))
  return softmax

In [None]:
torch.sum(good_softmax(logits))

tensor(1.0000)

Because of numerical issues like the one you just experiences, PyTorch code typically uses a `LogSoftmax` layer.

**Question [optional]:** PyTorch automatically computes the backpropagation gradient of a module for you. However, it can be instructive to derive and implement your own backward function. Try and implement the backward function for your softmax module and confirm that it is correct.

In [None]:
import torch

class CustomSoftmax(torch.nn.Module):
    def forward(self, x):
        exp_x = torch.exp(x)
        softmax = exp_x / torch.sum(exp_x, dim=1, keepdim=True)
        self.softmax = softmax  # Save for backward pass
        return softmax

    def backward(self, grad_output):
        grad_input = []
        for i, s in enumerate(self.softmax):
            jacobian = torch.diag(s) - torch.outer(s, s)
            grad_input.append(torch.matmul(jacobian.t(), grad_output[i]))
        return torch.stack(grad_input)

In [None]:
# Testing the custom softmax module
softmax_module = CustomSoftmax()
logits = torch.rand(1, 20,requires_grad=True) #+ 100
#x = torch.randn(10, 5, requires_grad=True)
softmax = softmax_module(logits)
loss = softmax.sum()
loss.backward()

# Checking gradients
print("Custom softmax gradients:")
print(logits.grad)

# Comparing with PyTorch's autograd
torch_softmax = torch.nn.functional.softmax(logits, dim=1)
loss = torch_softmax.sum()
loss.backward()
print("\nPyTorch softmax gradients:")
print(logits.grad)

Custom softmax gradients:
tensor([[-6.4832e-09, -6.0406e-09, -8.2848e-09, -6.0312e-09, -5.5803e-09,
         -7.5740e-09, -8.1076e-09, -4.0920e-09, -4.4220e-09, -3.8470e-09,
         -3.7530e-09, -6.8119e-09, -6.1010e-09, -4.9001e-09, -4.1363e-09,
         -4.6201e-09, -3.7287e-09, -9.3406e-09, -4.7468e-09, -9.3086e-09]])

PyTorch softmax gradients:
tensor([[-6.4832e-09, -6.0406e-09, -8.2848e-09, -6.0312e-09, -5.5803e-09,
         -7.5740e-09, -8.1076e-09, -4.0920e-09, -4.4220e-09, -3.8470e-09,
         -3.7530e-09, -6.8119e-09, -6.1010e-09, -4.9001e-09, -4.1363e-09,
         -4.6201e-09, -3.7287e-09, -9.3406e-09, -4.7468e-09, -9.3086e-09]])
