Skip to content

Commit

Permalink
add ensembler supports
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed May 7, 2024
1 parent cd6c460 commit fb617a4
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion alpha_automl/hyperparameter_tuning/smac.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ def gen_pipeline(config, pipeline):
transformers.append((trans_name, trans_obj, trans_index))
step_obj.__dict__['transformers'] = transformers
new_pipeline.steps.append([step_name, create_object(step_name, step_obj.__dict__)])
elif step_type == 'CLASSIFICATION_SINGLE_ENSEMBLER' or step_type == 'REGRESSION_SINGLE_ENSEMBLER':
estimator = step_obj.estimator
primitive_object = create_object(step_name, {'estimator': estimator})
new_pipeline.steps.append([step_name, primitive_object])
elif step_type == 'CLASSIFICATION_MULTI_ENSEMBLER' or step_type == 'REGRESSION_MULTI_ENSEMBLER':
estimators = extract_estimators(pipeline, PRIMITIVE_TYPES)
estimators = extract_estimators_smac(step_obj, PRIMITIVE_TYPES)
primitive_object = create_object(step_name, {'estimators': estimators})
new_pipeline.steps.append([step_name, primitive_object])
else:
Expand All @@ -58,6 +62,17 @@ def gen_pipeline(config, pipeline):
return new_pipeline


def extract_estimators_smac(step_obj, config):
new_estimators = []
estimators = step_obj.estimators
while estimators:
estimator_name, estimator_obj = estimators.pop()
estimator_name_lookup, estimator_name_counter = estimator_name.split('-')
new_estimators.append((estimator_name, create_object(estimator_name_lookup, get_primitive_params(config, estimator_name_lookup))))

return new_estimators


def get_primitive_params(config, step_name):
params = list(SMAC_DICT[step_name].keys())
class_params = {}
Expand All @@ -80,6 +95,16 @@ def gen_configspace(pipeline):
trans_prim_name = trans_name.split('-')[0]
params = SMAC_DICT[trans_prim_name]
configspace.add_hyperparameters(cast_primitive(params))
# elif step_type == 'CLASSIFICATION_SINGLE_ENSEMBLER' or step_type == 'REGRESSION_SINGLE_ENSEMBLER':
# estimator_obj = prim_obj.estimator
# for smac_name, smac_params in SMAC_DICT.items():
# if estimator_obj.__class__.__name__ in smac_name:
# configspace.add_hyperparameters(cast_primitive(smac_params))
elif step_type == 'CLASSIFICATION_MULTI_ENSEMBLER' or step_type == 'REGRESSION_MULTI_ENSEMBLER':
for estimator_name, _ in prim_obj.estimators:
estimator_name_lookup, _ = estimator_name.split('-')
params = SMAC_DICT[estimator_name_lookup]
configspace.add_hyperparameters(cast_primitive(params))
except Exception as e:
logger.critical(f'[SMAC] {str(e)}')
return configspace
Expand Down

0 comments on commit fb617a4

Please sign in to comment.