Skip to content

Commit

Permalink
feat: Added parameter c to SupportVectorMachines (#267)
Browse files Browse the repository at this point in the history
Closes #169.

### Summary of Changes

Added parameter c to SupportVectorMachines

---------

Co-authored-by: alex-senger <91055000+alex-senger@users.noreply.github.com>
Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
4 people committed May 5, 2023
1 parent 5adadad commit a88eb8b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 8 deletions.
24 changes: 20 additions & 4 deletions src/safeds/ml/classical/classification/_support_vector_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,30 @@


class SupportVectorMachine(Classifier):
"""Support vector machine."""
"""
Support vector machine.
def __init__(self) -> None:
Parameters
----------
c: float
The strength of regularization. Must be strictly positive.
Raises
------
ValueError
If `c` is less than or equal to 0.
"""

def __init__(self, c: float = 1.0) -> None:
# Internal state
self._wrapped_classifier: sk_SVC | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None

if c <= 0:
raise ValueError("The strength of regularization given by the c parameter must be strictly positive.")
self._c = c

def fit(self, training_set: TaggedTable) -> SupportVectorMachine:
"""
Create a copy of this classifier and fit it with the given training data.
Expand All @@ -42,10 +58,10 @@ def fit(self, training_set: TaggedTable) -> SupportVectorMachine:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_classifier = sk_SVC()
wrapped_classifier = sk_SVC(C=self._c)
fit(wrapped_classifier, training_set)

result = SupportVectorMachine()
result = SupportVectorMachine(self._c)
result._wrapped_classifier = wrapped_classifier
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
24 changes: 20 additions & 4 deletions src/safeds/ml/classical/regression/_support_vector_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,30 @@


class SupportVectorMachine(Regressor):
"""Support vector machine."""
"""
Support vector machine.
def __init__(self) -> None:
Parameters
----------
c: float
The strength of regularization. Must be strictly positive.
Raises
------
ValueError
If `c` is less than or equal to 0.
"""

def __init__(self, c: float = 1.0) -> None:
# Internal state
self._wrapped_regressor: sk_SVR | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None

if c <= 0:
raise ValueError("The strength of regularization given by the c parameter must be strictly positive.")
self._c = c

def fit(self, training_set: TaggedTable) -> SupportVectorMachine:
"""
Create a copy of this regressor and fit it with the given training data.
Expand All @@ -42,10 +58,10 @@ def fit(self, training_set: TaggedTable) -> SupportVectorMachine:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_SVR()
wrapped_regressor = sk_SVR(C=self._c)
fit(wrapped_regressor, training_set)

result = SupportVectorMachine()
result = SupportVectorMachine(self._c)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml.classical.classification import SupportVectorMachine


@pytest.fixture()
def training_set() -> TaggedTable:
table = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
return table.tag_columns(target_name="col1", feature_names=["col2"])


class TestC:
def test_should_be_passed_to_fitted_model(self, training_set: TaggedTable) -> None:
fitted_model = SupportVectorMachine(c=2).fit(training_set=training_set)
assert fitted_model._c == 2

def test_should_be_passed_to_sklearn(self, training_set: TaggedTable) -> None:
fitted_model = SupportVectorMachine(c=2).fit(training_set)
assert fitted_model._wrapped_classifier is not None
assert fitted_model._wrapped_classifier.C == 2

def test_should_raise_if_less_than_or_equal_to_0(self) -> None:
with pytest.raises(
ValueError,
match="The strength of regularization given by the c parameter must be strictly positive.",
):
SupportVectorMachine(c=-1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml.classical.regression import SupportVectorMachine


@pytest.fixture()
def training_set() -> TaggedTable:
table = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
return table.tag_columns(target_name="col1", feature_names=["col2"])


class TestC:
def test_should_be_passed_to_fitted_model(self, training_set: TaggedTable) -> None:
fitted_model = SupportVectorMachine(c=2).fit(training_set=training_set)
assert fitted_model._c == 2

def test_should_be_passed_to_sklearn(self, training_set: TaggedTable) -> None:
fitted_model = SupportVectorMachine(c=2).fit(training_set)
assert fitted_model._wrapped_regressor is not None
assert fitted_model._wrapped_regressor.C == 2

def test_should_raise_if_less_than_or_equal_to_0(self) -> None:
with pytest.raises(
ValueError,
match="The strength of regularization given by the c parameter must be strictly positive.",
):
SupportVectorMachine(c=-1)

0 comments on commit a88eb8b

Please sign in to comment.