/
predict.py
91 lines (70 loc) · 3.08 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class Predict:
'''Class for making predictions on the models that are stored
in the Scan() object'''
def __init__(self, scan_object):
'''Takes in as input a Scan() object and returns and object
with properties for `predict` and `predict_classes`'''
self.scan_object = scan_object
self.data = scan_object.data
def predict(self,
x,
metric,
asc,
model_id=None,
saved=False,
custom_objects=None):
'''Makes a probability prediction from input x. If model_id
is not given, then best_model will be used.
x | array | data to be used for the predictions
model_id | int | the id of the model from the Scan() object
metric | str | the metric to be used for picking best model
asc | bool | True if `metric` is something to be minimized
saved | bool | if a model saved on local machine should be used
custom_objects | dict | if the model has a custom object,
pass it here
'''
if model_id is None:
from ..utils.best_model import best_model
model_id = best_model(self.scan_object, metric, asc)
from ..utils.best_model import activate_model
model = activate_model(self.scan_object,
model_id,
saved,
custom_objects)
return model.predict(x)
def predict_classes(self,
x,
metric,
asc,
task,
model_id=None,
saved=False,
custom_objects=None):
'''Makes a class prediction from input x. If model_id
is not given, then best_model will be used.
x | array | data to be used for the predictions
model_id | int | the id of the model from the Scan() object
metric | str | the metric to be used for picking best model
asc | bool | True if `metric` is something to be minimized
task | string | 'binary' or 'multi_label'
saved | bool | if a model saved on local machine should be used
custom_objects | dict | if the model has a custom object, pass it here
'''
import numpy as np
if model_id is None:
from ..utils.best_model import best_model
model_id = best_model(self.scan_object, metric, asc)
from ..utils.best_model import activate_model
model = activate_model(self.scan_object,
model_id,
saved,
custom_objects)
# make (class) predictions with the model
preds = model.predict(x)
if task == 'binary':
return np.where(preds >= 0.5, 1, 0)
elif task == 'multi_label':
return np.argmax(preds, 1)
else:
msg = 'Only `binary` and `multi_label` are supported'
raise AttributeError(msg)