Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 430/all probabilities get predictions #433

Merged
merged 14 commits into from Dec 3, 2018
91 changes: 74 additions & 17 deletions skll/utilities/generate_predictions.py
Expand Up @@ -14,6 +14,7 @@
import argparse
import logging
import os
import sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: It seems like this import is never used?


from skll.data.readers import EXT_TO_READER
from skll.learner import Learner
Expand All @@ -26,7 +27,8 @@ class Predictor(object):
predictions for feature strings.
"""

def __init__(self, model_path, threshold=None, positive_label=1, logger=None):
def __init__(self, model_path, threshold=None, positive_label=1,
return_all_probabilities=False, logger=None):
desilinguist marked this conversation as resolved.
Show resolved Hide resolved
"""
Initialize the predictor.

Expand All @@ -46,6 +48,10 @@ def __init__(self, model_path, threshold=None, positive_label=1, logger=None):
predicting. 1 = second class, which is default
for binary classification.
Defaults to 1.
return_all_probabilities: bool
A flag indicating whether to return the probabilities for all
labels in each row instead of just returning the probability of
`positive_label`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nitpick, but can we add "Defaults to False" here?

logger : logging object, optional
A logging object. If ``None`` is passed, get logger from ``__name__``.
Defaults to ``None``.
Expand All @@ -54,6 +60,8 @@ def __init__(self, model_path, threshold=None, positive_label=1, logger=None):
self._learner = Learner.from_file(model_path)
self._pos_index = positive_label
self.threshold = threshold
self.all_probs = return_all_probabilities
desilinguist marked this conversation as resolved.
Show resolved Hide resolved
self.output_file_header = None

