-
Notifications
You must be signed in to change notification settings - Fork 67
/
google.py
150 lines (124 loc) · 5.86 KB
/
google.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
''' Google-based Converter classes. '''
import base64
import os
import tempfile
from .audio import AudioToTextConverter
from .image import ImageToTextConverter
from pliers.stimuli.text import TextStim, ComplexTextStim
from pliers.transformers import (GoogleVisionAPITransformer,
GoogleAPITransformer)
class GoogleSpeechAPIConverter(GoogleAPITransformer, AudioToTextConverter):
''' Uses the Google Speech API to do speech-to-text transcription.
Args:
language_code (str): The language of the supplied AudioStim.
profanity_filter (bool): If set to True, will ask Google to try and
filter out profanity from the resulting Text.
speech_contexts (list): A list of a list of favored phrases or words
to assist the API. The inner list is a sequence of word tokens,
each outer element is a potential context.
'''
api_name = 'speech'
resource = 'speech'
_log_attributes = ('language_code', 'profanity_filter', 'speech_contexts')
def __init__(self, language_code='en-US', profanity_filter=False,
speech_contexts=None, *args, **kwargs):
self.language_code = language_code
self.profanity_filter = profanity_filter
self.speech_contexts = speech_contexts
super(GoogleSpeechAPIConverter, self).__init__(*args, **kwargs)
def _query_api(self, request):
request_obj = self.service.speech().recognize(body=request)
return request_obj.execute(num_retries=self.num_retries)
def _build_request(self, stim):
tmp = tempfile.mktemp() + '.flac'
stim.clip.write_audiofile(tmp, fps=stim.sampling_rate, codec='flac',
ffmpeg_params=['-ac', '1'])
with open(tmp, 'rb') as f:
data = f.read()
os.remove(tmp)
content = base64.b64encode(data).decode()
if self.speech_contexts:
speech_contexts = [{'phrases': c} for c in self.speech_contexts]
else:
speech_contexts = []
request = {
'audio': {
'content': content
},
'config': {
'encoding': 'FLAC',
'sampleRateHertz': stim.sampling_rate,
'languageCode': self.language_code,
'maxAlternatives': 1,
'profanityFilter': self.profanity_filter,
'speechContexts': speech_contexts,
'enableWordTimeOffsets': True
}
}
return request
def _convert(self, stim):
request = self._build_request(stim)
response = self._query_api(request)
if 'error' in response:
raise Exception(response['error']['message'])
offset = 0.0 if stim.onset is None else stim.onset
if 'results' in response:
for result in response['results']:
transcription = result['alternatives'][0]
words = []
for w in transcription['words']:
onset = float(w['startTime'][:-1])
duration = float(w['endTime'][:-1]) - onset
words.append(TextStim(text=w['word'],
onset=offset + onset,
duration=duration))
return ComplexTextStim(elements=words, onset=stim.onset)
class GoogleVisionAPITextConverter(GoogleVisionAPITransformer,
ImageToTextConverter):
''' Detects text within images using the Google Cloud Vision API.
Args:
handle_annotations (str): How to handle cases where there are multiple
detected text labels. Valid values are 'first' (only return the
first response as a TextStim), 'concatenate' (concatenate all
responses into a single TextStim), or 'list' (return a list of
TextStims).
args, kwargs: Optional positional and keyword arguments to pass to
the superclass init.
'''
request_type = 'TEXT_DETECTION'
response_object = 'textAnnotations'
VERSION = '1.0'
_log_attributes = ('handle_annotations', 'api_version')
def __init__(self, handle_annotations='first', *args, **kwargs):
self.handle_annotations = handle_annotations
super(GoogleVisionAPITextConverter, self).__init__(*args, **kwargs)
def _convert(self, stims):
request = self._build_request(stims)
responses = self._query_api(request)
texts = []
for i, response in enumerate(responses):
stim = stims[i]
if response and self.response_object in response:
annotations = response[self.response_object]
# Combine the annotations
if self.handle_annotations == 'first':
text = annotations[0]['description']
texts.append(TextStim(text=text, onset=stim.onset,
duration=stim.duration))
elif self.handle_annotations == 'concatenate':
text = ''
for annotation in annotations:
text = ' '.join([text, annotation['description']])
texts.append(TextStim(text=text, onset=stim.onset,
duration=stim.duration))
elif self.handle_annotations == 'list':
for annotation in annotations:
texts.append(TextStim(text=annotation['description'],
onset=stim.onset,
duration=stim.duration))
elif 'error' in response:
raise Exception(response['error']['message'])
else:
texts.append(TextStim(text='', onset=stim.onset,
duration=stim.duration))
return texts