Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set default value of grid_search in config files to be True. #465

Merged
merged 9 commits into from Feb 15, 2019
13 changes: 11 additions & 2 deletions skll/config.py
Expand Up @@ -58,7 +58,7 @@ def __init__(self):
'featuresets': '[]',
'featureset_names': '[]',
'fixed_parameters': '[]',
'grid_search': 'False',
'grid_search': 'True',
'grid_search_folds': '3',
'grid_search_jobs': '0',
'hasher_features': '0',
Expand Down Expand Up @@ -752,6 +752,13 @@ def _parse_config_file(config_path, log_level=logging.INFO):
'objectives',
logger=logger)

# if we are doing learning curves , we don't care about
# grid search
if task == 'learning_curve' and do_grid_search:
do_grid_search = False
logger.warning("Grid search is not supported during "
"learning curve generation. Disabling.")

# Check if `param_grids` is specified, but `do_grid_search` is False
if param_grid_list and not do_grid_search:
logger.warning('Since "grid_search" is set to False, the specified'
Expand Down Expand Up @@ -790,7 +797,9 @@ def _parse_config_file(config_path, log_level=logging.INFO):
'predict.')
if task in ['cross_validate', 'evaluate', 'train']:
if do_grid_search and len(grid_objectives) == 0:
raise ValueError('You must specify a list of objectives if doing grid search.')
raise ValueError('Grid search is on. Either specify a list of tuning '
'objectives or set `grid_search` to `false` in the '
'Tuning section.')
if not do_grid_search and len(grid_objectives) > 0:
logger.warning('Since "grid_search" is set to False, any specified'
' "objectives" will be ignored.')
Expand Down
4 changes: 0 additions & 4 deletions skll/experiments.py
Expand Up @@ -1048,10 +1048,6 @@ def run_configuration(config_file, local=False, overwrite=True, queue='all.q',

# No grid search or ablation for learning curve generation
if task == 'learning_curve':
if do_grid_search:
do_grid_search = False
logger.warning("Grid search is not supported during "
"learning curve generation. Ignoring.")
if ablation is None or ablation > 0:
ablation = 0
logger.warning("Ablating features is not supported during "
Expand Down
11 changes: 10 additions & 1 deletion skll/learner.py
Expand Up @@ -1946,7 +1946,7 @@ def cross_validate(self,
examples,
stratified=True,
cv_folds=10,
grid_search=False,
grid_search=True,
grid_search_folds=3,
grid_jobs=None,
grid_objective=None,
Expand Down Expand Up @@ -2049,6 +2049,15 @@ def cross_validate(self,
type_of_target(examples.labels) not in ['binary', 'multiclass']):
raise ValueError("Floating point labels must be encoded as strings for cross-validation.")

# check that we have an objective since grid search is on by default
# Note that `train()` would raise this error anyway later but it's
# better to raise this early on so rather than after a whole bunch of
# stuff has happened
if grid_search:
if not grid_objective:
raise ValueError("Grid search is on by default. You must either "
"specify a grid objective or turn off grid search.")

# Shuffle so that the folds are random for the inner grid search CV.
# If grid search is True but shuffle isn't, shuffle anyway.
# You can't shuffle a scipy sparse matrix in place, so unfortunately
Expand Down
1 change: 1 addition & 0 deletions tests/configs/test_int_labels_cv.template.cfg
Expand Up @@ -9,3 +9,4 @@ suffix=.jsonlines
[Output]

[Tuning]
grid_search=false
1 change: 1 addition & 0 deletions tests/configs/test_single_file_saved_subset.template.cfg
Expand Up @@ -6,6 +6,7 @@ task=evaluate
learners=["RandomForestClassifier"]

[Tuning]
grid_search=false

[Output]
probability=false
20 changes: 15 additions & 5 deletions tests/test_classification.py
Expand Up @@ -585,9 +585,9 @@ def test_new_labels_in_test_set_change_order():
train_fs, test_fs = make_classification_data(num_labels=3,
train_test_ratio=0.8)
# change train labels to create a gap
train_fs.labels = train_fs.labels*10
train_fs.labels = train_fs.labels * 10
# add new test labels
test_fs.labels = test_fs.labels*10
test_fs.labels = test_fs.labels * 10
test_fs.labels[-3:] = 15

learner = Learner('SVC')
Expand All @@ -604,7 +604,7 @@ def test_all_new_labels_in_test():
train_fs, test_fs = make_classification_data(num_labels=3,
train_test_ratio=0.8)
# change all test labels
test_fs.labels = test_fs.labels+3
test_fs.labels = test_fs.labels + 3

learner = Learner('SVC')
learner.train(train_fs, grid_search=False)
Expand Down Expand Up @@ -644,6 +644,7 @@ def test_xval_float_classes_as_strings():
prediction_prefix = join(_my_dir, 'output', 'float_class')
learner = Learner('LogisticRegression')
learner.cross_validate(float_class_fs,
grid_search=True,
desilinguist marked this conversation as resolved.
Show resolved Hide resolved
grid_objective='accuracy',
prediction_prefix=prediction_prefix)

Expand All @@ -663,6 +664,7 @@ def check_bad_xval_float_classes(do_stratified_xval):
learner = Learner('LogisticRegression')
learner.cross_validate(float_class_fs,
stratified=do_stratified_xval,
grid_search=True,
grid_objective='accuracy',
prediction_prefix=prediction_prefix)

Expand Down Expand Up @@ -710,7 +712,7 @@ def test_train_and_score_function():


@raises(ValueError)
def test_learner_api_grid_search_no_objective():
def check_learner_api_grid_search_no_objective(task='train'):

(train_fs,
test_fs) = make_classification_data(num_examples=500,
Expand All @@ -719,7 +721,15 @@ def test_learner_api_grid_search_no_objective():
use_feature_hashing=False,
non_negative=True)
learner = Learner('LogisticRegression')
_ = learner.train(train_fs)
if task == 'train':
_ = learner.train(train_fs)
else:
_ = learner.cross_validate(train_fs)


def test_learner_api_grid_search_no_objective():
yield check_learner_api_grid_search_no_objective, 'train'
yield check_learner_api_grid_search_no_objective, 'cross_validate'


def test_learner_api_load_into_existing_instance():
Expand Down
1 change: 0 additions & 1 deletion tests/test_cv.py
Expand Up @@ -278,7 +278,6 @@ def test_retrieve_cv_folds():
stratified=False,
cv_folds=num_folds,
grid_search=False,
grid_objective='f1_score_micro',
mulhod marked this conversation as resolved.
Show resolved Hide resolved
shuffle=False,
save_cv_folds=True)
assert_equal(skll_fold_ids, custom_cv_folds)
Expand Down