def predict(self, data):
"""
Expand All @@ -71,18 +79,29 @@ def predict(self, data):
# compute the predictions from the learner
preds = self._learner.predict(data)
preds = preds.tolist()
labels = self._learner.label_list

# Create file header list, and transform predictions as needed
# depending on the specified prediction arguments.
if self._learner.probability:
if self.threshold is None:
return [pred[self._pos_index] for pred in preds]
if self.all_probs:
desilinguist marked this conversation as resolved.
Show resolved Hide resolved
self.output_file_header = ["id"] + [str(x) for x in labels]
elif self.threshold is None:
label = self._learner.label_dict[self._pos_index]
desilinguist marked this conversation as resolved.
Show resolved Hide resolved
self.output_file_header = ["id",
"Probability of '{}'".format(label)]
preds = [pred[self._pos_index] for pred in preds]
else:
return [int(pred[self._pos_index] >= self.threshold)
for pred in preds]
self.output_file_header = ["id", "prediction"]
preds = [int(pred[self._pos_index] >= self.threshold)
for pred in preds]
elif self._learner.model._estimator_type == 'regressor':
return preds
self.output_file_header = ["id", "prediction"]
else:
return [self._learner.label_list[pred if isinstance(pred, int) else
int(pred[0])] for pred in preds]
self.output_file_header = ["id", "prediction"]
preds = [labels[pred if isinstance(pred, int) else int(pred[0])]
for pred in preds]
return preds


def main(argv=None):
Expand Down Expand Up @@ -130,14 +149,25 @@ def main(argv=None):
parser.add_argument('-q', '--quiet',
help='Suppress printing of "Loading..." messages.',
action='store_true')
parser.add_argument('-t', '--threshold',
help="If the model we're using is generating \
probabilities of the positive label, return 1 \
if it meets/exceeds the given threshold and 0 \
otherwise.",
type=float)
parser.add_argument('--output_file', '-o',
help="Path to output tsv file. If not specified, "
"predictions will be printed to stdout.")
parser.add_argument('--version', action='version',
version='%(prog)s {0}'.format(__version__))
probability_handling = parser.add_mutually_exclusive_group()
probability_handling.add_argument('-t', '--threshold',
help="If the model we're using is "
"generating probabilities of the "
"positive label, return 1 if it "
"meets/exceeds the given threshold "
"and 0 otherwise.", type=float)
probability_handling.add_argument('--all_probabilities', '-a',
action='store_true',
help="Flag indicating whether to output "
"the probabilities of all labels "
"instead of just the probability "
"of the positive label.")

args = parser.parse_args(argv)

# Make warnings from built-in warnings module get formatted more nicely
Expand All @@ -150,10 +180,11 @@ def main(argv=None):
predictor = Predictor(args.model_file,
positive_label=args.positive_label,
threshold=args.threshold,
return_all_probabilities=args.all_probabilities,
logger=logger)

# Iterate over all the specified input files
for input_file in args.input_file:
for i, input_file in enumerate(args.input_file):
desilinguist marked this conversation as resolved.
Show resolved Hide resolved

# make sure each file extension is one we can process
input_extension = os.path.splitext(input_file)[1].lower()
Expand All @@ -169,8 +200,34 @@ def main(argv=None):
label_col=args.label_col,
id_col=args.id_col)
feature_set = reader.read()
for pred in predictor.predict(feature_set):
print(pred)
preds = predictor.predict(feature_set)
header = predictor.output_file_header

if args.output_file is not None:
with open(args.output_file, "a") as fout:
desilinguist marked this conversation as resolved.
Show resolved Hide resolved
if i == 0: # Only write header once per set of input files
print("\t".join(header), file=fout)
if args.all_probabilities:
for i, probabilities in enumerate(preds):
id_ = feature_set.ids[i]
probs_str = "\t".join([str(p) for p in probabilities])
print("{}\t{}".format(id_, probs_str), file=fout)
else:
for i, pred in enumerate(preds):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call this j or something else, since we're using i in the outer loop? Same comment below.

id_ = feature_set.ids[i]
print("{}\t{}".format(id_, pred), file=fout)
else:
if i == 0: # Only write header once per set of input files
print("\t".join(header))
if args.all_probabilities:
for i, probabilities in enumerate(preds):
id_ = feature_set.ids[i]
probs_str = "\t".join([str(p) for p in probabilities])
print("{}\t{}".format(id_, probs_str))
else:
for i, pred in enumerate(preds):
id_ = feature_set.ids[i]
print("{}\t{}".format(id_, pred))


if __name__ == '__main__':
Expand Down
51 changes: 39 additions & 12 deletions tests/test_utilities.py
Expand Up @@ -286,16 +286,17 @@ def test_compute_eval_from_predictions_random_choice():

def check_generate_predictions(use_feature_hashing=False,
use_threshold=False,
test_on_subset=False):
test_on_subset=False,
all_probs=False):

# create some simple classification feature sets for training and testing
train_fs, test_fs = make_classification_data(num_examples=1000,
num_features=5,
use_feature_hashing=use_feature_hashing,
feature_bins=4)

proba = use_threshold or all_probs
desilinguist marked this conversation as resolved.
Show resolved Hide resolved
# create a learner that uses an SGD classifier
learner = Learner('SGDClassifier', probability=use_threshold)
learner = Learner('SGDClassifier', probability=proba)

# train the learner with grid search
learner.train(train_fs, grid_search=True)
Expand Down Expand Up @@ -325,21 +326,31 @@ def check_generate_predictions(use_feature_hashing=False,

# now use Predictor to generate the predictions and make
# sure that they are the same as before saving the model
p = gp.Predictor(model_file, threshold=threshold)
p = gp.Predictor(model_file, threshold=threshold,
return_all_probabilities=all_probs)
predictions_after_saving = p.predict(test_fs)

eq_(predictions, predictions_after_saving)


def test_generate_predictions():
possibilities = [
(True, True, True, False), (True, True, False, False),
(True, False, True, False), (True, False, True, True),
(True, False, False, False), (True, False, False, True),
(False, True, True, False), (False, True, False, False),
(False, False, True, False), (False, False, True, True),
(False, False, False, False), (False, False, False, True)]
desilinguist marked this conversation as resolved.
Show resolved Hide resolved

for (use_feature_hashing,
use_threshold,
test_on_subset) in product([True, False], [True, False], [True, False]):
yield check_generate_predictions, use_feature_hashing, use_threshold, test_on_subset
test_on_subset,
all_probabilities) in possibilities:
yield (check_generate_predictions, use_feature_hashing,
use_threshold, test_on_subset, all_probabilities)


def check_generate_predictions_console(use_threshold=False):
def check_generate_predictions_console(use_threshold=False, all_probs=False):
desilinguist marked this conversation as resolved.
Show resolved Hide resolved

# create some simple classification data without feature hashing
train_fs, test_fs = make_classification_data(num_examples=1000,
Expand All @@ -351,8 +362,9 @@ def check_generate_predictions_console(use_threshold=False):
writer = NDJWriter(input_file, test_fs)
writer.write()

proba = use_threshold or all_probs
desilinguist marked this conversation as resolved.
Show resolved Hide resolved
# create a learner that uses an SGD classifier
learner = Learner('SGDClassifier', probability=use_threshold)
learner = Learner('SGDClassifier', probability=proba)

# train the learner with grid search
learner.train(train_fs, grid_search=True)
Expand All @@ -378,6 +390,9 @@ def check_generate_predictions_console(use_threshold=False):
generate_cmd = []
if use_threshold:
generate_cmd.append('-t {}'.format(threshold))
elif all_probs:
generate_cmd.append('-a')

generate_cmd.extend([model_file, input_file])

# we need to capture stdout since that's what main() writes to
Expand All @@ -390,8 +405,19 @@ def check_generate_predictions_console(use_threshold=False):
gp.main(generate_cmd)
out = mystdout.getvalue()
err = mystderr.getvalue()
predictions_after_saving = [int(x) for x in out.strip().split('\n')]
eq_(predictions, predictions_after_saving)
output_lines = out.strip().split('\n')[1:] # Skip headers
if all_probs:
# Ignore the id (first column) in output.
predictions_after_saving = [[float(p) for p in x.split('\t')[1:]]
for x in output_lines]
else:
# Ignore the id (first column) in output.
predictions_after_saving = [int(x.split('\t')[1])
for x in output_lines]
if all_probs:
assert_array_almost_equal(predictions, predictions_after_saving)
else:
eq_(predictions, predictions_after_saving)
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
Expand All @@ -403,8 +429,9 @@ def test_generate_predictions_console():
Test generate_predictions as a console script with/without a threshold
"""

yield check_generate_predictions_console, False
yield check_generate_predictions_console, True
yield check_generate_predictions_console, False, False
yield check_generate_predictions_console, False, True
yield check_generate_predictions_console, True, False


def check_skll_convert(from_suffix, to_suffix):
Expand Down