forked from google-ai-edge/mediapipe
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request google-ai-edge#79 from googlesamples/audio-live-st…
…ream-py Adding python sample for live audio classification
- Loading branch information
Showing
4 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
135 changes: 135 additions & 0 deletions
135
examples/audio_classifier/python/audio_classification_live_stream/audio_record.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""A module to record audio in a streaming basis.""" | ||
import threading | ||
import numpy as np | ||
|
||
try: | ||
# pylint: disable=g-import-not-at-top | ||
import sounddevice as sd | ||
# pylint: enable=g-import-not-at-top | ||
except OSError as oe: | ||
sd = None | ||
sd_error = oe | ||
except ImportError as ie: | ||
sd = None | ||
sd_error = ie | ||
|
||
|
||
class AudioRecord(object): | ||
"""A class to record audio in a streaming basis.""" | ||
|
||
def __init__(self, channels: int, sampling_rate: int, | ||
buffer_size: int) -> None: | ||
|
||
# If microphone is not detected, may need to change default | ||
# You can list available devices from a Linux terminal with | ||
# python3 -m sounddevice to get the device id for below. | ||
|
||
# sd.default.device = 7 | ||
# print(sd.query_devices(device=None, kind='input')) | ||
|
||
"""Creates an AudioRecord instance. | ||
Args: | ||
channels: Number of input channels. | ||
sampling_rate: Sampling rate in Hertz. | ||
buffer_size: Size of the ring buffer in number of samples. | ||
Raises: | ||
ValueError: if any of the arguments is non-positive. | ||
ImportError: if failed to import `sounddevice`. | ||
OSError: if failed to load `PortAudio`. | ||
""" | ||
if sd is None: | ||
raise sd_error | ||
|
||
if channels <= 0: | ||
raise ValueError('channels must be postive.') | ||
if sampling_rate <= 0: | ||
raise ValueError('sampling_rate must be postive.') | ||
if buffer_size <= 0: | ||
raise ValueError('buffer_size must be postive.') | ||
|
||
|
||
self._audio_buffer = [] | ||
self._buffer_size = buffer_size | ||
self._channels = channels | ||
self._sampling_rate = sampling_rate | ||
|
||
# Create a ring buffer to store the input audio. | ||
self._buffer = np.zeros([buffer_size, channels], dtype=float) | ||
self._lock = threading.Lock() | ||
|
||
def audio_callback(data, *_): | ||
"""A callback to receive recorded audio data from sounddevice.""" | ||
self._lock.acquire() | ||
shift = len(data) | ||
if shift > buffer_size: | ||
self._buffer = np.copy(data[:buffer_size]) | ||
else: | ||
self._buffer = np.roll(self._buffer, -shift, axis=0) | ||
self._buffer[-shift:, :] = np.copy(data) | ||
self._lock.release() | ||
|
||
# Create an input stream to continuously capture the audio data. | ||
self._stream = sd.InputStream( | ||
channels=channels, | ||
samplerate=sampling_rate, | ||
callback=audio_callback, | ||
) | ||
|
||
@property | ||
def channels(self) -> int: | ||
return self._channels | ||
|
||
@property | ||
def sampling_rate(self) -> int: | ||
return self._sampling_rate | ||
|
||
@property | ||
def buffer_size(self) -> int: | ||
return self._buffer_size | ||
|
||
def start_recording(self) -> None: | ||
"""Starts the audio recording.""" | ||
# Clear the internal ring buffer. | ||
self._buffer.fill(0) | ||
|
||
# Start recording using sounddevice's InputStream. | ||
self._stream.start() | ||
|
||
def stop(self) -> None: | ||
"""Stops the audio recording.""" | ||
self._stream.stop() | ||
|
||
def read(self, size: int) -> np.ndarray: | ||
"""Reads the latest audio data captured in the buffer. | ||
Args: | ||
size: Number of samples to read from the buffer. | ||
Returns: | ||
A NumPy array containing the audio data. | ||
Raises: | ||
ValueError: Raised if `size` is larger than the buffer size. | ||
""" | ||
if size > self._buffer_size: | ||
raise ValueError('Cannot read more samples than the size of the buffer.') | ||
elif size <= 0: | ||
raise ValueError('Size must be positive.') | ||
|
||
start_index = self._buffer_size - size | ||
return np.copy(self._buffer[start_index:]) |
134 changes: 134 additions & 0 deletions
134
examples/audio_classifier/python/audio_classification_live_stream/classify.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Main scripts to run audio classification.""" | ||
|
||
import argparse | ||
import time | ||
import audio_record | ||
|
||
from mediapipe.tasks import python | ||
from mediapipe.tasks.python.components import containers | ||
from mediapipe.tasks.python import audio | ||
from utils import Plotter | ||
|
||
|
||
def run(model: str, max_results: int, score_threshold: float, | ||
overlapping_factor: float) -> None: | ||
"""Continuously run inference on audio data acquired from the device. | ||
Args: | ||
model: Name of the TFLite audio classification model. | ||
max_results: Maximum number of classification results to display. | ||
score_threshold: The score threshold of classification results. | ||
overlapping_factor: Target overlapping between adjacent inferences. | ||
""" | ||
|
||
if (overlapping_factor <= 0) or (overlapping_factor >= 1.0): | ||
raise ValueError('Overlapping factor must be between 0 and 1.') | ||
|
||
if (score_threshold < 0) or (score_threshold > 1.0): | ||
raise ValueError('Score threshold must be between (inclusive) 0 and 1.') | ||
|
||
classification_result_list = [] | ||
# Initialize a plotter instance to display the classification results. | ||
plotter = Plotter() | ||
|
||
def save_result(result: audio.AudioClassifierResult, timestamp_ms: int): | ||
result.timestamp_ms = timestamp_ms | ||
classification_result_list.append(result) | ||
|
||
# Initialize the audio classification model. | ||
base_options = python.BaseOptions(model_asset_path=model) | ||
options = audio.AudioClassifierOptions( | ||
base_options=base_options, running_mode=audio.RunningMode.AUDIO_STREAM, | ||
max_results=max_results, score_threshold=score_threshold, | ||
result_callback=save_result) | ||
classifier = audio.AudioClassifier.create_from_options(options) | ||
|
||
# Initialize the audio recorder and a tensor to store the audio input. | ||
# The sample rate may need to be changed to match your input device. | ||
# For example, an AT2020 requires sample_rate 44100. | ||
buffer_size, sample_rate, num_channels = 15600, 16000, 1 | ||
audio_format = containers.AudioDataFormat(num_channels, sample_rate) | ||
record = audio_record.AudioRecord(num_channels, sample_rate, buffer_size) | ||
audio_data = containers.AudioData(buffer_size, audio_format) | ||
|
||
# We'll try to run inference every interval_between_inference seconds. | ||
# This is usually half of the model's input length to create an overlapping | ||
# between incoming audio segments to improve classification accuracy. | ||
input_length_in_second = float(len( | ||
audio_data.buffer)) / audio_data.audio_format.sample_rate | ||
interval_between_inference = input_length_in_second * (1 - overlapping_factor) | ||
pause_time = interval_between_inference * 0.1 | ||
last_inference_time = time.time() | ||
|
||
# Start audio recording in the background. | ||
record.start_recording() | ||
|
||
# Loop until the user close the classification results plot. | ||
while True: | ||
# Wait until at least interval_between_inference seconds has passed since | ||
# the last inference. | ||
now = time.time() | ||
diff = now - last_inference_time | ||
if diff < interval_between_inference: | ||
time.sleep(pause_time) | ||
continue | ||
last_inference_time = now | ||
|
||
# Load the input audio from the AudioRecord instance and run classify. | ||
data = record.read(buffer_size) | ||
# audio_data.load_from_array(data.astype(np.float32)) | ||
audio_data.load_from_array(data) | ||
classifier.classify_async(audio_data, round(last_inference_time * 1000)) | ||
|
||
# print(classification_result_list) | ||
# # Plot the classification results. | ||
if classification_result_list: | ||
print(classification_result_list) | ||
plotter.plot(classification_result_list[0]) | ||
classification_result_list.clear() | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument( | ||
'--model', | ||
help='Name of the audio classification model.', | ||
required=False, | ||
default='yamnet.tflite') | ||
parser.add_argument( | ||
'--maxResults', | ||
help='Maximum number of results to show.', | ||
required=False, | ||
default=5) | ||
parser.add_argument( | ||
'--overlappingFactor', | ||
help='Target overlapping between adjacent inferences. Value must be in (0, 1)', | ||
required=False, | ||
default=0.5) | ||
parser.add_argument( | ||
'--scoreThreshold', | ||
help='The score threshold of classification results.', | ||
required=False, | ||
default=0.0) | ||
args = parser.parse_args() | ||
|
||
run(args.model, int(args.maxResults), float(args.scoreThreshold), | ||
float(args.overlappingFactor)) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
2 changes: 2 additions & 0 deletions
2
examples/audio_classifier/python/audio_classification_live_stream/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
sounddevice | ||
mediapipe |
68 changes: 68 additions & 0 deletions
68
examples/audio_classifier/python/audio_classification_live_stream/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""A module with util functions.""" | ||
import sys | ||
|
||
from mediapipe.tasks.python import audio | ||
from matplotlib import rcParams | ||
import matplotlib.pyplot as plt | ||
|
||
rcParams.update({ | ||
# Set the plot left margin so that the labels are visible. | ||
'figure.subplot.left': 0.3, | ||
|
||
# Hide the bottom toolbar. | ||
'toolbar': 'None' | ||
}) | ||
|
||
|
||
class Plotter(object): | ||
"""An util class to display the classification results.""" | ||
|
||
_PAUSE_TIME = 0.05 | ||
"""Time for matplotlib to wait for UI event.""" | ||
|
||
def __init__(self) -> None: | ||
fig, self._axes = plt.subplots() | ||
fig.canvas.manager.set_window_title('Audio classification') | ||
|
||
# Stop the program when the ESC key is pressed. | ||
def event_callback(event): | ||
if event.key == 'escape': | ||
sys.exit(0) | ||
|
||
fig.canvas.mpl_connect('key_press_event', event_callback) | ||
|
||
plt.show(block=False) | ||
|
||
def plot(self, result: audio.AudioClassifierResult) -> None: | ||
"""Plot the audio classification result. | ||
Args: | ||
result: Classification results returned by an audio classification | ||
model. | ||
""" | ||
# Clear the axes | ||
self._axes.cla() | ||
self._axes.set_title('Press ESC to exit.') | ||
self._axes.set_xlim((0, 1)) | ||
|
||
# Plot the results so that the most probable category comes at the top. | ||
classification = result.classifications[0] | ||
label_list = [category.category_name | ||
for category in classification.categories] | ||
score_list = [category.score for category in classification.categories] | ||
self._axes.barh(label_list[::-1], score_list[::-1]) | ||
|
||
# Wait for the UI event. | ||
plt.pause(self._PAUSE_TIME) |