In the usual binary classification task, the loss is usually defined w.r.t. the predicted probability $p=\frac{1}{1+\exp(-\beta^{\top}x)}$ as the following cross entropy
$$
\text{CrossEntropy}(p)=\left\{
\begin{aligned}
&-\log(p), &y=1 \\
&-\log(1-p), &y=0 \\
\end{aligned}
\right.
$$
However, in the setteing where positive and negative class are extremly imbalanced, an alternative is the following focal loss
$$
\text{FocalLoss}(p;\alpha,\gamma)=\left\{
\begin{aligned}
&-\alpha (1-p)^{\gamma}\log(p), &y=1 \\
&-(1-\alpha) p^{\gamma}\log(1-p), &y=0 \\
\end{aligned}
\right.
$$
Here, we show how to perform a sparse imbalanced binary classification task using ``focal loss`` and ``scope``.

In [1]:
import numpy as np
np.random.seed(123)
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scope import ScopeSolver

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

In [2]:
X, y = load_breast_cancer(return_X_y=True)
idx_drop = np.where(y==0)[0][:170]  
# rng = np.random.default_rng(0)
X, y = np.delete(X, idx_drop, axis=0), np.delete(y, idx_drop)  # drop some samples of class 1 to make an imbalanced dataset
X = (X - X.mean(0)) / X.std(0)  # standardize X 
X = np.hstack((X, np.random.randn(X.shape[0], 100-X.shape[1])))  # append some noise features
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

print('X shape: ', X.shape)
print('Imbalanced ratio: ', np.round((y==1).sum() / (y==0).sum(), 3))

X shape:  (399, 100)
Imbalanced ratio:  8.5


In [3]:
def cross_entropy(params):
    prob = 1 / (1 + jnp.exp(- X_train @ params))
    loss = jnp.mean(jnp.piecewise(prob, 
                                 [y_train==1, y_train==0], 
                                 [lambda x: -jnp.log(x), lambda x: -jnp.log(1-x)]
                    )
            )
    return loss

solver = ScopeSolver(dimensionality=X_train.shape[1], sparsity=8)
params = solver.solve(cross_entropy)

y_pred = ((1 / (1 + jnp.exp(- X_test @ params))) >= 0.5).astype(int)
print('F1 score of cross entropy: ', f1_score(y_test, y_pred).round(3))

F1 score of cross entropy:  0.865


In [4]:
def focal_loss(params):
    alpha, gamma = 0.1, 2
    prob = 1 / (1 + jnp.exp(- X_train @ params))
    loss = jnp.mean(jnp.piecewise(prob, 
                                 [y_train==1, y_train==0], 
                                 [lambda x: - alpha * (1-x)**gamma * jnp.log(x), lambda x: - (1-alpha) * x**gamma * jnp.log(1-x)]
                    )
            )
    return loss

solver = ScopeSolver(dimensionality=X_train.shape[1], sparsity=8)
params = solver.solve(focal_loss)

y_pred = ((1 / (1 + jnp.exp(- X_test @ params))) >= 0.5).astype(int)
print('F1 score of focal loss: ', f1_score(y_test, y_pred).round(3))

F1 score of focal loss:  0.941
