-
Notifications
You must be signed in to change notification settings - Fork 67
/
clarifai.py
247 lines (213 loc) · 10.3 KB
/
clarifai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
'''
Extractors that interact with the Clarifai API.
'''
import logging
import os
from contextlib import ExitStack
from pliers.extractors.image import ImageExtractor
from pliers.extractors.video import VideoExtractor
from pliers.extractors.base import ExtractorResult
from pliers.transformers import BatchTransformerMixin
from pliers.transformers.api import APITransformer
from pliers.utils import listify, attempt_to_import, verify_dependencies
import pandas as pd
clarifai_client = attempt_to_import('clarifai.rest.client', 'clarifai_client',
['ClarifaiApp',
'Concept',
'ModelOutputConfig',
'ModelOutputInfo',
'Image',
'Video'])
class ClarifaiAPIExtractor(APITransformer):
''' Uses the Clarifai API to extract tags of visual stimuli.
Args:
api_key (str): A valid API_KEY for the Clarifai API. Only needs to be
passed the first time the extractor is initialized.
model (str): The name of the Clarifai model to use. If None, defaults
to the general image tagger.
min_value (float): A value between 0.0 and 1.0 indicating the minimum
confidence required to return a prediction. Defaults to 0.0.
max_concepts (int): A value between 0 and 200 indicating the maximum
number of label predictions returned.
select_concepts (list): List of concepts (strings) to query from the
API. For example, ['food', 'animal'].
rate_limit (int): The minimum number of seconds required between
transform calls on this Transformer.
'''
_log_attributes = ('api_key', 'model', 'model_name', 'min_value',
'max_concepts', 'select_concepts')
_env_keys = ('CLARIFAI_API_KEY',)
VERSION = '1.0'
def __init__(self, api_key=None, model='general-v1.3', min_value=None,
max_concepts=None, select_concepts=None, rate_limit=None,
batch_size=None):
verify_dependencies(['clarifai_client'])
if api_key is None:
try:
api_key = os.environ['CLARIFAI_API_KEY']
except KeyError:
raise ValueError("A valid Clarifai API API_KEY "
"must be passed the first time a Clarifai "
"extractor is initialized.")
self.api_key = api_key
try:
self.api = clarifai_client.ClarifaiApp(api_key=api_key)
self.model = self.api.models.get(model)
except clarifai_client.ApiError as e:
logging.warning(str(e))
self.api = None
self.model = None
self.model_name = model
self.min_value = min_value
self.max_concepts = max_concepts
self.select_concepts = select_concepts
if select_concepts:
select_concepts = listify(select_concepts)
self.select_concepts = [clarifai_client.Concept(concept_name=n)
for n in select_concepts]
super(ClarifaiAPIExtractor, self).__init__(rate_limit=rate_limit)
@property
def api_keys(self):
return [self.api_key]
def check_valid_keys(self):
return self.api is not None
def _query_api(self, objects):
verify_dependencies(['clarifai_client'])
moc = clarifai_client.ModelOutputConfig(min_value=self.min_value,
max_concepts=self.max_concepts,
select_concepts=self.select_concepts)
model_output_info = clarifai_client.ModelOutputInfo(output_config=moc)
tags = self.model.predict(objects, model_output_info=model_output_info)
return tags['outputs']
def _parse_annotations(self, annotation, handle_annotations=None):
"""
Parse outputs from a clarifai face extraction.
Args:
handle_annotations (str): How returned face annotations should be
handled in cases where there are multiple faces.
'first' indicates to only use the first face JSON object, all
other values will default to including every face.
"""
# check whether the model is the face detection model
if self.model_name == 'face':
# if a face was detected, get at least the boundaries
if annotation['data']:
# if specified, only return first face
if handle_annotations == 'first':
annotation = [annotation['data']['region'][0]]
# else collate all faces into a multi-row dataframe
face_results = []
for i, d in enumerate(annotation['data']['regions']):
data_dict = {}
for k, v in d['region_info']['bounding_box'].items():
data_dict[k] = v
face_results.append(data_dict)
return face_results
# return an empty dict if there was no face
else:
data_dict = {}
for tag in annotation['data']['concepts']:
data_dict[tag['name']] = tag['value']
return data_dict
class ClarifaiAPIImageExtractor(ClarifaiAPIExtractor, BatchTransformerMixin,
ImageExtractor):
''' Uses the Clarifai API to extract tags of images.
Args:
api_key (str): A valid API_KEY for the Clarifai API. Only needs to be
passed the first time the extractor is initialized.
model (str): The name of the Clarifai model to use. If None, defaults
to the general image tagger.
min_value (float): A value between 0.0 and 1.0 indicating the minimum
confidence required to return a prediction. Defaults to 0.0.
max_concepts (int): A value between 0 and 200 indicating the maximum
number of label predictions returned.
select_concepts (list): List of concepts (strings) to query from the
API. For example, ['food', 'animal'].
rate_limit (int): The minimum number of seconds required between
transform calls on this Transformer.
batch_size (int): Number of stims to send per batched API request.
'''
_batch_size = 32
def __init__(self, api_key=None, model='general-v1.3', min_value=None,
max_concepts=None, select_concepts=None, rate_limit=None,
batch_size=None):
super(ClarifaiAPIImageExtractor,
self).__init__(api_key=api_key,
model=model,
min_value=min_value,
max_concepts=max_concepts,
select_concepts=select_concepts,
rate_limit=rate_limit,
batch_size=batch_size)
def _extract(self, stims):
verify_dependencies(['clarifai_client'])
# ExitStack lets us use filename context managers simultaneously
with ExitStack() as stack:
imgs = []
for s in stims:
if s.url:
imgs.append(clarifai_client.Image(url=s.url))
else:
f = stack.enter_context(s.get_filename())
imgs.append(clarifai_client.Image(filename=f))
outputs = self._query_api(imgs)
extractions = []
for i, resp in enumerate(outputs):
extractions.append(ExtractorResult(resp, stims[i], self))
return extractions
def _to_df(self, result):
if self.model_name == 'face':
# is a list already, no need to wrap it in one
return pd.DataFrame(self._parse_annotations(result._data))
return pd.DataFrame([self._parse_annotations(result._data)])
class ClarifaiAPIVideoExtractor(ClarifaiAPIExtractor, VideoExtractor):
def _extract(self, stim):
verify_dependencies(['clarifai_client'])
with stim.get_filename() as filename:
vids = [clarifai_client.Video(filename=filename)]
outputs = self._query_api(vids)
return ExtractorResult(outputs, stim, self)
def _to_df(self, result):
onsets = []
durations = []
data = []
frames = result._data[0]['data']['frames']
for i, frame_res in enumerate(frames):
tmp_res = self._parse_annotations(frame_res)
# if we detect multiple faces, the parsed annotation can be multi-line
if type(tmp_res) == list:
for d in tmp_res:
data.append(d)
onset = frame_res['frame_info']['time'] / 1000.0
if (i + 1) == len(frames):
end = result.stim.duration
else:
end = frames[i + 1]['frame_info']['time'] / 1000.0
onsets.append(onset)
durations.append(max([end - onset, 0]))
result._onsets = onsets
result._durations = durations
df = pd.DataFrame(data)
result.features = list(df.columns)
else:
data.append(tmp_res)
onset = frame_res['frame_info']['time'] / 1000.0
if (i + 1) == len(frames):
end = result.stim.duration
else:
end = frames[i+1]['frame_info']['time'] / 1000.0
onsets.append(onset)
# NOTE: As of Clarifai API v2 and client library 2.6.1, the API
# returns more frames than it should—at least for some videos.
# E.g., given a 5.5 second clip, it may return 7 frames, with the
# last beginning at 6000 ms. Since this appears to be a problem on
# the Clarifai end, and it's not actually clear how they're getting
# this imaginary frame (I'm guessing it's the very last frame?),
# we're not going to do anything about it here, except to make sure
# that durations aren't negative.
durations.append(max([end - onset, 0]))
result._onsets = onsets
result._durations = durations
df = pd.DataFrame(data)
result.features = list(df.columns)
return df