From 5fdf4be53c4b0bf198f1d58cdb1cfcbaa12add99 Mon Sep 17 00:00:00 2001 From: GabrielSCabrera Date: Sat, 2 Dec 2023 19:45:30 +0100 Subject: [PATCH] Fixing bugs --- banterbot/extensions/interface.py | 11 ++-- banterbot/gui/tk_interface.py | 5 +- .../handlers/speech_synthesis_handler.py | 12 ++-- banterbot/handlers/stream_handler.py | 2 +- banterbot/managers/stream_manager.py | 31 ++++------ banterbot/models/__init__.py | 2 + banterbot/{utils => models}/number.py | 0 .../services/speech_recognition_service.py | 2 +- .../services/speech_synthesis_service.py | 44 ++++++------- banterbot/utils/__init__.py | 3 +- banterbot/utils/closeable_queue.py | 1 + banterbot/utils/indexed_event.py | 61 +++++++++++-------- docs/banterbot.models.rst | 8 +++ docs/banterbot.utils.rst | 8 --- 14 files changed, 98 insertions(+), 92 deletions(-) rename banterbot/{utils => models}/number.py (100%) diff --git a/banterbot/extensions/interface.py b/banterbot/extensions/interface.py index 1940cc3..b00abaf 100644 --- a/banterbot/extensions/interface.py +++ b/banterbot/extensions/interface.py @@ -13,6 +13,7 @@ from banterbot.models.azure_neural_voice_profile import AzureNeuralVoiceProfile from banterbot.models.message import Message from banterbot.models.openai_model import OpenAIModel +from banterbot.models.word import Word from banterbot.paths import chat_logs from banterbot.services.openai_service import OpenAIService from banterbot.services.speech_recognition_service import SpeechRecognitionService @@ -303,15 +304,13 @@ def respond(self, init_time: int) -> None: # Initialize the generator for asynchronous yielding of sentence blocks for block in open_ai_stream: phrases, context = self._prosody_selector.select(sentences=block, context=content, system=self._system) - if phrases is None: raise FormatMismatchError() - print("PHRASES") - print(phrases) - for word in self._speech_synthesis_service.synthesize(phrases=phrases, init_time=init_time): - self.update_conversation_area(word.word) - content.append(word.word) + for item in self._speech_synthesis_service.synthesize(phrases=phrases, init_time=init_time): + print("SYNTHESIS: ", item) + self.update_conversation_area(item.value.word) + content.append(item.value.word) self.update_conversation_area(" ") content.append(" ") diff --git a/banterbot/gui/tk_interface.py b/banterbot/gui/tk_interface.py index 8e07d29..2235113 100644 --- a/banterbot/gui/tk_interface.py +++ b/banterbot/gui/tk_interface.py @@ -1,5 +1,6 @@ import logging import threading +import time import tkinter as tk from tkinter import ttk from typing import Optional, Union @@ -70,7 +71,9 @@ def listener_activate(self, idx: int) -> None: def request_response(self) -> None: if self._messages: # Interrupt any currently active ChatCompletion, text-to-speech, or speech-to-text streams - self._thread_queue.add_task(threading.Thread(target=self.respond, daemon=True)) + self._thread_queue.add_task( + threading.Thread(target=self.respond, kwargs={"init_time": time.perf_counter_ns()}, daemon=True) + ) def run(self, greet: bool = False) -> None: """ diff --git a/banterbot/handlers/speech_synthesis_handler.py b/banterbot/handlers/speech_synthesis_handler.py index 176f47b..48d346e 100644 --- a/banterbot/handlers/speech_synthesis_handler.py +++ b/banterbot/handlers/speech_synthesis_handler.py @@ -55,21 +55,17 @@ def __iter__(self) -> Generator[Word, None, None]: self._synthesizer.speak_ssml_async(self._ssml) logging.debug("SpeechSynthesisHandler synthesizer started") - print("SSML") - print(self._ssml) - # Process the words as they are synthesized. - for stream_log_entry in self._queue: - word = stream_log_entry.value + for item in self._queue: # Determine if a delay is needed to match the word's offset. - dt = 1e-9 * (word["time"] - time.perf_counter_ns()) + dt = 1e-9 * (item["time"] - time.perf_counter_ns()) # If a delay is needed, wait for the specified time. if dt > 0: time.sleep(dt) # Yield the word. - yield word - logging.debug(f"SpeechSynthesisHandler yielded word: `{word['text']}`") + yield item["word"] + logging.debug(f"SpeechSynthesisHandler yielded word: `{item['word']}`") def close(self): self._synthesizer.stop_speaking() diff --git a/banterbot/handlers/stream_handler.py b/banterbot/handlers/stream_handler.py index 9ace366..d48aee8 100644 --- a/banterbot/handlers/stream_handler.py +++ b/banterbot/handlers/stream_handler.py @@ -2,8 +2,8 @@ import threading import time +from banterbot.models.number import Number from banterbot.utils.closeable_queue import CloseableQueue -from banterbot.utils.number import Number class StreamHandler: diff --git a/banterbot/managers/stream_manager.py b/banterbot/managers/stream_manager.py index 5fc6228..4984668 100644 --- a/banterbot/managers/stream_manager.py +++ b/banterbot/managers/stream_manager.py @@ -7,10 +7,10 @@ from typing import Any, Optional from banterbot.handlers.stream_handler import StreamHandler +from banterbot.models.number import Number from banterbot.models.stream_log_entry import StreamLogEntry from banterbot.utils.closeable_queue import CloseableQueue from banterbot.utils.indexed_event import IndexedEvent -from banterbot.utils.number import Number class StreamManager: @@ -18,14 +18,13 @@ class StreamManager: Manages streaming of data through threads and allows hard or soft interruption of the streamed data. """ - def __init__(self, lock: Optional[threading.Lock] = None) -> None: + def __init__(self) -> None: """ Initializes the StreamManager with default values. """ self._processor: Callable[[IndexedEvent, int, dict], Any] = lambda log, index, shared_data: log[index] self._exception_handler: Optional[Callable[[IndexedEvent, dict], Any]] = None self._completion_handler: Callable[[IndexedEvent, int, dict], Any] = None - self._lock = lock def connect_processor(self, func: Callable[[list[StreamLogEntry], int, dict], Any]) -> None: """ @@ -101,7 +100,6 @@ def stream( # Creating the interrupt, index, and index_max values to be used. interrupt = Number(value=0) - index_max = Number(None) # Creating the queue and log to be used. queue = CloseableQueue() @@ -122,7 +120,6 @@ def stream( stream_thread = threading.Thread( target=self._wrap_stream, kwargs={ - "index_max": index_max, "indexed_event": indexed_event, "kill_event": kill_event, "log": log, @@ -138,8 +135,8 @@ def stream( kwargs={ "timestamp": timestamp, "interrupt": interrupt, - "index_max": index_max, "indexed_event": indexed_event, + "kill_event": kill_event, "queue": queue, "log": log, "processor": self._processor, @@ -164,7 +161,6 @@ def stream( def _wrap_stream( self, - index_max: Number, indexed_event: IndexedEvent, kill_event: threading.Event, log: list[StreamLogEntry], @@ -175,7 +171,6 @@ def _wrap_stream( Wraps the `_stream` thread to allow for instant interruption using the `kill` event. Args: - index_max (Number): The maximum index to stream to. indexed_event (IndexedEvent): The indexed event to use for tracking the current index. kill_event (threading.Event): The event to use for interrupting the stream. log (list[StreamLogEntry]): The log to store streamed data in. @@ -186,7 +181,6 @@ def _wrap_stream( thread = threading.Thread( target=self._stream, kwargs={ - "index_max": index_max, "indexed_event": indexed_event, "kill_event": kill_event, "log": log, @@ -207,7 +201,6 @@ def _wrap_stream( def _stream( self, - index_max: Number, indexed_event: IndexedEvent, kill_event: threading.Event, log: list[StreamLogEntry], @@ -223,19 +216,18 @@ def _stream( log (list[StreamLogEntry]): The log to store streamed data in. iterable (Iterable[Any]): The iterable to stream data from. """ - for n, value in enumerate(iterable): + for value in iterable: + print("STREAM: ", value) log.append(StreamLogEntry(value=value)) indexed_event.increment() - index_max.set(n - 1) - indexed_event.increment() kill_event.set() def _wrap_processor( self, timestamp: float, interrupt: Number, - index_max: Number, indexed_event: IndexedEvent, + kill_event: threading.Event, queue: CloseableQueue, log: list[StreamLogEntry], processor: Callable[[list[StreamLogEntry], int, dict], Any], @@ -249,8 +241,8 @@ def _wrap_processor( Args: timestamp (float): The timestamp of the stream. interrupt (Number): The interrupt time of the stream. - index_max (Number): The maximum index to stream to. indexed_event (IndexedEvent): The indexed event to use for tracking the current index. + kill_event (threading.Event): The event to use for interrupting the stream. queue (CloseableQueue): The queue to store processed data in. log (list[StreamLogEntry]): The log to store streamed data in. stream_processor (Callable[[list[StreamLogEntry], int, dict], Any]): The stream processor function to @@ -262,11 +254,13 @@ def _wrap_processor( shared_data (Optional[dict[str, Any]]): The shared data to be used. """ index = 0 - while interrupt < timestamp and (index_max.is_null() or index < index_max): + while interrupt < timestamp and (not kill_event.is_set() or index < len(log)): indexed_event.wait() - if not index_max.is_null() and index >= index_max: + indexed_event.decrement() + if kill_event.is_set() and index >= len(log): continue try: + print(log[index]) output = processor(log=log, index=index, shared_data=shared_data) if output is not None: queue.put(output) @@ -289,6 +283,3 @@ def _wrap_processor( break index += 1 queue.close() - - if self._lock and self._lock.locked(): - self._lock.release() diff --git a/banterbot/models/__init__.py b/banterbot/models/__init__.py index b1c0fc7..7cb30ef 100644 --- a/banterbot/models/__init__.py +++ b/banterbot/models/__init__.py @@ -1,6 +1,7 @@ from banterbot.models.azure_neural_voice_profile import AzureNeuralVoiceProfile from banterbot.models.memory import Memory from banterbot.models.message import Message +from banterbot.models.number import Number from banterbot.models.openai_model import OpenAIModel from banterbot.models.phrase import Phrase from banterbot.models.speech_recognition_input import SpeechRecognitionInput @@ -11,6 +12,7 @@ "AzureNeuralVoiceProfile", "Memory", "Message", + "Number", "OpenAIModel", "Phrase", "SpeechRecognitionInput", diff --git a/banterbot/utils/number.py b/banterbot/models/number.py similarity index 100% rename from banterbot/utils/number.py rename to banterbot/models/number.py diff --git a/banterbot/services/speech_recognition_service.py b/banterbot/services/speech_recognition_service.py index 8828be4..677804e 100644 --- a/banterbot/services/speech_recognition_service.py +++ b/banterbot/services/speech_recognition_service.py @@ -50,7 +50,7 @@ def __init__( self._last_total_offset = 0 # Initialize the `StreamManager` for handling streaming processes. - self._stream_manager = StreamManager(lock=self.__class__._recognition_lock) + self._stream_manager = StreamManager() # Indicates whether the current instance of `SpeechRecognitionService` is listening. self._recognizing = False diff --git a/banterbot/services/speech_synthesis_service.py b/banterbot/services/speech_synthesis_service.py index b63c992..4a4a054 100644 --- a/banterbot/services/speech_synthesis_service.py +++ b/banterbot/services/speech_synthesis_service.py @@ -42,7 +42,7 @@ def __init__( self._init_synthesizer(output_format=output_format) # Initialize the StreamManager for handling streaming processes. - self._stream_manager = StreamManager(lock=self.__class__._synthesis_lock) + self._stream_manager = StreamManager() # Indicates whether the current instance of `SpeechSynthesisService` is speaking. self._speaking = False @@ -81,6 +81,7 @@ def synthesize(self, phrases: list[Phrase], init_time: Optional[int] = None) -> with self.__class__._synthesis_lock: self._queue = CloseableQueue() + self._first_word = True iterable = SpeechSynthesisHandler(phrases=phrases, synthesizer=self._synthesizer, queue=self._queue) handler = self._stream_manager.stream(iterable=iterable, close_stream=iterable.close) @@ -147,26 +148,27 @@ def _callback_word_boundary(self, event: speechsdk.SessionEventArgs) -> None: # Check if the type is not a sentence boundary if event.boundary_type != speechsdk.SpeechSynthesisBoundaryType.Sentence: # Add the event and timing information to the list of events - self._queue.put( - StreamLogEntry( - { - "event": event, - "time": ( - self._start_synthesis_time - + 5e8 - + 100 * event.audio_offset - + 1e9 * event.duration.total_seconds() / event.word_length - ), - "word": Word( - word=event.text if self._queue.empty() else " " + event.text, - offset=datetime.timedelta(microseconds=event.audio_offset / 10), - duration=event.duration, - category=event.boundary_type, - source=SpeechProcessingType.TTS, - ), - } - ) - ) + self._queue.put({ + "event": event, + "time": ( + self._start_synthesis_time + + 5e8 + + 100 * event.audio_offset + + 1e9 * event.duration.total_seconds() / event.word_length + ), + "word": Word( + word=( + event.text + if event.boundary_type == speechsdk.SpeechSynthesisBoundaryType.Word and self._first_word + else " " + event.text + ), + offset=datetime.timedelta(microseconds=event.audio_offset / 10), + duration=event.duration, + category=event.boundary_type, + source=SpeechProcessingType.TTS, + ), + }) + self._first_word = False def _callbacks_connect(self): """ diff --git a/banterbot/utils/__init__.py b/banterbot/utils/__init__.py index a9943ec..01fa60d 100644 --- a/banterbot/utils/__init__.py +++ b/banterbot/utils/__init__.py @@ -1,8 +1,7 @@ from banterbot.utils.closeable_queue import CloseableQueue from banterbot.utils.indexed_event import IndexedEvent from banterbot.utils.nlp import NLP -from banterbot.utils.number import Number from banterbot.utils.thread_queue import ThreadQueue from banterbot.utils.time_resolver import TimeResolver -__all__ = ["CloseableQueue", "IndexedEvent", "NLP", "Number", "ThreadQueue", "TimeResolver"] +__all__ = ["CloseableQueue", "IndexedEvent", "NLP", "ThreadQueue", "TimeResolver"] diff --git a/banterbot/utils/closeable_queue.py b/banterbot/utils/closeable_queue.py index d8c32a9..88a717d 100644 --- a/banterbot/utils/closeable_queue.py +++ b/banterbot/utils/closeable_queue.py @@ -59,6 +59,7 @@ def __iter__(self) -> Self: while not self.finished(): self._indexed_event.wait() + self._indexed_event.decrement() if not self.empty(): yield super().get() diff --git a/banterbot/utils/indexed_event.py b/banterbot/utils/indexed_event.py index 8da1211..4c9d053 100644 --- a/banterbot/utils/indexed_event.py +++ b/banterbot/utils/indexed_event.py @@ -84,7 +84,37 @@ def increment(self, N: int = 1) -> None: with self._lock: self._counter += N - super().set() + if self._counter > 0: + super().set() + + def decrement(self, N: int = 1) -> None: + """ + Decrements the counter by a specified amount. It also clears the event if zero is reached, blocking the + consumer. + + Args: + N (int): The amount to decrement the counter by. Must be non-negative. + + Raises: + ValueError: If N is less than 1 or N is not a number. + """ + if N < 0 or not isinstance(N, int): + raise ValueError( + "Argument `N` in class `IndexedEvent` method `decrement(N: int)` must be a non-negative integer." + ) + + with self._lock: + if self._counter - N < 0: + raise ValueError( + "Argument `N` in class `IndexedEvent` method `decrement(N: int)` must be less than or equal to the" + f" current counter value ({self._counter})." + ) + + self._counter -= N + if self._counter > 0: + super().set() + else: + super().clear() def is_set(self) -> bool: """ @@ -105,30 +135,13 @@ def set(self, N: int = 1) -> None: Raises: ValueError: If N is less than 1 or N is not a number. """ - if N < 0 or not isinstance(N, int): - raise ValueError( - "Argument `N` in class `IndexedEvent` method `set(N: int)` must be a non-negative integer." - ) with self._lock: - self._counter = N - if self._counter > 0: + if N < 0 or not isinstance(N, int): + raise ValueError( + "Argument `N` in class `IndexedEvent` method `set(N: int)` must be a non-negative integer." + ) + elif N > 0: super().set() else: super().clear() - - def wait(self, timeout: float = None) -> bool: - """ - Waits for the event to be set (data to be available), then decrements the counter, indicating a data chunk has - been processed. If the counter reaches zero, indicating no more data, the event is cleared. - - Args: - timeout (float): Optional timeout for the wait. - - Returns: - bool: True if the event was set (data was processed), False if it timed out. - """ - with self._lock: - self._counter = self._counter - 1 if self._counter > 0 else self._counter - if self._counter <= 0: - return super().wait(timeout) - return True + self._counter = N diff --git a/docs/banterbot.models.rst b/docs/banterbot.models.rst index 2da8ed6..536d3da 100644 --- a/docs/banterbot.models.rst +++ b/docs/banterbot.models.rst @@ -25,6 +25,14 @@ banterbot.models.message module :undoc-members: :show-inheritance: +banterbot.models.number module +----------------------------- + +.. automodule:: banterbot.models.number + :members: + :undoc-members: + :show-inheritance: + banterbot.models.openai\_model module ------------------------------------- diff --git a/docs/banterbot.utils.rst b/docs/banterbot.utils.rst index 7bcc353..ac68bb8 100644 --- a/docs/banterbot.utils.rst +++ b/docs/banterbot.utils.rst @@ -25,14 +25,6 @@ banterbot.utils.nlp module :undoc-members: :show-inheritance: -banterbot.utils.number module ------------------------------ - -.. automodule:: banterbot.utils.number - :members: - :undoc-members: - :show-inheritance: - banterbot.utils.thread\_queue module ------------------------------------