Skip to content
51 changes: 51 additions & 0 deletions machine_learning/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,57 @@ def binary_cross_entropy(
return np.mean(bce_loss)


def binary_focal_cross_entropy(
y_true: np.ndarray,
y_pred: np.ndarray,
gamma: float = 2.0,
alpha: float = 0.25,
epsilon: float = 1e-15,
) -> float:
"""
Calculate the mean binary focal cross-entropy (BFCE) loss between true labels
and predicted probabilities.

BFCE loss quantifies dissimilarity between true labels (0 or 1) and predicted
probabilities. It's a variation of binary cross-entropy that addresses class
imbalance by focusing on hard examples.

BCFE = -Σ(alpha * (1 - y_pred)**gamma * y_true * log(y_pred)
+ (1 - alpha) * y_pred**gamma * (1 - y_true) * log(1 - y_pred))

Reference: [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf)

Parameters:
- y_true: True binary labels (0 or 1).
- y_pred: Predicted probabilities for class 1.
- gamma: Focusing parameter for modulating the loss (default: 2.0).
- alpha: Weighting factor for class 1 (default: 0.25).
- epsilon: Small constant to avoid numerical instability.

>>> true_labels = np.array([0, 1, 1, 0, 1])
>>> predicted_probs = np.array([0.2, 0.7, 0.9, 0.3, 0.8])
>>> binary_focal_cross_entropy(true_labels, predicted_probs)
0.008257977659239775
>>> true_labels = np.array([0, 1, 1, 0, 1])
>>> predicted_probs = np.array([0.3, 0.8, 0.9, 0.2])
>>> binary_focal_cross_entropy(true_labels, predicted_probs)
Traceback (most recent call last):
...
ValueError: Input arrays must have the same length.
"""
if len(y_true) != len(y_pred):
raise ValueError("Input arrays must have the same length.")
# Clip predicted probabilities to avoid log(0)
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)

bcfe_loss = -(
alpha * (1 - y_pred) ** gamma * y_true * np.log(y_pred)
+ (1 - alpha) * y_pred**gamma * (1 - y_true) * np.log(1 - y_pred)
)

return np.mean(bcfe_loss)


def categorical_cross_entropy(
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-15
) -> float:
Expand Down