Skip to content

Commit

Permalink
Add curated schema for CCA
Browse files Browse the repository at this point in the history
  • Loading branch information
shinnar committed Apr 13, 2021
1 parent 965c250 commit 8863bda
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lale/lib/sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
* lale.lib.sklearn. `BaggingClassifier`_
* lale.lib.sklearn. `BernoulliNB`_
* lale.lib.sklearn. `CalibratedClassifierCV`_
* lale.lib.sklearn. `CCA`_
* lale.lib.sklearn. `DecisionTreeClassifier`_
* lale.lib.sklearn. `DummyClassifier`_
* lale.lib.sklearn. `ExtraTreesClassifier`_
Expand Down Expand Up @@ -162,6 +163,7 @@
from .binarizer import Binarizer
from .birch import Birch
from .calibrated_classifier_cv import CalibratedClassifierCV
from .cca import CCA
from .column_transformer import ColumnTransformer
from .decision_tree_classifier import DecisionTreeClassifier
from .decision_tree_regressor import DecisionTreeRegressor
Expand Down
147 changes: 147 additions & 0 deletions lale/lib/sklearn/cca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from sklearn.cross_decomposition import CCA as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator

_hyperparams_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Canonical Correlation Analysis.",
"allOf": [
{
"type": "object",
"required": ["n_components", "scale", "max_iter", "tol", "copy"],
"relevantToOptimizer": ["n_components", "scale", "max_iter", "tol"],
"additionalProperties": False,
"properties": {
"n_components": {
"type": "integer",
"minimum": 1,
"minimumForOptimizer": 2,
"maximumForOptimizer": 256,
"distribution": "uniform",
"default": 2,
"description": "number of components to keep.",
},
"scale": {
"type": "boolean",
"default": True,
"description": "whether to scale the data?",
},
"max_iter": {
"description": "the maximum number of the power method.",
"type": "integer",
"minimum": 0,
"minimumForOptimizer": 10,
"maximumForOptimizer": 1000,
"distribution": "uniform",
"default": 500,
},
"tol": {
"description": "the tolerance used in the iterative algorithm",
"type": "number",
"minimumForOptimizer": 1e-08,
"maximumForOptimizer": 0.01,
"distribution": "loguniform",
"default": 1e-06,
},
"copy": {
"type": "boolean",
"default": True,
"description": "Whether the deflation be done on a copy",
},
},
}
],
}
_input_fit_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Fit model to data.",
"type": "object",
"required": ["X"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Training vectors, where n_samples is the number of samples and n_features is the number of predictors.",
},
"Y": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Target vectors, where n_samples is the number of samples and n_targets is the number of response variables.",
},
},
}
_input_transform_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Apply the dimension reduction learned on the train data.",
"type": "object",
"required": ["X"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Training vectors, where n_samples is the number of samples and n_features is the number of predictors.",
},
"Y": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Target vectors, where n_samples is the number of samples and n_targets is the number of response variables.",
},
"copy": {
"type": "boolean",
"default": True,
"description": "Whether to copy X and Y, or perform in-place normalization.",
},
},
}
_output_transform_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Apply the dimension reduction learned on the train data.",
"laleType": "Any",
"XXX TODO XXX": "x_scores if Y is not given, (x_scores, y_scores) otherwise.",
}
_input_predict_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Apply the dimension reduction learned on the train data.",
"type": "object",
"required": ["X"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Training vectors, where n_samples is the number of samples and n_features is the number of predictors.",
},
"copy": {
"type": "boolean",
"default": True,
"description": "Whether to copy X and Y, or perform in-place normalization.",
},
},
}
_output_predict_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Apply the dimension reduction learned on the train data.",
"laleType": "Any",
}
_combined_schemas = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": """`Canonical Correlation Analysis`_ from sklearn
.. _`Canonical Correlation Analysis`: https://scikit-learn.org/stable/modules/generated/sklearn.cross_decomposition.CCA
""",
"documentation_url": "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.cca.html",
"import_from": "sklearn.cross_decomposition",
"type": "object",
"tags": {"pre": [], "op": ["transformer", "estimator"], "post": []},
"properties": {
"hyperparams": _hyperparams_schema,
"input_fit": _input_fit_schema,
"input_transform": _input_transform_schema,
"output_transform": _output_transform_schema,
"input_predict": _input_predict_schema,
"output_predict": _output_predict_schema,
},
}
CCA = make_operator(Op, _combined_schemas)

set_docstrings(CCA)
1 change: 1 addition & 0 deletions test/test_core_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_classifier(self):
classifiers = [
"lale.lib.sklearn.BernoulliNB",
"lale.lib.sklearn.CalibratedClassifierCV",
"lale.lib.sklearn.CCA",
"lale.lib.sklearn.DummyClassifier",
"lale.lib.sklearn.RandomForestClassifier",
"lale.lib.sklearn.DecisionTreeClassifier",
Expand Down

0 comments on commit 8863bda

Please sign in to comment.