From 4b4c0be753f23c6a86253a31ddc6e513b26e54f9 Mon Sep 17 00:00:00 2001 From: Avi Shinnar Date: Mon, 12 Apr 2021 08:12:36 -0400 Subject: [PATCH] Add curated schema for complement_nb --- lale/lib/sklearn/__init__.py | 3 + lale/lib/sklearn/complement_nb.py | 117 ++++++++++++++++++++++++++++++ test/test_core_classifiers.py | 1 + 3 files changed, 121 insertions(+) create mode 100644 lale/lib/sklearn/complement_nb.py diff --git a/lale/lib/sklearn/__init__.py b/lale/lib/sklearn/__init__.py index beb8db879d..7bdca9f1bb 100644 --- a/lale/lib/sklearn/__init__.py +++ b/lale/lib/sklearn/__init__.py @@ -27,6 +27,7 @@ * lale.lib.sklearn. `BernoulliNB`_ * lale.lib.sklearn. `CalibratedClassifierCV`_ * lale.lib.sklearn. `CCA`_ +* lale.lib.sklearn. `ComplementNB`_ * lale.lib.sklearn. `DecisionTreeClassifier`_ * lale.lib.sklearn. `DummyClassifier`_ * lale.lib.sklearn. `ExtraTreesClassifier`_ @@ -102,6 +103,7 @@ .. _`Birch`: lale.lib.sklearn.birch.html .. _`CalibratedClassifierCV`: lale.lib.sklearn.calibrated_classifier_cv.html .. _`ColumnTransformer`: lale.lib.sklearn.column_transformer.html +.. _`ComplementNB`: lale.lib.sklearn.complement_nb.html .. _`DecisionTreeClassifier`: lale.lib.sklearn.decision_tree_classifier.html .. _`DecisionTreeRegressor`: lale.lib.sklearn.decision_tree_regressor.html .. _`DummyClassifier`: lale.lib.sklearn.dummy_classifier.html @@ -165,6 +167,7 @@ from .calibrated_classifier_cv import CalibratedClassifierCV from .cca import CCA from .column_transformer import ColumnTransformer +from .complement_nb import ComplementNB from .decision_tree_classifier import DecisionTreeClassifier from .decision_tree_regressor import DecisionTreeRegressor from .dummy_classifier import DummyClassifier diff --git a/lale/lib/sklearn/complement_nb.py b/lale/lib/sklearn/complement_nb.py new file mode 100644 index 0000000000..32fa56bcd7 --- /dev/null +++ b/lale/lib/sklearn/complement_nb.py @@ -0,0 +1,117 @@ +from sklearn.naive_bayes import ComplementNB as Op + +from lale.docstrings import set_docstrings +from lale.operators import make_operator + +_hyperparams_schema = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": "The Complement Naive Bayes classifier described in Rennie et al. (2003).", + "allOf": [ + { + "type": "object", + "required": ["alpha", "fit_prior", "class_prior", "norm"], + "relevantToOptimizer": [], + "additionalProperties": False, + "properties": { + "alpha": { + "type": "number", + "default": 1.0, + "description": "Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).", + }, + "fit_prior": { + "type": "boolean", + "default": True, + "description": "Only used in edge case with a single class in the training set.", + }, + "class_prior": { + "anyOf": [ + {"type": "array", "items": {"type": "number"}}, + {"enum": [None]}, + ], + "default": None, + "description": "Prior probabilities of the classes. Not used.", + }, + "norm": { + "type": "boolean", + "default": False, + "description": "Whether or not a second normalization of the weights is performed", + }, + }, + }, + ], +} +_input_fit_schema = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": "Fit Naive Bayes classifier according to X, y", + "type": "object", + "required": ["X", "y"], + "properties": { + "X": { + "type": "array", + "items": {"type": "array", "items": {"type": "number"}}, + "description": "Training vectors, where n_samples is the number of samples and n_features is the number of features.", + }, + "y": { + "type": "array", + "items": {"type": "number"}, + "description": "Target values.", + }, + "sample_weight": { + "anyOf": [{"type": "array", "items": {"type": "number"}}, {"enum": [None]}], + "default": None, + "description": "Weights applied to individual samples (1", + }, + }, +} +_input_predict_schema = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": "Perform classification on an array of test vectors X.", + "type": "object", + "required": ["X"], + "properties": { + "X": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}} + }, +} +_output_predict_schema = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": "Predicted target values for X", + "type": "array", + "items": {"type": "number"}, +} +_input_predict_proba_schema = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": "Return probability estimates for the test vector X.", + "type": "object", + "required": ["X"], + "properties": { + "X": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}} + }, +} +_output_predict_proba_schema = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": "Returns the probability of the samples for each class in the model", + "type": "array", + "items": {"type": "array", "items": {"type": "number"}}, +} +_combined_schemas = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": """`Complement Naive Bayes`_ classifier described in Rennie et al. (2003). + +.. _`Complement Naive Bayes`: https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.ComplementNB + """, + "documentation_url": "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.complement_nb.html", + "import_from": "sklearn.naive_bayes", + "type": "object", + "tags": {"pre": [], "op": ["estimator"], "post": []}, + "properties": { + "hyperparams": _hyperparams_schema, + "input_fit": _input_fit_schema, + "input_predict": _input_predict_schema, + "output_predict": _output_predict_schema, + "input_predict_proba": _input_predict_proba_schema, + "output_predict_proba": _output_predict_proba_schema, + }, +} +ComplementNB = make_operator(Op, _combined_schemas) + +set_docstrings(ComplementNB) diff --git a/test/test_core_classifiers.py b/test/test_core_classifiers.py index 41796d3c82..fcf97479fb 100644 --- a/test/test_core_classifiers.py +++ b/test/test_core_classifiers.py @@ -131,6 +131,7 @@ def test_classifier(self): "lale.lib.sklearn.BernoulliNB", "lale.lib.sklearn.CalibratedClassifierCV", "lale.lib.sklearn.CCA", + "lale.lib.sklearn.ComplementNB", "lale.lib.sklearn.DummyClassifier", "lale.lib.sklearn.RandomForestClassifier", "lale.lib.sklearn.DecisionTreeClassifier",