Skip to content

Commit

Permalink
Merge pull request #296 from chriamue/tensorflow
Browse files Browse the repository at this point in the history
Tensorflow
  • Loading branch information
Uberi committed Dec 11, 2017
2 parents c932096 + 0a7bf7c commit 19dc36e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
26 changes: 26 additions & 0 deletions examples/tensorflow_commands.py
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
import time
import speech_recognition as sr
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio # noqa

# obtain audio from the microphone
r = sr.Recognizer()
m = sr.Microphone()

with m as source:
r.adjust_for_ambient_noise(source)


def callback(recognizer, audio):
try:
# You can download the data here: http://download.tensorflow.org/models/speech_commands_v0.01.zip
spoken = recognizer.recognize_tensorflow(audio, tensor_graph='speech_recognition/tensorflow-data/conv_actions_frozen.pb', tensor_label='speech_recognition/tensorflow-data/conv_actions_labels.txt')
print(spoken)
except sr.UnknownValueError:
print("Tensorflow could not understand audio")
except sr.RequestError as e:
print("Could not request results from Tensorflow service; {0}".format(e))


stop_listening = r.listen_in_background(m, callback, phrase_time_limit=0.6)
time.sleep(100)
47 changes: 47 additions & 0 deletions speech_recognition/__init__.py
Expand Up @@ -1214,6 +1214,53 @@ def recognize_ibm(self, audio_data, username, password, language="en-US", show_a
transcription.append(hypothesis["transcript"])
return "\n".join(transcription)

lasttfgraph = ''
tflabels = None

def recognize_tensorflow(self, audio_data, tensor_graph='tensorflow-data/conv_actions_frozen.pb', tensor_label='tensorflow-data/conv_actions_labels.txt'):
"""
Performs speech recognition on ``audio_data`` (an ``AudioData`` instance).
Path to Tensor loaded from ``tensor_graph``. You can download a model here: http://download.tensorflow.org/models/speech_commands_v0.01.zip
Path to Tensor Labels file loaded from ``tensor_label``.
"""
assert isinstance(audio_data, AudioData), "Data must be audio data"
assert isinstance(tensor_graph, str), "``tensor_graph`` must be a string"
assert isinstance(tensor_label, str), "``tensor_label`` must be a string"

try:
import tensorflow as tf
except ImportError:
raise RequestError("missing tensorflow module: ensure that tensorflow is set up correctly.")

if not (tensor_graph == self.lasttfgraph):
self.lasttfgraph = tensor_graph

# load graph
with tf.gfile.FastGFile(tensor_graph, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# load labels
self.tflabels = [line.rstrip() for line in tf.gfile.GFile(tensor_label)]

wav_data = audio_data.get_wav_data(
convert_rate=16000, convert_width=2
)

with tf.Session() as sess:
input_layer_name = 'wav_data:0'
output_layer_name = 'labels_softmax:0'
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})

# Sort labels in order of confidence
top_k = predictions.argsort()[-1:][::-1]
for node_id in top_k:
human_string = self.tflabels[node_id]
return human_string


def get_flac_converter():
"""Returns the absolute path of a FLAC converter executable, or raises an OSError if none can be found."""
Expand Down

0 comments on commit 19dc36e

Please sign in to comment.