Skip to content

Commit

Permalink
LDA and QDA: Raise an informative error if the user forgot to set sto…
Browse files Browse the repository at this point in the history
…re_covariance=True when instantiating the model
  • Loading branch information
andreArtelt committed May 8, 2020
1 parent 3e8e3cf commit 5619300
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ceml/sklearn/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def rebuild_model(self, model):
"""
if not isinstance(model, LinearDiscriminantAnalysis):
raise TypeError(f"model has to be an instance of 'sklearn.discriminant_analysis.LinearDiscriminantAnalysis' but not of {type(model)}")
if not hasattr(model, "covariance_"):
raise AttributeError("You have to set store_covariance=True when instantiating a new sklearn.discriminant_analysis.LinearDiscriminantAnalysis model")

return Lda(model)

Expand Down
2 changes: 2 additions & 0 deletions ceml/sklearn/qda.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def rebuild_model(self, model):
"""
if not isinstance(model, QuadraticDiscriminantAnalysis):
raise TypeError(f"model has to be an instance of 'sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis' but not of {type(model)}")
if not hasattr(model, "covariance_"):
raise AttributeError("You have to set store_covariance=True when instantiating a new sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis model")

return Qda(model)

Expand Down

0 comments on commit 5619300

Please sign in to comment.