/
operator_wrapper.py
122 lines (100 loc) · 3.96 KB
/
operator_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2019-2022 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from typing import Container, List, Optional, Set
from lale.operators import Operator, clone_op, get_op_from_lale_lib
logger = logging.getLogger(__name__)
def _wrap_operators_in_symtab(
symtab,
exclude_classes: Optional[Container[str]] = None,
wrapper_modules: Optional[List[str]] = None,
) -> None:
for name, impl in symtab.items():
if (
inspect.isclass(impl)
and not issubclass(impl, Operator)
and (hasattr(impl, "predict") or hasattr(impl, "transform"))
):
if exclude_classes is not None:
if name in exclude_classes:
continue
operator = get_op_from_lale_lib(impl, wrapper_modules)
if operator is None:
# symtab[name] = make_operator(impl=impl, name=name)
logger.info(f"Lale:Not wrapping unknown operator:{name}")
else:
symtab[name] = clone_op(operator, name)
if operator.class_name().startswith("lale.lib.autogen"):
logger.info(f"Lale:Wrapped autogen operator:{name}")
else:
logger.info(f"Lale:Wrapped known operator:{name}")
def wrap_imported_operators(
exclude_classes: Optional[Container[str]] = None,
wrapper_modules: Optional[List[str]] = None,
) -> None:
"""Wrap the currently imported operators from the symbol table
to their lale wrappers.
Parameters
----------
exclude_classes : string, optional, default None
List of class names to exclude from wrapping,
alias names if they are used while importing.
wrapper_modules : set of string, optional, default None
Set of Lale modules to use for wrapping operators.
"""
current_frame = inspect.currentframe()
assert (
current_frame is not None
), "Try to use inspect.stack()[1][0] to get the calling frame"
calling_frame = current_frame.f_back
assert (
calling_frame is not None
), "Try to use inspect.stack()[1][0] to get the calling frame"
if wrapper_modules is not None:
wrapper_modules.extend(get_lale_wrapper_modules())
else:
wrapper_modules = list(get_lale_wrapper_modules())
_wrap_operators_in_symtab(
calling_frame.f_globals, exclude_classes, wrapper_modules=wrapper_modules
)
if calling_frame.f_code.co_name == "<module>": # for testing with exec()
_wrap_operators_in_symtab(
calling_frame.f_locals, exclude_classes, wrapper_modules=wrapper_modules
)
_lale_wrapper_modules: Set[str] = set()
def register_lale_wrapper_modules(m: str) -> None:
"""Register a module with lale's import system
so that :meth:`lale.helpers.import_from_sklearn_pipeline` will look for replacement classes in that module.
Example: (in `__init__.py` file for the module):
.. code-block:: python
from lale import register_lale_wrapper_modules
register_lale_wrapper_modules(__name__)
Parameters
----------
m : [str]
The module name
"""
_lale_wrapper_modules.add(m)
def get_lale_wrapper_modules() -> Set[str]:
return _lale_wrapper_modules
for builtin_lale_modules in [
"lale.lib.sklearn",
"lale.lib.autoai_libs",
"lale.lib.xgboost",
"lale.lib.lightgbm",
"lale.lib.snapml",
"autoai_ts_libs.lale",
]:
register_lale_wrapper_modules(builtin_lale_modules)