Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed May 7, 2024
1 parent 5b58fdb commit cd6c460
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions alpha_automl/hyperparameter_tuning/smac.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from alpha_automl.scorer import make_scorer, make_splitter
from alpha_automl.utils import create_object
from alpha_automl.primitive_loader import PRIMITIVE_TYPES
from alpha_automl.pipeline_synthesis.pipeline_builder import extract_estimators

logger = logging.getLogger(__name__)
SMAC_PARAMETERS_PATH = join(dirname(__file__), 'smac_parameters.json')
Expand Down Expand Up @@ -47,6 +48,10 @@ def gen_pipeline(config, pipeline):
transformers.append((trans_name, trans_obj, trans_index))
step_obj.__dict__['transformers'] = transformers
new_pipeline.steps.append([step_name, create_object(step_name, step_obj.__dict__)])
elif step_type == 'CLASSIFICATION_MULTI_ENSEMBLER' or step_type == 'REGRESSION_MULTI_ENSEMBLER':
estimators = extract_estimators(pipeline, PRIMITIVE_TYPES)
primitive_object = create_object(step_name, {'estimators': estimators})
new_pipeline.steps.append([step_name, primitive_object])
else:
new_pipeline.steps.append([step_name, create_object(step_name, get_primitive_params(config, step_name))])

Expand Down

0 comments on commit cd6c460

Please sign in to comment.