From fd5767545963cfbad1987c3ca795dd4d5b343892 Mon Sep 17 00:00:00 2001 From: Paul Ruiz Date: Tue, 21 Mar 2023 15:41:53 -0600 Subject: [PATCH] Adding python sample for live audio classification --- .../audio_record.py | 135 ++++++++++++++++++ .../classify.py | 134 +++++++++++++++++ .../requirements.txt | 2 + .../audio_classification_live_stream/utils.py | 68 +++++++++ 4 files changed, 339 insertions(+) create mode 100644 examples/audio_classifier/python/audio_classification_live_stream/audio_record.py create mode 100644 examples/audio_classifier/python/audio_classification_live_stream/classify.py create mode 100644 examples/audio_classifier/python/audio_classification_live_stream/requirements.txt create mode 100644 examples/audio_classifier/python/audio_classification_live_stream/utils.py diff --git a/examples/audio_classifier/python/audio_classification_live_stream/audio_record.py b/examples/audio_classifier/python/audio_classification_live_stream/audio_record.py new file mode 100644 index 0000000000..27e6f0127f --- /dev/null +++ b/examples/audio_classifier/python/audio_classification_live_stream/audio_record.py @@ -0,0 +1,135 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A module to record audio in a streaming basis.""" +import threading +import numpy as np + +try: +# pylint: disable=g-import-not-at-top + import sounddevice as sd +# pylint: enable=g-import-not-at-top +except OSError as oe: + sd = None + sd_error = oe +except ImportError as ie: + sd = None + sd_error = ie + + +class AudioRecord(object): + """A class to record audio in a streaming basis.""" + + def __init__(self, channels: int, sampling_rate: int, + buffer_size: int) -> None: + + # If microphone is not detected, may need to change default + # You can list available devices from a Linux terminal with + # python3 -m sounddevice to get the device id for below. + + # sd.default.device = 7 + # print(sd.query_devices(device=None, kind='input')) + + """Creates an AudioRecord instance. + + Args: + channels: Number of input channels. + sampling_rate: Sampling rate in Hertz. + buffer_size: Size of the ring buffer in number of samples. + + Raises: + ValueError: if any of the arguments is non-positive. + ImportError: if failed to import `sounddevice`. + OSError: if failed to load `PortAudio`. + """ + if sd is None: + raise sd_error + + if channels <= 0: + raise ValueError('channels must be postive.') + if sampling_rate <= 0: + raise ValueError('sampling_rate must be postive.') + if buffer_size <= 0: + raise ValueError('buffer_size must be postive.') + + + self._audio_buffer = [] + self._buffer_size = buffer_size + self._channels = channels + self._sampling_rate = sampling_rate + + # Create a ring buffer to store the input audio. + self._buffer = np.zeros([buffer_size, channels], dtype=float) + self._lock = threading.Lock() + + def audio_callback(data, *_): + """A callback to receive recorded audio data from sounddevice.""" + self._lock.acquire() + shift = len(data) + if shift > buffer_size: + self._buffer = np.copy(data[:buffer_size]) + else: + self._buffer = np.roll(self._buffer, -shift, axis=0) + self._buffer[-shift:, :] = np.copy(data) + self._lock.release() + + # Create an input stream to continuously capture the audio data. + self._stream = sd.InputStream( + channels=channels, + samplerate=sampling_rate, + callback=audio_callback, + ) + + @property + def channels(self) -> int: + return self._channels + + @property + def sampling_rate(self) -> int: + return self._sampling_rate + + @property + def buffer_size(self) -> int: + return self._buffer_size + + def start_recording(self) -> None: + """Starts the audio recording.""" + # Clear the internal ring buffer. + self._buffer.fill(0) + + # Start recording using sounddevice's InputStream. + self._stream.start() + + def stop(self) -> None: + """Stops the audio recording.""" + self._stream.stop() + + def read(self, size: int) -> np.ndarray: + """Reads the latest audio data captured in the buffer. + + Args: + size: Number of samples to read from the buffer. + + Returns: + A NumPy array containing the audio data. + + Raises: + ValueError: Raised if `size` is larger than the buffer size. + """ + if size > self._buffer_size: + raise ValueError('Cannot read more samples than the size of the buffer.') + elif size <= 0: + raise ValueError('Size must be positive.') + + start_index = self._buffer_size - size + return np.copy(self._buffer[start_index:]) diff --git a/examples/audio_classifier/python/audio_classification_live_stream/classify.py b/examples/audio_classifier/python/audio_classification_live_stream/classify.py new file mode 100644 index 0000000000..1da004365f --- /dev/null +++ b/examples/audio_classifier/python/audio_classification_live_stream/classify.py @@ -0,0 +1,134 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Main scripts to run audio classification.""" + +import argparse +import time +import audio_record + +from mediapipe.tasks import python +from mediapipe.tasks.python.components import containers +from mediapipe.tasks.python import audio +from utils import Plotter + + +def run(model: str, max_results: int, score_threshold: float, + overlapping_factor: float) -> None: + """Continuously run inference on audio data acquired from the device. + + Args: + model: Name of the TFLite audio classification model. + max_results: Maximum number of classification results to display. + score_threshold: The score threshold of classification results. + overlapping_factor: Target overlapping between adjacent inferences. + """ + + if (overlapping_factor <= 0) or (overlapping_factor >= 1.0): + raise ValueError('Overlapping factor must be between 0 and 1.') + + if (score_threshold < 0) or (score_threshold > 1.0): + raise ValueError('Score threshold must be between (inclusive) 0 and 1.') + + classification_result_list = [] + # Initialize a plotter instance to display the classification results. + plotter = Plotter() + + def save_result(result: audio.AudioClassifierResult, timestamp_ms: int): + result.timestamp_ms = timestamp_ms + classification_result_list.append(result) + + # Initialize the audio classification model. + base_options = python.BaseOptions(model_asset_path=model) + options = audio.AudioClassifierOptions( + base_options=base_options, running_mode=audio.RunningMode.AUDIO_STREAM, + max_results=max_results, score_threshold=score_threshold, + result_callback=save_result) + classifier = audio.AudioClassifier.create_from_options(options) + + # Initialize the audio recorder and a tensor to store the audio input. + # The sample rate may need to be changed to match your input device. + # For example, an AT2020 requires sample_rate 44100. + buffer_size, sample_rate, num_channels = 15600, 16000, 1 + audio_format = containers.AudioDataFormat(num_channels, sample_rate) + record = audio_record.AudioRecord(num_channels, sample_rate, buffer_size) + audio_data = containers.AudioData(buffer_size, audio_format) + + # We'll try to run inference every interval_between_inference seconds. + # This is usually half of the model's input length to create an overlapping + # between incoming audio segments to improve classification accuracy. + input_length_in_second = float(len( + audio_data.buffer)) / audio_data.audio_format.sample_rate + interval_between_inference = input_length_in_second * (1 - overlapping_factor) + pause_time = interval_between_inference * 0.1 + last_inference_time = time.time() + + # Start audio recording in the background. + record.start_recording() + + # Loop until the user close the classification results plot. + while True: + # Wait until at least interval_between_inference seconds has passed since + # the last inference. + now = time.time() + diff = now - last_inference_time + if diff < interval_between_inference: + time.sleep(pause_time) + continue + last_inference_time = now + + # Load the input audio from the AudioRecord instance and run classify. + data = record.read(buffer_size) + # audio_data.load_from_array(data.astype(np.float32)) + audio_data.load_from_array(data) + classifier.classify_async(audio_data, round(last_inference_time * 1000)) + + # print(classification_result_list) + # # Plot the classification results. + if classification_result_list: + print(classification_result_list) + plotter.plot(classification_result_list[0]) + classification_result_list.clear() + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--model', + help='Name of the audio classification model.', + required=False, + default='yamnet.tflite') + parser.add_argument( + '--maxResults', + help='Maximum number of results to show.', + required=False, + default=5) + parser.add_argument( + '--overlappingFactor', + help='Target overlapping between adjacent inferences. Value must be in (0, 1)', + required=False, + default=0.5) + parser.add_argument( + '--scoreThreshold', + help='The score threshold of classification results.', + required=False, + default=0.0) + args = parser.parse_args() + + run(args.model, int(args.maxResults), float(args.scoreThreshold), + float(args.overlappingFactor)) + + +if __name__ == '__main__': + main() diff --git a/examples/audio_classifier/python/audio_classification_live_stream/requirements.txt b/examples/audio_classifier/python/audio_classification_live_stream/requirements.txt new file mode 100644 index 0000000000..64283f962b --- /dev/null +++ b/examples/audio_classifier/python/audio_classification_live_stream/requirements.txt @@ -0,0 +1,2 @@ +sounddevice +mediapipe diff --git a/examples/audio_classifier/python/audio_classification_live_stream/utils.py b/examples/audio_classifier/python/audio_classification_live_stream/utils.py new file mode 100644 index 0000000000..fecb62f7dc --- /dev/null +++ b/examples/audio_classifier/python/audio_classification_live_stream/utils.py @@ -0,0 +1,68 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A module with util functions.""" +import sys + +from mediapipe.tasks.python import audio +from matplotlib import rcParams +import matplotlib.pyplot as plt + +rcParams.update({ + # Set the plot left margin so that the labels are visible. + 'figure.subplot.left': 0.3, + + # Hide the bottom toolbar. + 'toolbar': 'None' +}) + + +class Plotter(object): + """An util class to display the classification results.""" + + _PAUSE_TIME = 0.05 + """Time for matplotlib to wait for UI event.""" + + def __init__(self) -> None: + fig, self._axes = plt.subplots() + fig.canvas.manager.set_window_title('Audio classification') + + # Stop the program when the ESC key is pressed. + def event_callback(event): + if event.key == 'escape': + sys.exit(0) + + fig.canvas.mpl_connect('key_press_event', event_callback) + + plt.show(block=False) + + def plot(self, result: audio.AudioClassifierResult) -> None: + """Plot the audio classification result. + Args: + result: Classification results returned by an audio classification + model. + """ + # Clear the axes + self._axes.cla() + self._axes.set_title('Press ESC to exit.') + self._axes.set_xlim((0, 1)) + + # Plot the results so that the most probable category comes at the top. + classification = result.classifications[0] + label_list = [category.category_name + for category in classification.categories] + score_list = [category.score for category in classification.categories] + self._axes.barh(label_list[::-1], score_list[::-1]) + + # Wait for the UI event. + plt.pause(self._PAUSE_TIME)