Skip to content

Commit

Permalink
MNT parameter validation for covariance.empirical_covariance (scikit-…
Browse files Browse the repository at this point in the history
  • Loading branch information
shogohida authored and Vincent-Maladiere committed Dec 14, 2022
1 parent a6fa94f commit e7473de
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
7 changes: 7 additions & 0 deletions sklearn/covariance/_empirical_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .. import config_context
from ..base import BaseEstimator
from ..utils import check_array
from ..utils._param_validation import validate_params
from ..utils.extmath import fast_logdet
from ..metrics.pairwise import pairwise_distances

Expand Down Expand Up @@ -48,6 +49,12 @@ def log_likelihood(emp_cov, precision):
return log_likelihood_


@validate_params(
{
"X": ["array-like"],
"assume_centered": ["boolean"],
}
)
def empirical_covariance(X, *, assume_centered=False):
"""Compute the Maximum likelihood covariance estimator.
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _check_function_param_validation(
PARAM_VALIDATION_FUNCTION_LIST = [
"sklearn.cluster.estimate_bandwidth",
"sklearn.cluster.kmeans_plusplus",
"sklearn.covariance.empirical_covariance",
"sklearn.feature_extraction.grid_to_graph",
"sklearn.feature_extraction.img_to_graph",
"sklearn.metrics.accuracy_score",
Expand Down

0 comments on commit e7473de

Please sign in to comment.