In [1]:
import numpy as np


class Softmax:
    def __init__(self):
        self.output: np.ndarray = np.array([])
    
    def __call__(self, x: np.ndarray, normalization=True) -> np.ndarray:
        if normalization:
            exponents = np.exp(x - np.max(x, axis=1, keepdims=True))
        else:
            exponents = np.exp(x)
        
        self.output = exponents / np.sum(exponents, axis=1, keepdims=True)
        
        return self.output
    
    def backward(self) -> np.ndarray:
        gradients: list[np.ndarray] = []
        for batch in self.output:
            current_gradient = np.zeros((batch.shape[0], batch.shape[0]))
            for i in range(batch.shape[0]):
                for j in range(batch.shape[0]):
                    current_gradient[i, j] = batch[i] * (1 - batch[j]) if i == j else -batch[i] * batch[j]
            gradients.append(current_gradient)
        return np.array(gradients)


In [15]:
x = np.array([[1, 2, 3, 4, 5],
              [5, 4, 3, 2, 1],
              [0.001, 0.002, 0.0015, 0.003, 10]])

# example with normalization
softmax_layer = Softmax()
print(f'output:\n{softmax_layer(x)}')
print(f'backpropagation:\n{softmax_layer.backward()}')

output:
[[1.16562310e-02 3.16849208e-02 8.61285444e-02 2.34121657e-01
  6.36408647e-01]
 [6.36408647e-01 2.34121657e-01 8.61285444e-02 3.16849208e-02
  1.16562310e-02]
 [4.54370855e-05 4.54825454e-05 4.54598098e-05 4.55280507e-05
  9.99818093e-01]]
backpropagation:
[[[ 1.15203632e-02 -3.69326755e-04 -1.00393421e-03 -2.72897611e-03
   -7.41812617e-03]
  [-3.69326755e-04  3.06809866e-02 -2.72897611e-03 -7.41812617e-03
   -2.01645576e-02]
  [-1.00393421e-03 -2.72897611e-03  7.87104183e-02 -2.01645576e-02
   -5.48129504e-02]
  [-2.72897611e-03 -7.41812617e-03 -2.01645576e-02  1.79308707e-01
   -1.48997047e-01]
  [-7.41812617e-03 -2.01645576e-02 -5.48129504e-02 -1.48997047e-01
    2.31392681e-01]]

 [[ 2.31392681e-01 -1.48997047e-01 -5.48129504e-02 -2.01645576e-02
   -7.41812617e-03]
  [-1.48997047e-01  1.79308707e-01 -2.01645576e-02 -7.41812617e-03
   -2.72897611e-03]
  [-5.48129504e-02 -2.01645576e-02  7.87104183e-02 -2.72897611e-03
   -1.00393421e-03]
  [-2.01645576e-02 -7.41812617e-03 -

In [16]:
# example without normalization
softmax_wo_norm_layer = Softmax()
print(f'output:\n{softmax_wo_norm_layer(x)}')
print(f'backpropagation:\n{softmax_wo_norm_layer.backward()}')

output:
[[1.16562310e-02 3.16849208e-02 8.61285444e-02 2.34121657e-01
  6.36408647e-01]
 [6.36408647e-01 2.34121657e-01 8.61285444e-02 3.16849208e-02
  1.16562310e-02]
 [4.54370855e-05 4.54825454e-05 4.54598098e-05 4.55280507e-05
  9.99818093e-01]]
backpropagation:
[[[ 1.15203632e-02 -3.69326755e-04 -1.00393421e-03 -2.72897611e-03
   -7.41812617e-03]
  [-3.69326755e-04  3.06809866e-02 -2.72897611e-03 -7.41812617e-03
   -2.01645576e-02]
  [-1.00393421e-03 -2.72897611e-03  7.87104183e-02 -2.01645576e-02
   -5.48129504e-02]
  [-2.72897611e-03 -7.41812617e-03 -2.01645576e-02  1.79308707e-01
   -1.48997047e-01]
  [-7.41812617e-03 -2.01645576e-02 -5.48129504e-02 -1.48997047e-01
    2.31392681e-01]]

 [[ 2.31392681e-01 -1.48997047e-01 -5.48129504e-02 -2.01645576e-02
   -7.41812617e-03]
  [-1.48997047e-01  1.79308707e-01 -2.01645576e-02 -7.41812617e-03
   -2.72897611e-03]
  [-5.48129504e-02 -2.01645576e-02  7.87104183e-02 -2.72897611e-03
   -1.00393421e-03]
  [-2.01645576e-02 -7.41812617e-03 -