Skip to content

Commit

Permalink
Merge pull request google-ai-edge#79 from googlesamples/audio-live-st…
Browse files Browse the repository at this point in the history
…ream-py

Adding python sample for live audio classification
  • Loading branch information
PaulTR committed Mar 21, 2023
2 parents e8d4275 + fd57675 commit d3d097f
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A module to record audio in a streaming basis."""
import threading
import numpy as np

try:
# pylint: disable=g-import-not-at-top
import sounddevice as sd
# pylint: enable=g-import-not-at-top
except OSError as oe:
sd = None
sd_error = oe
except ImportError as ie:
sd = None
sd_error = ie


class AudioRecord(object):
"""A class to record audio in a streaming basis."""

def __init__(self, channels: int, sampling_rate: int,
buffer_size: int) -> None:

# If microphone is not detected, may need to change default
# You can list available devices from a Linux terminal with
# python3 -m sounddevice to get the device id for below.

# sd.default.device = 7
# print(sd.query_devices(device=None, kind='input'))

"""Creates an AudioRecord instance.
Args:
channels: Number of input channels.
sampling_rate: Sampling rate in Hertz.
buffer_size: Size of the ring buffer in number of samples.
Raises:
ValueError: if any of the arguments is non-positive.
ImportError: if failed to import `sounddevice`.
OSError: if failed to load `PortAudio`.
"""
if sd is None:
raise sd_error

if channels <= 0:
raise ValueError('channels must be postive.')
if sampling_rate <= 0:
raise ValueError('sampling_rate must be postive.')
if buffer_size <= 0:
raise ValueError('buffer_size must be postive.')


self._audio_buffer = []
self._buffer_size = buffer_size
self._channels = channels
self._sampling_rate = sampling_rate

# Create a ring buffer to store the input audio.
self._buffer = np.zeros([buffer_size, channels], dtype=float)
self._lock = threading.Lock()

def audio_callback(data, *_):
"""A callback to receive recorded audio data from sounddevice."""
self._lock.acquire()
shift = len(data)
if shift > buffer_size:
self._buffer = np.copy(data[:buffer_size])
else:
self._buffer = np.roll(self._buffer, -shift, axis=0)
self._buffer[-shift:, :] = np.copy(data)
self._lock.release()

# Create an input stream to continuously capture the audio data.
self._stream = sd.InputStream(
channels=channels,
samplerate=sampling_rate,
callback=audio_callback,
)

@property
def channels(self) -> int:
return self._channels

@property
def sampling_rate(self) -> int:
return self._sampling_rate

@property
def buffer_size(self) -> int:
return self._buffer_size

def start_recording(self) -> None:
"""Starts the audio recording."""
# Clear the internal ring buffer.
self._buffer.fill(0)

# Start recording using sounddevice's InputStream.
self._stream.start()

def stop(self) -> None:
"""Stops the audio recording."""
self._stream.stop()

def read(self, size: int) -> np.ndarray:
"""Reads the latest audio data captured in the buffer.
Args:
size: Number of samples to read from the buffer.
Returns:
A NumPy array containing the audio data.
Raises:
ValueError: Raised if `size` is larger than the buffer size.
"""
if size > self._buffer_size:
raise ValueError('Cannot read more samples than the size of the buffer.')
elif size <= 0:
raise ValueError('Size must be positive.')

start_index = self._buffer_size - size
return np.copy(self._buffer[start_index:])
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main scripts to run audio classification."""

import argparse
import time
import audio_record

from mediapipe.tasks import python
from mediapipe.tasks.python.components import containers
from mediapipe.tasks.python import audio
from utils import Plotter


def run(model: str, max_results: int, score_threshold: float,
overlapping_factor: float) -> None:
"""Continuously run inference on audio data acquired from the device.
Args:
model: Name of the TFLite audio classification model.
max_results: Maximum number of classification results to display.
score_threshold: The score threshold of classification results.
overlapping_factor: Target overlapping between adjacent inferences.
"""

if (overlapping_factor <= 0) or (overlapping_factor >= 1.0):
raise ValueError('Overlapping factor must be between 0 and 1.')

if (score_threshold < 0) or (score_threshold > 1.0):
raise ValueError('Score threshold must be between (inclusive) 0 and 1.')

