Skip to content

Commit

Permalink
Merge pull request #494 from GAA-UAM/issue/classifier_classes/qda
Browse files Browse the repository at this point in the history
Issue/classifier classes/qda
  • Loading branch information
vnmabus committed Nov 15, 2022
2 parents fca9f31 + f40872f commit 7b51542
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
33 changes: 18 additions & 15 deletions skfda/ml/classification/_qda.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from typing import Sequence, TypeVar, Union

import numpy as np
from scipy.linalg import logm
Expand All @@ -11,12 +11,14 @@
from ..._utils._sklearn_adapter import BaseEstimator, ClassifierMixin
from ...exploratory.stats.covariance import CovarianceEstimator
from ...representation import FDataGrid
from ...typing._numpy import NDArrayFloat, NDArrayInt
from ...typing._numpy import NDArrayFloat, NDArrayInt, NDArrayStr

Target = TypeVar("Target", bound=Union[NDArrayInt, NDArrayStr])


class QuadraticDiscriminantAnalysis(
BaseEstimator,
ClassifierMixin[FDataGrid, NDArrayInt],
ClassifierMixin[FDataGrid, Target],
):
"""
Functional quadratic discriminant analysis.
Expand Down Expand Up @@ -101,8 +103,8 @@ class QuadraticDiscriminantAnalysis(
>>> round(qda.score(X_test, y_test), 2)
0.96
"""

means_: Sequence[FDataGrid]

def __init__(
Expand All @@ -117,8 +119,8 @@ def __init__(
def fit(
self,
X: FDataGrid,
y: NDArrayInt,
) -> QuadraticDiscriminantAnalysis:
y: Target,
) -> QuadraticDiscriminantAnalysis[Target]:
"""
Fit the model using X as training data and y as target values.
Expand All @@ -130,7 +132,7 @@ def fit(
self
"""
classes, y_ind = _classifier_get_classes(y)
self.classes = classes
self.classes_ = classes
self.y_ind = y_ind

self._fit_gaussian_process(X)
Expand All @@ -150,7 +152,7 @@ def fit(

return self

def predict(self, X: FDataGrid) -> NDArrayInt:
def predict(self, X: FDataGrid) -> Target:
"""
Predict the class labels for the provided data.
Expand All @@ -163,12 +165,13 @@ def predict(self, X: FDataGrid) -> NDArrayInt:
"""
check_is_fitted(self)

return np.argmax( # type: ignore[no-any-return]
self._calculate_log_likelihood(X.data_matrix),
axis=1,
)
return self.classes_[ # type: ignore[no-any-return]
np.argmax(
self._calculate_log_likelihood(X.data_matrix),
axis=1,
)]

def _calculate_priors(self, y: NDArrayInt) -> NDArrayFloat:
def _calculate_priors(self, y: Target) -> NDArrayFloat:
"""
Calculate the prior probability of each class.
Expand Down Expand Up @@ -198,7 +201,7 @@ def _fit_gaussian_process(
cov_estimators = []
means = []
covariance = []
for class_index, _ in enumerate(self.classes):
for class_index, _ in enumerate(self.classes_):
X_class = X[self.y_ind == class_index]
cov_estimator = clone(self.cov_estimator).fit(X_class)

Expand Down Expand Up @@ -235,7 +238,7 @@ def _calculate_log_likelihood(self, X: NDArrayFloat) -> NDArrayFloat:
self._regularized_covariances,
X_centered,
),
(-1, self.classes.size),
(-1, self.classes_.size),
)

return np.asarray(
Expand Down
2 changes: 1 addition & 1 deletion skfda/tests/test_classifier_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_classes(
"""Test classes attribute of classifiers."""
n_samples = 30
y = np.resize(classes, n_samples)
X = make_gaussian_process(n_samples=n_samples)
X = make_gaussian_process(n_samples=n_samples, random_state=0)
classifier.fit(X, y)
resulting_classes = np.unique(classifier.predict(X))

Expand Down

0 comments on commit 7b51542

Please sign in to comment.