In [1]:
%matplotlib inline


# Generalized Linear Model


## Gamma Regression
Gamma regression can be used when you have positive continuous response variables such as payments for insurance claims,
or the lifetime of a redundant system.
It is well known that the density of Gamma distribution can be represented as a function of
a mean parameter ($\mu$) and a shape parameter ($\alpha$), respectively,
$$
\begin{align}f(y \mid \mu, \alpha)=\frac{1}{y \Gamma(\alpha)}\left(\frac{\alpha y}{\mu}\right)^{\alpha} e^{-\alpha y / \mu} {I}_{(0, \infty)}(y),\end{align}
$$
where $I(\cdot)$ denotes the indicator function. In the Gamma regression model,
response variables are assumed to follow Gamma distributions. Specifically,

\begin{align}y_i \sim Gamma(\mu_i, \alpha),\end{align}


where $1/\mu_i = x_i^T\beta$.

With $n$ independent data of the explanatory variables $x$ and the response variable $y$, we can estimate $\beta$ by minimizing the negative log-likelihood function under sparsity constraint:
$$
\arg \min _{\beta \in R^p} L(\beta):=-\frac{1}{n} \sum_{i=1}^n\left\{-\alpha \left( y_i x_i^T \beta - \log \left(x_i^T \beta\right)\right) + \alpha \log \alpha + \left(\alpha - 1\right) \log y - \log \Gamma \left(\alpha\right) \right\}, \text { s.t. }\|\beta\|_0 \leq s .
$$

Here is Python code for solving sparse gamma regression problem:


In [4]:
np.random.seed(2)

n = 100
p = 10
s = 3
data = make_glm_data(n=n, p=p, k=s, family="gamma")
X = data.x
y = data.y

# Define function to calculate negative log-likelihood of Gamma regression
def gamma_loss(params):
    xbeta = jnp.clip(X @ params, -30, 30)
    return jnp.mean(y * xbeta - jnp.log(xbeta)) 


solver = ScopeSolver(p, s)
solver.solve(gamma_loss, jit=True)

print("True support set: ", np.nonzero(data.coef_)[0])
print("True parameters: ", data.coef_)
print("True loss value: ", gamma_loss(data.coef_))
print("Estimated support set: ", np.sort(solver.support_set))
print("Estimated parameters: ", solver.params)
print("Estimated loss value: ", gamma_loss(solver.params))

True support set:  [2 6 8]
True parameters:  [ 0.          0.         16.84626207  0.          0.          0.
  9.48390875  0.          7.42158219  0.        ]
True loss value:  nan
Estimated support set:  []
Estimated parameters:  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Estimated loss value:  inf
