In [1]:
%matplotlib inline

## Multinomial Logistic Regression

Multinomial logistic regression is a type of regression analysis used to predict the probabilities of multiple categorical outcomes. It is an extension of binary logistic regression, which is used to predict the probability of a binary outcome.

### Mathematical Derivation

In multinomial logistic regression, we have multiple categories, denoted by $k=1,2,...,K$. We want to predict the probability of each category given a set of predictor variables $X$. We assume that the probability of each category is a function of the predictor variables, and that the probabilities for each category sum to 1.

We can model the probability of each category using the softmax function:

$$P(Y=k|X=x) = \frac{e^{\beta _{0k} + \beta _k^TX}}{\sum_{j=1}^K e^{\beta _{0j} + \beta _j^TX}}$$

where $Y$ is the categorical outcome, $X$ is the vector of predictor variables, $\beta _{0k}$ and $\beta _k$ are the intercept and coefficient vectors for category $k$, and $e$ is the base of the natural logarithm.

The softmax function ensures that the probabilities for each category sum to 1. The numerator of the function represents the probability of category $k$, and the denominator represents the sum of the probabilities for all categories.

We can estimate the coefficients using maximum likelihood estimation. The likelihood function for multinomial logistic regression is:

$$L(\beta) = \prod _{i=1}^n \prod _{k=1}^K P(Y_i=k|X_i=x_i)^{I(Y_i=k)}$$

where $n$ is the number of observations, $I(Y_i=k)$ is an indicator function that equals 1 if $Y_i=k$ and 0 otherwise, and $P(Y_i=k|X_i=x_i)$ is the predicted probability of category $k$ for observation $i$.

The negative log-likelihood function is:

$$-l(\beta) = -\sum _{i=1}^n \sum _{k=1}^K I(Y_i=k) \log P(Y_i=k|X_i=x_i)$$

This is the function that we want to minimize in order to estimate the coefficients. We can use scope algorithm to find the values of $\beta$ with sparsity constraints that minimize the negative log-likelihood function.

Here is Python code for solving sparse gamma regression problem:

In [2]:
import jax.numpy as jnp
import numpy as np
from scope import ScopeSolver
import numpy as np
from abess.datasets import make_multivariate_glm_data
np.random.seed(3)

n = 500  # sample size
p = 20  # all predictors
k = 5   # real predictors
m = 3   # number of classes


data = make_multivariate_glm_data(n=n, p=p, k=k, family="multinomial", M=m)

X = data.x
y = data.y

def multinomial_regression_loss(params):
    beta = params.reshape((p, m))
    # Compute the logits
    logits = jnp.dot(X, beta)

    # Compute the softmax probabilities
    softmax_probs = jnp.exp(logits) / jnp.sum(jnp.exp(logits), axis=1, keepdims=True)

    # Compute the NLL loss
    loss = -jnp.mean(jnp.sum(y * jnp.log(softmax_probs), axis=1))

    return loss


solver = ScopeSolver(p*(m), k, group=[i for i in range(p) for j in range(m)])
solver.solve(multinomial_regression_loss, jit=True)


print("True parameter:\n", data.coef_)
print('real variables\' index:\n', set(np.nonzero(data.coef_)[0]))
print("Estimated parameter:\n", solver.params.reshape((p, m)))
print('Estimated variables\' index:\n', set(np.nonzero(solver.params.reshape((p, m)))[0]))

True parameter:
 [[  5.44916029  -0.94953634   0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [ -1.39241163 -12.96678673   0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  3.24543565   4.02033588   0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [ -1.38210809   4.07755579   0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  5.79719104   3.61096451   0.        ]]
real variables' index:
 {0, 3, 7, 10, 19}
Estimated parameter:
 [[ 1.34862198 -0.73559236 -0.61302938]
 [ 0.          0.          0.        ]
