<a href="https://colab.research.google.com/github/AshishBora/TorchLeet/blob/main/softmax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Implement Softmax function from scratch

import torch
import numpy as np


def custom_softmax(logits, axis=0):
    max_logits = torch.max(logits, axis, keepdims=True)[0]
    logits = logits - max_logits
    exp_logits = torch.exp(logits)
    axis_sum = torch.sum(exp_logits, axis, keepdims=True)
    return exp_logits / axis_sum


# TODO: Make this a parameterized test
def test_softmax():
    # test 1
    logits = torch.Tensor([0, 0, 0, 0])
    expected_softmax = [0.25 for _ in range(4)]
    softmax = custom_softmax(logits)
    np.testing.assert_allclose(softmax, expected_softmax)

    # test 2
    logits = torch.Tensor([1, 1, 1, 1])
    expected_softmax = [0.25 for _ in range(4)]
    softmax = custom_softmax(logits)
    np.testing.assert_allclose(softmax, expected_softmax)

    # test 3
    logits = torch.Tensor([-np.inf, -np.inf, 1, 1])
    expected_softmax = [0, 0, 0.5, 0.5]
    softmax = custom_softmax(logits)
    np.testing.assert_allclose(softmax, expected_softmax)

    # test 4
    logits = torch.Tensor([0, 1, 2, 3])
    expected_softmax = [0.0320586 , 0.08714432, 0.23688282, 0.64391426]
    softmax = custom_softmax(logits)
    np.testing.assert_allclose(softmax, expected_softmax, atol=1e-4)


test_softmax()

In [None]:
!pip install parameterized

In [None]:
import unittest
from parameterized import parameterized

class TestSequence(unittest.TestCase):
    @parameterized.expand([
        ["test1", [0, 0, 0, 0], [0.25 for _ in range(4)],],
        ["test2", [0, 0, 0, 0], [0.25 for _ in range(4)],],
        ["test3", [1, 1, 1, 1], [0.25 for _ in range(4)],],
        ["test4", [-np.inf, -np.inf, 1, 1], [0, 0, 0.5, 0.5],],
        ["test5", [0, 1, 2, 3], [0.0320586 , 0.08714432, 0.23688282, 0.64391426],]
    ])
    def test_sequence(self, name, logits, expected_softmax):
        logits = torch.Tensor(logits)
        softmax = custom_softmax(logits)
        print(softmax)
        np.testing.assert_allclose(softmax, expected_softmax, atol=1e-4)


unittest.main(argv=[''], verbosity=2, exit=False)