Skip to content

Commit

Permalink
feat: set alpha parameter for regularization of `ElasticNetRegressi…
Browse files Browse the repository at this point in the history
…on` (#238)

Closes #165.

### Summary of Changes

add alpha parameter for `elasticnetregression`.

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
3 people committed Apr 28, 2023
1 parent b5050b9 commit e642d1d
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 23 deletions.
56 changes: 46 additions & 10 deletions src/safeds/ml/classical/regression/_elastic_net_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from typing import TYPE_CHECKING
from warnings import warn

from sklearn.linear_model import ElasticNet as sk_ElasticNet

Expand All @@ -14,24 +15,59 @@


class ElasticNetRegression(Regressor):
"""Elastic net regression."""
"""Elastic net regression.
Parameters
----------
alpha : float
Controls the regularization of the model. The higher the value, the more regularized it becomes.
lasso_ratio: float
Number between 0 and 1 that controls the ratio between Lasso- and Ridge regularization.
lasso_ratio=0 is essentially RidgeRegression
lasso_ratio=1 is essentially LassoRegression
Raises
------
ValueError
If alpha is negative.
"""

def __init__(self, alpha: float = 1.0, lasso_ratio: float = 0.5) -> None:
if alpha < 0:
raise ValueError("alpha must be non-negative")
if alpha == 0:
warn(
(
"Setting alpha to zero makes this model equivalent to LinearRegression. You should use "
"LinearRegression instead for better numerical stability."
),
UserWarning,
stacklevel=2,
)

self._alpha = alpha

def __init__(self, lasso_ratio: float = 0.5) -> None:
if lasso_ratio < 0 or lasso_ratio > 1:
raise ValueError("lasso_ratio must be between 0 and 1.")
elif lasso_ratio == 0:
warnings.warn(
"ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability.",
stacklevel=1,
(
"ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability."
),
stacklevel=2,
)
elif lasso_ratio == 1:
warnings.warn(
"ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability.",
stacklevel=1,
(
"ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability."
),
stacklevel=2,
)
self.lasso_ratio = lasso_ratio

self._wrapped_regressor: sk_ElasticNet | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
Expand All @@ -57,10 +93,10 @@ def fit(self, training_set: TaggedTable) -> ElasticNetRegression:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_ElasticNet(l1_ratio=self.lasso_ratio)
wrapped_regressor = sk_ElasticNet(alpha=self._alpha, l1_ratio=self.lasso_ratio)
fit(wrapped_regressor, training_set)

result = ElasticNetRegression(self.lasso_ratio)
result = ElasticNetRegression(alpha=self._alpha, lasso_ratio=self.lasso_ratio)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
55 changes: 42 additions & 13 deletions tests/safeds/ml/classical/regression/test_elastic_net_regression.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,68 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.ml.classical.regression._elastic_net_regression import ElasticNetRegression
from safeds.ml.classical.regression import ElasticNetRegression


def test_lasso_ratio_valid() -> None:
def test_should_throw_value_error_alpha() -> None:
with pytest.raises(ValueError, match="alpha must be non-negative"):
ElasticNetRegression(alpha=-1.0)


def test_should_throw_warning_alpha() -> None:
with pytest.warns(
UserWarning,
match=(
"Setting alpha to zero makes this model equivalent to LinearRegression. You "
"should use LinearRegression instead for better numerical stability."
),
):
ElasticNetRegression(alpha=0.0)


def test_should_give_alpha_to_sklearn() -> None:
training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
tagged_training_set = training_set.tag_columns(target_name="col1", feature_names=["col2"])

elastic_net_regression = ElasticNetRegression(alpha=1.0).fit(tagged_training_set)
assert elastic_net_regression._wrapped_regressor is not None
assert elastic_net_regression._wrapped_regressor.alpha == elastic_net_regression._alpha


def test_should_give_lasso_ratio_to_sklearn() -> None:
training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
tagged_training_set = training_set.tag_columns(target_name="col1", feature_names=["col2"])
lasso_ratio = 0.3

elastic_net_regression = ElasticNetRegression(lasso_ratio).fit(tagged_training_set)
elastic_net_regression = ElasticNetRegression(lasso_ratio=lasso_ratio).fit(tagged_training_set)
assert elastic_net_regression._wrapped_regressor is not None
assert elastic_net_regression._wrapped_regressor.l1_ratio == lasso_ratio


def test_lasso_ratio_invalid() -> None:
def test_should_throw_value_error_lasso_ratio() -> None:
with pytest.raises(ValueError, match="lasso_ratio must be between 0 and 1."):
ElasticNetRegression(-1)
ElasticNetRegression(lasso_ratio=-1.0)


def test_lasso_ratio_zero() -> None:
def test_should_throw_warning_lasso_ratio_zero() -> None:
with pytest.warns(
UserWarning,
match="ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability.",
match=(
"ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability."
),
):
ElasticNetRegression(0)
ElasticNetRegression(lasso_ratio=0)


def test_lasso_ratio_one() -> None:
def test_should_throw_warning_lasso_ratio_one() -> None:
with pytest.warns(
UserWarning,
match="ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability.",
match=(
"ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability."
),
):
ElasticNetRegression(1)
ElasticNetRegression(lasso_ratio=1)


# (Default parameter is tested in `test_regressor.py`.)

0 comments on commit e642d1d

Please sign in to comment.