In [1]:
%matplotlib inline

Multiple linear regression is also known as multi output linear regression, multivariate linear regression, or multidimensional linear regression, which is used to predict multiple dependent variables using a linear combination of multiple independent variables. This is different from simple linear regression, which is used to predict a single dependent variable using a single independent variable.

The model is expressed as:
$$
    y = B^* x + \epsilon,
$$
where $y$ is an $m$-dimensional response variable, $x$ is $p$-dimensional predictors, $B \in R^{m \times p}$ is the sparse coefficient matrix, $\epsilon$ is an $m$-dimensional random noise variable with zero mean.

With $n$ independent data of the explanatory variables $X$ and the response variable $Y$, we can estimate $B^* $ by minimizing the objective function under sparsity constraint:
$$ arg\min_{B}L(B) := ||Y-B X||^2, s.t.  || B ||_ {0,2} \leq s, $$
where $|| B ||_ {0, 2}$ is the number of non-zero rows of $B$.

Here is Python code for solving sparse multiple linear regression problem:


### Import necessary packages 

In [4]:
from scope import ScopeSolver
import jax.numpy as jnp
import numpy as np
from sklearn.datasets import make_regression

### Set a seed for experimental reproduction

In [5]:
np.random.seed(5)

### Generate the data

In [6]:
n, p, k, m = 10, 5, 3, 2
x, y, coef = make_regression(n_samples=n, n_features=p, n_informative=k, n_targets=m, coef=True)

### Define inverse gaussian regression loss

In [7]:
def multi_linear_objective(params):
    return jnp.sum(jnp.square(y - jnp.matmul(x, params.reshape((p, m)))))

### Use scope to solve the sparse multiple response linear regression problem

In [10]:
solver = ScopeSolver(p * m, k, group=[i for i in range(p) for j in range(m)])
solver.solve(multi_linear_objective)

print("Estimated parameter:\n", solver.params.reshape((p, m)))
print("True parameter:\n", coef)

Estimated parameter:
 [[54.63583377 79.61427562]
 [ 0.          0.        ]
 [ 0.          0.        ]
 [36.54777739 24.42908797]
 [ 5.11427902 18.86677609]]
True parameter:
 [[54.63583485 79.61427209]
 [ 0.          0.        ]
 [ 0.          0.        ]
 [36.54777679 24.4290867 ]
 [ 5.11428031 18.86677358]]
