Skip to content

Commit

Permalink
Merge pull request #571 from EducationalTestingService/make-pos-label…
Browse files Browse the repository at this point in the history
…-default-values-consistent

Make `pos_label_str` values consistent between API and configuration file
  • Loading branch information
desilinguist committed Oct 22, 2019
2 parents d70436d + 0109b1f commit 9a81b14
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
4 changes: 4 additions & 0 deletions skll/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,11 @@ def _parse_config_file(config_path, log_level=logging.INFO):
"sampler_parameters"))
fixed_sampler_parameters = yaml.safe_load(fixed_sampler_parameters)
param_grid_list = yaml.safe_load(_fix_json(config.get("Tuning", "param_grids")))

# read and normalize the value of `pos_label_str`
pos_label_str = safe_float(config.get("Tuning", "pos_label_str"))
if pos_label_str == '':
pos_label_str = None

# ensure that feature_scaling is specified only as one of the
# four available choices
Expand Down
41 changes: 41 additions & 0 deletions tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,3 +2016,44 @@ def test_config_parsing_param_grids_fixed_parameters_conflict():
'parameter values will take precedence.')
matches = re.findall(warning_pattern, f.read())
assert_equal(len(matches), 1)


def test_config_parsing_default_pos_label_str_value():
"""
Check that the default value of `pos_label_str` gets set to `None`
"""

train_dir = join('..', 'train')
test_dir = join('..', 'test')
output_dir = join(_my_dir, 'output')

values_to_fill_dict = {'experiment_name': 'config_parsing',
'task': 'evaluate',
'train_directory': train_dir,
'test_directory': test_dir,
'featuresets': "[['f1', 'f2', 'f3']]",
'learners': "['LogisticRegression']",
'objectives': "['accuracy']",
'log': output_dir,
'results': output_dir}

config_template_path = join(_my_dir, 'configs',
'test_config_parsing.template.cfg')

config_path = fill_in_config_options(config_template_path,
values_to_fill_dict,
'default_value_pos_label_str')

(experiment_name, task, sampler, fixed_sampler_parameters,
feature_hasher, hasher_features, id_col, label_col, train_set_name,
test_set_name, suffix, featuresets, do_shuffle, model_path,
do_grid_search, grid_objectives, probability, pipeline, results_path,
pos_label_str, feature_scaling, min_feature_count, folds_file,
grid_search_jobs, grid_search_folds, cv_folds, save_cv_folds,
save_cv_models, use_folds_file_for_grid_search, do_stratified_folds,
fixed_parameter_list, param_grid_list, featureset_names, learners,
prediction_dir, log_path, train_path, test_path, ids_to_floats,
class_map, custom_learner_path, learning_curve_cv_folds_list,
learning_curve_train_sizes, output_metrics) = _parse_config_file(config_path)

eq_(pos_label_str, None)

0 comments on commit 9a81b14

Please sign in to comment.