In [15]:
import numpy as np

def softmax(z: np.ndarray) -> np.ndarray:
    """
    Computes a numerically stable softmax function.

    Args:
        z (np.ndarray): A 1D or 2D array of raw scores (logits).

    Returns:
        np.ndarray: The probability distribution.
    """
    # Ensure z is at least 2D for consistent max operation
    if z.ndim == 1:
        z = z.reshape(1, -1)

    # The stability trick: subtract the max for each sample
    max_z = np.max(z, axis=1, keepdims=True)
    exp_z = np.exp(z - max_z)
    
    # Normalize to get probabilities
    probabilities = exp_z / np.sum(exp_z, axis=1, keepdims=True)
    
    return probabilities.squeeze() # Remove extra dimension if input was 1D

# --- Example Usage ---
# Works for small numbers
logits = np.array([2.0, 1.0, 0.1])
print(f"Softmax on small logits:\n{softmax(logits)}\n")
# Output: [0.66524096 0.24472847 0.09003057]

# Works for large numbers without overflow
logits_large = np.array([1000, 1010, 990])
print(f"Softmax on large logits:\n{softmax(logits_large)}")
# Output: [4.53978687e-05 9.99954600e-01 2.06106005e-09]

Softmax on small logits:
[0.65900114 0.24243297 0.09856589]

Softmax on large logits:
[4.53978686e-05 9.99954600e-01 2.06106005e-09]


In [None]:
import numpy as np

logits = np.array([2.0, 1.0, 0.1])
print(logits.shape)

if logits.ndim == 1:
    logits = logits.reshape(1, -1) # (2D with 1 row, -1 for infer cols)


print(logits.shape)
# axis = 0 rows, axis = 1 columns
# keepdims preserves the reduced dim as shape of (1, 1) - for broadcasting
# so if to prefer original 1D result, remove extra dim by .squeeze()
max_logits = np.max(logits, axis=1, keepdims=True)
print(max_logits.shape)
print(max_logits.size)

(3,)
(1, 3)
(1, 1)
1


In [18]:
z = np.array([[2.0, 1.0, 0.1]])  # shape (1, 3)
print(np.max(z, axis=1).squeeze())         # [2.0]  (max of the row)
#axis = 1 means "operate across columns (i.e for each row do something across its columns)" i.e. find max in eac row (Across its columns)
# axis = 0 , means down the rows, for each column, do across the columns
z = np.array([[2.0, 1.0, 0.1],
              [5.0, 3.0, 4.0]])  # shape (2, 3)
print(np.max(z, axis=1).squeeze())         # [2.0, 5.0] (max of each row)

2.0
[2. 5.]
