In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [23]:
class SoftMax:
    def forward(self, x: torch.Tensor):
        if x.ndim == 1:
            x = x.unsqueeze(dim=0)
        max_logits = torch.max(x, dim=-1, keepdim=True).values  # (B, N) -> (B, 1)
        exp_logits = torch.exp(x - max_logits)  # (B, N)
        return exp_logits / torch.sum(exp_logits, dim=-1, keepdim=True)

    def backward(self, y: torch.Tensor, grad_output: torch.Tensor = None):
        # y is the output from the forward pass (softmax output)
        # grad_output is the gradient of the loss with respect to the output of softmax
        
        # Compute the gradient of the softmax function
        batch_size, num_classes = y.shape
        
        if grad_output is not None:
            # Reshape grad_output to (B, N, 1)
            grad_output = grad_output.unsqueeze(-1)  # (B, N, 1)
        
        # Reshape y to (B, N, 1)
        y = y.unsqueeze(-1)  # (B, N, 1)
        
        # Compute the diagonal part of the Jacobian matrix: y * (1 - y)
        diag_jacobian = y * (1 - y)  # (B, N, 1)
        
        # Compute the outer product part of the Jacobian matrix: -y * y^T
        outer_product_jacobian = -y @ y.transpose(1, 2)  # (B, N, N)
        
        # Combine both parts to get the full Jacobian matrix
        eye = torch.eye(num_classes).to(y.device)
        jacobian_matrix = diag_jacobian * eye + outer_product_jacobian * (1- eye)  # (B, N, N)
        
        if grad_output is not None:
            # Multiply the Jacobian matrix with grad_output
            dx = torch.matmul(jacobian_matrix, grad_output).squeeze(-1)  # (B, N)
            return dx
        return jacobian_matrix

In [53]:
class CrossEntropy:
    def __init__(self, fast_backward=False):
        self.softmax = SoftMax()

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        # x: (B, N), y: (B,)
        batch_size, num_classes = x.shape
        self.batch_size = batch_size
        x = self.softmax.forward(x)  # (B, N) -> (B, N)
        self.preds = x
        print(f"SoftMax outputs: {x}")
        one_hot_y = F.one_hot(y, num_classes=num_classes)
        self.targets = one_hot_y
        H = -torch.log(x)
        loss = (H * one_hot_y).sum() / batch_size
        return loss

    def backward(self):
        return (self.preds - self.targets) / self.batch_size

In [24]:
x = torch.tensor([2.0, 1.0, 0.1], requires_grad=True)
softmax = SoftMax()
y = softmax.forward(x)
y

tensor([[0.6590, 0.2424, 0.0986]], grad_fn=<DivBackward0>)

In [25]:
softmax.backward(y)

tensor([[[ 0.2247, -0.1598, -0.0650],
         [-0.1598,  0.1837, -0.0239],
         [-0.0650, -0.0239,  0.0889]]], grad_fn=<AddBackward0>)

In [27]:
print(torch.autograd.grad(y[0, 0], x, retain_graph = True))
print(torch.autograd.grad(y[0, 1], x, retain_graph = True))
print(torch.autograd.grad(y[0, 2], x, retain_graph = True))

(tensor([ 0.2247, -0.1598, -0.0650]),)
(tensor([-0.1598,  0.1837, -0.0239]),)
(tensor([-0.0650, -0.0239,  0.0889]),)


In [54]:
x = torch.tensor([[0.2,0.3,0.5],[0.3,0.2,0.5],[0.4,0.4,0.2]], requires_grad=True)
y = torch.tensor([1,0,2])
ce = CrossEntropy()
loss = ce.forward(x, y)
print(f"My CE: {loss}")
_loss = F.cross_entropy(x, y)
print(f"Torch CE: {_loss}")

SoftMax outputs: tensor([[0.2894, 0.3199, 0.3907],
        [0.3199, 0.2894, 0.3907],
        [0.3548, 0.3548, 0.2905]], grad_fn=<DivBackward0>)
My CE: 1.1719828844070435
Torch CE: 1.1719828844070435


In [55]:
ce.backward()

tensor([[ 0.0965, -0.2267,  0.1302],
        [-0.2267,  0.0965,  0.1302],
        [ 0.1183,  0.1183, -0.2365]], grad_fn=<DivBackward0>)

In [56]:
print(torch.autograd.grad(loss, x, retain_graph = True))

(tensor([[ 0.0965, -0.2267,  0.1302],
        [-0.2267,  0.0965,  0.1302],
        [ 0.1183,  0.1183, -0.2365]]),)
