-
Notifications
You must be signed in to change notification settings - Fork 82
/
dask_testing.py
36 lines (25 loc) · 1.28 KB
/
dask_testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from collections import namedtuple
from evalml.objectives.utils import get_objective
from evalml.pipelines import BinaryClassificationPipeline
from evalml.preprocessing.data_splitters import TrainingValidationSplit
# Top-level replacement for AutoML object to supply data for testing purposes.
def err_call(*args, **kwargs):
return 1
AutoMLSearchStruct = namedtuple("AutoML",
"data_splitter problem_type objective additional_objectives optimize_thresholds error_callback random_seed ensembling_indices")
data_splitter = TrainingValidationSplit()
problem_type = "binary"
objective = get_objective("Log Loss Binary", return_instance=True)
additional_objectives = []
optimize_thresholds = False
error_callback = err_call
random_seed = 0
ensembling_indices = [0]
automl_data = AutoMLSearchStruct(data_splitter, problem_type, objective, additional_objectives,
optimize_thresholds, error_callback, random_seed, ensembling_indices)
class TestLRCPipeline(BinaryClassificationPipeline):
component_graph = ["Logistic Regression Classifier"]
class TestSVMPipeline(BinaryClassificationPipeline):
component_graph = ["SVM Classifier"]
class TestCBPipeline(BinaryClassificationPipeline):
component_graph = ["CatBoost Classifier"]