Skip to content

Commit

Permalink
Merge pull request #496 from GAA-UAM/issue/classifier_classes/logreg
Browse files Browse the repository at this point in the history
Issue/classifier_classes/logreg
  • Loading branch information
vnmabus committed Nov 23, 2022
2 parents aee15f4 + a9e9e51 commit 8756059
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ rst-directives =
versionadded,versionchanged,

rst-roles =
attr,class,doc,footcite,footcite:ts,func,meth,mod,obj,ref,term,
attr,class,doc,footcite,footcite:ts,func,meth,mod,obj,ref,term,external:class

allowed-domain-names = data, obj, result, results, val, value, values, var

Expand Down
9 changes: 6 additions & 3 deletions skfda/ml/classification/_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
self.solver = solver
self.max_iter = max_iter

def fit( # noqa: D102
def fit( # noqa: D102, WPS210
self,
X: FDataGrid,
y: NDArrayAny,
Expand Down Expand Up @@ -162,7 +162,10 @@ def fit( # noqa: D102
# This does not improve
selected_indexes = selected_indexes[:n_selected]
selected_values = selected_values[:, :n_selected]
likelihood_curves_data = likelihood_curves_data[:n_selected, t]
likelihood_curves_data = likelihood_curves_data[
:n_selected,
n_features - 1,
]
break

last_max_likelihood = max_likelihood
Expand All @@ -171,7 +174,7 @@ def fit( # noqa: D102
selected_values[:, n_selected] = X.data_matrix[:, tmax, 0]

# fit for the complete set of points
mvlr.fit(selected_values, y_ind)
mvlr.fit(selected_values, y)

self.coef_ = mvlr.coef_
self.intercept_ = mvlr.intercept_
Expand Down
24 changes: 24 additions & 0 deletions skfda/tests/test_classifier_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@

from skfda._utils._sklearn_adapter import ClassifierMixin
from skfda.datasets import make_gaussian_process
from skfda.exploratory.depth import ModifiedBandDepth
from skfda.exploratory.stats.covariance import ParametricGaussianCovariance
from skfda.misc.covariances import Gaussian
from skfda.ml.classification import (
DDClassifier,
DDGClassifier,
DTMClassifier,
KNeighborsClassifier,
LogisticRegression,
MaximumDepthClassifier,
NearestCentroid,
QuadraticDiscriminantAnalysis,
RadiusNeighborsClassifier,
)
from skfda.representation import FData
Expand All @@ -20,7 +30,21 @@

@pytest.fixture(
params=[
DDClassifier(degree=2),
DDGClassifier(
depth_method=[("mbd", ModifiedBandDepth())],
multivariate_classifier=KNeighborsClassifier(),
),
DTMClassifier(proportiontocut=0.25),
KNeighborsClassifier(),
LogisticRegression(),
MaximumDepthClassifier(),
NearestCentroid(),
QuadraticDiscriminantAnalysis(
cov_estimator=ParametricGaussianCovariance(
Gaussian(),
),
),
RadiusNeighborsClassifier(),
],
ids=lambda clf: type(clf).__name__,
Expand Down

0 comments on commit 8756059

Please sign in to comment.