classification_result_list = []
# Initialize a plotter instance to display the classification results.
plotter = Plotter()

def save_result(result: audio.AudioClassifierResult, timestamp_ms: int):
result.timestamp_ms = timestamp_ms
classification_result_list.append(result)

# Initialize the audio classification model.
base_options = python.BaseOptions(model_asset_path=model)
options = audio.AudioClassifierOptions(
base_options=base_options, running_mode=audio.RunningMode.AUDIO_STREAM,
max_results=max_results, score_threshold=score_threshold,
result_callback=save_result)
classifier = audio.AudioClassifier.create_from_options(options)

# Initialize the audio recorder and a tensor to store the audio input.
# The sample rate may need to be changed to match your input device.
# For example, an AT2020 requires sample_rate 44100.
buffer_size, sample_rate, num_channels = 15600, 16000, 1
audio_format = containers.AudioDataFormat(num_channels, sample_rate)
record = audio_record.AudioRecord(num_channels, sample_rate, buffer_size)
audio_data = containers.AudioData(buffer_size, audio_format)

# We'll try to run inference every interval_between_inference seconds.
# This is usually half of the model's input length to create an overlapping
# between incoming audio segments to improve classification accuracy.
input_length_in_second = float(len(
audio_data.buffer)) / audio_data.audio_format.sample_rate
interval_between_inference = input_length_in_second * (1 - overlapping_factor)
pause_time = interval_between_inference * 0.1
last_inference_time = time.time()

# Start audio recording in the background.
record.start_recording()

# Loop until the user close the classification results plot.
while True:
# Wait until at least interval_between_inference seconds has passed since
# the last inference.
now = time.time()
diff = now - last_inference_time
if diff < interval_between_inference:
time.sleep(pause_time)
continue
last_inference_time = now

# Load the input audio from the AudioRecord instance and run classify.
data = record.read(buffer_size)
# audio_data.load_from_array(data.astype(np.float32))
audio_data.load_from_array(data)
classifier.classify_async(audio_data, round(last_inference_time * 1000))

# print(classification_result_list)
# # Plot the classification results.
if classification_result_list:
print(classification_result_list)
plotter.plot(classification_result_list[0])
classification_result_list.clear()


def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--model',
help='Name of the audio classification model.',
required=False,
default='yamnet.tflite')
parser.add_argument(
'--maxResults',
help='Maximum number of results to show.',
required=False,
default=5)
parser.add_argument(
'--overlappingFactor',
help='Target overlapping between adjacent inferences. Value must be in (0, 1)',
required=False,
default=0.5)
parser.add_argument(
'--scoreThreshold',
help='The score threshold of classification results.',
required=False,
default=0.0)
args = parser.parse_args()

run(args.model, int(args.maxResults), float(args.scoreThreshold),
float(args.overlappingFactor))


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sounddevice
mediapipe
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A module with util functions."""
import sys

from mediapipe.tasks.python import audio
from matplotlib import rcParams
import matplotlib.pyplot as plt

rcParams.update({
# Set the plot left margin so that the labels are visible.
'figure.subplot.left': 0.3,

# Hide the bottom toolbar.
'toolbar': 'None'
})


class Plotter(object):
"""An util class to display the classification results."""

_PAUSE_TIME = 0.05
"""Time for matplotlib to wait for UI event."""

def __init__(self) -> None:
fig, self._axes = plt.subplots()
fig.canvas.manager.set_window_title('Audio classification')

# Stop the program when the ESC key is pressed.
def event_callback(event):
if event.key == 'escape':
sys.exit(0)

fig.canvas.mpl_connect('key_press_event', event_callback)

plt.show(block=False)

def plot(self, result: audio.AudioClassifierResult) -> None:
"""Plot the audio classification result.
Args:
result: Classification results returned by an audio classification
model.
"""
# Clear the axes
self._axes.cla()
self._axes.set_title('Press ESC to exit.')
self._axes.set_xlim((0, 1))

# Plot the results so that the most probable category comes at the top.
classification = result.classifications[0]
label_list = [category.category_name
for category in classification.categories]
score_list = [category.score for category in classification.categories]
self._axes.barh(label_list[::-1], score_list[::-1])

# Wait for the UI event.
plt.pause(self._PAUSE_TIME)

0 comments on commit d3d097f

Please sign in to comment.