New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue 430/all probabilities get predictions #433
Changes from 4 commits
db514c2
3c4c4a5
b5f72f9
cecce05
d72d001
8620280
6a124c8
70b22ad
20b1b8f
4b53eaf
4da69cd
8c1eaca
bbd8f50
67587d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import argparse | ||
import logging | ||
import os | ||
import sys | ||
|
||
from skll.data.readers import EXT_TO_READER | ||
from skll.learner import Learner | ||
|
@@ -26,7 +27,8 @@ class Predictor(object): | |
predictions for feature strings. | ||
""" | ||
|
||
def __init__(self, model_path, threshold=None, positive_label=1, logger=None): | ||
def __init__(self, model_path, threshold=None, positive_label=1, | ||
return_all_probabilities=False, logger=None): | ||
desilinguist marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Initialize the predictor. | ||
|
||
|
@@ -46,6 +48,10 @@ def __init__(self, model_path, threshold=None, positive_label=1, logger=None): | |
predicting. 1 = second class, which is default | ||
for binary classification. | ||
Defaults to 1. | ||
return_all_probabilities: bool | ||
A flag indicating whether to return the probabilities for all | ||
labels in each row instead of just returning the probability of | ||
`positive_label`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a nitpick, but can we add "Defaults to |
||
logger : logging object, optional | ||
A logging object. If ``None`` is passed, get logger from ``__name__``. | ||
Defaults to ``None``. | ||
|
@@ -54,6 +60,8 @@ def __init__(self, model_path, threshold=None, positive_label=1, logger=None): | |
self._learner = Learner.from_file(model_path) | ||
self._pos_index = positive_label | ||
self.threshold = threshold | ||
self.all_probs = return_all_probabilities | ||
desilinguist marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.output_file_header = None | ||
|
||
def predict(self, data): | ||
""" | ||
|
@@ -71,18 +79,29 @@ def predict(self, data): | |
# compute the predictions from the learner | ||
preds = self._learner.predict(data) | ||
preds = preds.tolist() | ||
labels = self._learner.label_list | ||
|
||
# Create file header list, and transform predictions as needed | ||
# depending on the specified prediction arguments. | ||
if self._learner.probability: | ||
if self.threshold is None: | ||
return [pred[self._pos_index] for pred in preds] | ||
if self.all_probs: | ||
desilinguist marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.output_file_header = ["id"] + [str(x) for x in labels] | ||
elif self.threshold is None: | ||
label = self._learner.label_dict[self._pos_index] | ||
desilinguist marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.output_file_header = ["id", | ||
"Probability of '{}'".format(label)] | ||
preds = [pred[self._pos_index] for pred in preds] | ||
else: | ||
return [int(pred[self._pos_index] >= self.threshold) | ||
for pred in preds] | ||
self.output_file_header = ["id", "prediction"] | ||
preds = [int(pred[self._pos_index] >= self.threshold) | ||
for pred in preds] | ||
elif self._learner.model._estimator_type == 'regressor': | ||
return preds | ||
self.output_file_header = ["id", "prediction"] | ||
else: | ||
return [self._learner.label_list[pred if isinstance(pred, int) else | ||
int(pred[0])] for pred in preds] | ||
self.output_file_header = ["id", "prediction"] | ||
preds = [labels[pred if isinstance(pred, int) else int(pred[0])] | ||
for pred in preds] | ||
return preds | ||
|
||
|
||
def main(argv=None): | ||
|
@@ -130,14 +149,25 @@ def main(argv=None): | |
parser.add_argument('-q', '--quiet', | ||
help='Suppress printing of "Loading..." messages.', | ||
action='store_true') | ||
parser.add_argument('-t', '--threshold', | ||
help="If the model we're using is generating \ | ||
probabilities of the positive label, return 1 \ | ||
if it meets/exceeds the given threshold and 0 \ | ||
otherwise.", | ||
type=float) | ||
parser.add_argument('--output_file', '-o', | ||
help="Path to output tsv file. If not specified, " | ||
"predictions will be printed to stdout.") | ||
parser.add_argument('--version', action='version', | ||
version='%(prog)s {0}'.format(__version__)) | ||
probability_handling = parser.add_mutually_exclusive_group() | ||
probability_handling.add_argument('-t', '--threshold', | ||
help="If the model we're using is " | ||
"generating probabilities of the " | ||
"positive label, return 1 if it " | ||
"meets/exceeds the given threshold " | ||
"and 0 otherwise.", type=float) | ||
probability_handling.add_argument('--all_probabilities', '-a', | ||
action='store_true', | ||
help="Flag indicating whether to output " | ||
"the probabilities of all labels " | ||
"instead of just the probability " | ||
"of the positive label.") | ||
|
||
args = parser.parse_args(argv) | ||
|
||
# Make warnings from built-in warnings module get formatted more nicely | ||
|
@@ -150,10 +180,11 @@ def main(argv=None): | |
predictor = Predictor(args.model_file, | ||
positive_label=args.positive_label, | ||
threshold=args.threshold, | ||
return_all_probabilities=args.all_probabilities, | ||
logger=logger) | ||
|
||
# Iterate over all the specified input files | ||
for input_file in args.input_file: | ||
for i, input_file in enumerate(args.input_file): | ||
desilinguist marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# make sure each file extension is one we can process | ||
input_extension = os.path.splitext(input_file)[1].lower() | ||
|
@@ -169,8 +200,34 @@ def main(argv=None): | |
label_col=args.label_col, | ||
id_col=args.id_col) | ||
feature_set = reader.read() | ||
for pred in predictor.predict(feature_set): | ||
print(pred) | ||
preds = predictor.predict(feature_set) | ||
header = predictor.output_file_header | ||
|
||
if args.output_file is not None: | ||
with open(args.output_file, "a") as fout: | ||
desilinguist marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if i == 0: # Only write header once per set of input files | ||
print("\t".join(header), file=fout) | ||
if args.all_probabilities: | ||
for i, probabilities in enumerate(preds): | ||
id_ = feature_set.ids[i] | ||
probs_str = "\t".join([str(p) for p in probabilities]) | ||
print("{}\t{}".format(id_, probs_str), file=fout) | ||
else: | ||
for i, pred in enumerate(preds): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call this |
||
id_ = feature_set.ids[i] | ||
print("{}\t{}".format(id_, pred), file=fout) | ||
else: | ||
if i == 0: # Only write header once per set of input files | ||
print("\t".join(header)) | ||
if args.all_probabilities: | ||
for i, probabilities in enumerate(preds): | ||
id_ = feature_set.ids[i] | ||
probs_str = "\t".join([str(p) for p in probabilities]) | ||
print("{}\t{}".format(id_, probs_str)) | ||
else: | ||
for i, pred in enumerate(preds): | ||
id_ = feature_set.ids[i] | ||
print("{}\t{}".format(id_, pred)) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: It seems like this import is never used?