Skip to content

Commit

Permalink
feat: add alpha parameter to lasso_regression (#232)
Browse files Browse the repository at this point in the history
Closes #163.

### Summary of Changes

Add alpha parameter to lasso regression

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
3 people committed Apr 28, 2023
1 parent f4f44a6 commit b5050b9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
34 changes: 29 additions & 5 deletions src/safeds/ml/classical/regression/_lasso_regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from warnings import warn

from sklearn.linear_model import Lasso as sk_Lasso

Expand All @@ -13,9 +14,32 @@


class LassoRegression(Regressor):
"""Lasso regression."""

def __init__(self) -> None:
"""Lasso regression.
Parameters
----------
alpha : float
Controls the regularization of the model. The higher the value, the more regularized it becomes.
Raises
------
ValueError
If alpha is negative.
"""

def __init__(self, alpha: float = 1.0) -> None:
if alpha < 0:
raise ValueError("alpha must be non-negative")
if alpha == 0:
warn(
(
"Setting alpha to zero makes this model equivalent to LinearRegression. You should use "
"LinearRegression instead for better numerical stability."
),
UserWarning,
stacklevel=2,
)
self._alpha = alpha
self._wrapped_regressor: sk_Lasso | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
Expand All @@ -41,10 +65,10 @@ def fit(self, training_set: TaggedTable) -> LassoRegression:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_Lasso()
wrapped_regressor = sk_Lasso(alpha=self._alpha)
fit(wrapped_regressor, training_set)

result = LassoRegression()
result = LassoRegression(alpha=self._alpha)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
28 changes: 28 additions & 0 deletions tests/safeds/ml/classical/regression/test_lasso_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.ml.classical.regression import LassoRegression


def test_should_throw_value_error() -> None:
with pytest.raises(ValueError, match="alpha must be non-negative"):
LassoRegression(alpha=-1)


def test_should_throw_warning() -> None:
with pytest.warns(
UserWarning,
match=(
"Setting alpha to zero makes this model equivalent to LinearRegression. You "
"should use LinearRegression instead for better numerical stability."
),
):
LassoRegression(alpha=0)


def test_should_give_alpha_to_sklearn() -> None:
training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
tagged_table = training_set.tag_columns("col1")

regressor = LassoRegression(alpha=1).fit(tagged_table)
assert regressor._wrapped_regressor is not None
assert regressor._wrapped_regressor.alpha == regressor._alpha

0 comments on commit b5050b9

Please sign in to comment.