-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: major API redesign (WIP) (#752)
Closes #694 Closes #699 Closes #714 Closes #748 ### Summary of Changes * Replace old implementation of tabular containers * New, more efficient implementation of metrics * Standalone package for metrics * New regression metrics * Abstract base class for classifiers & regressors * Introspection methods to get information about features and target of supervised models * Rename `LogisticRegressionClassifier` to `LogisticClassifier` (shorter + does not show up when searching for regression) * Rename `LinearRegressionRegressor` to `LinearRegressor` (shorter) * Rename `SupportVectorMachineClassifier` to `SupportVectorClassifier` (a little less precise, but still unambiguous and shorter) * Rename `SupportVectorMachineRegressor` to `SupportVectorRegressor` (ditto)
- Loading branch information
1 parent
0e5a54b
commit 8e781f9
Showing
163 changed files
with
7,217 additions
and
15,007 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from __future__ import annotations | ||
|
||
from timeit import timeit | ||
from typing import TYPE_CHECKING | ||
|
||
import polars as pl | ||
|
||
from benchmarks.table.utils import create_synthetic_table | ||
from safeds.data.tabular.containers import Table | ||
from safeds.ml.metrics import ClassificationMetrics | ||
|
||
|
||
REPETITIONS = 10 | ||
|
||
|
||
def _run_accuracy() -> None: | ||
ClassificationMetrics.accuracy(table.get_column("predicted"), table.get_column("expected")) | ||
|
||
|
||
def _run_f1_score() -> None: | ||
ClassificationMetrics.f1_score(table.get_column("predicted"), table.get_column("expected"), 1) | ||
|
||
|
||
def _run_precision() -> None: | ||
ClassificationMetrics.precision(table.get_column("predicted"), table.get_column("expected"), 1) | ||
|
||
|
||
def _run_recall() -> None: | ||
ClassificationMetrics.recall(table.get_column("predicted"), table.get_column("expected"), 1) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Create a synthetic Table | ||
table = ( | ||
create_synthetic_table(10000, 2) | ||
.rename_column("column_0", "predicted") | ||
.rename_column("column_1", "expected") | ||
) | ||
|
||
# Run the benchmarks | ||
timings: dict[str, float] = { | ||
"accuracy": timeit( | ||
_run_accuracy, | ||
number=REPETITIONS, | ||
), | ||
"f1_score": timeit( | ||
_run_f1_score, | ||
number=REPETITIONS, | ||
), | ||
"precision": timeit( | ||
_run_precision, | ||
number=REPETITIONS, | ||
), | ||
"recall": timeit( | ||
_run_recall, | ||
number=REPETITIONS, | ||
), | ||
} | ||
|
||
# Print the timings | ||
with pl.Config( | ||
tbl_rows=-1, | ||
): | ||
print( | ||
Table( | ||
{ | ||
"method": list(timings.keys()), | ||
"timing": list(timings.values()), | ||
} | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
from .create_synthetic_table import create_synthetic_table | ||
from .create_synthetic_table_polars import create_synthetic_table_polars | ||
|
||
__all__ = [ | ||
"create_synthetic_table", | ||
"create_synthetic_table_polars", | ||
] |
Oops, something went wrong.