Skip to content

Commit

Permalink
Fixing bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
GabrielSCabrera committed Dec 2, 2023
1 parent 570da7a commit 5fdf4be
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 92 deletions.
11 changes: 5 additions & 6 deletions banterbot/extensions/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from banterbot.models.azure_neural_voice_profile import AzureNeuralVoiceProfile
from banterbot.models.message import Message
from banterbot.models.openai_model import OpenAIModel
from banterbot.models.word import Word
from banterbot.paths import chat_logs
from banterbot.services.openai_service import OpenAIService
from banterbot.services.speech_recognition_service import SpeechRecognitionService
Expand Down Expand Up @@ -303,15 +304,13 @@ def respond(self, init_time: int) -> None:
# Initialize the generator for asynchronous yielding of sentence blocks
for block in open_ai_stream:
phrases, context = self._prosody_selector.select(sentences=block, context=content, system=self._system)

if phrases is None:
raise FormatMismatchError()

print("PHRASES")
print(phrases)
for word in self._speech_synthesis_service.synthesize(phrases=phrases, init_time=init_time):
self.update_conversation_area(word.word)
content.append(word.word)
for item in self._speech_synthesis_service.synthesize(phrases=phrases, init_time=init_time):
print("SYNTHESIS: ", item)
self.update_conversation_area(item.value.word)
content.append(item.value.word)

self.update_conversation_area(" ")
content.append(" ")
Expand Down
5 changes: 4 additions & 1 deletion banterbot/gui/tk_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import threading
import time
import tkinter as tk
from tkinter import ttk
from typing import Optional, Union
Expand Down Expand Up @@ -70,7 +71,9 @@ def listener_activate(self, idx: int) -> None:
def request_response(self) -> None:
if self._messages:
# Interrupt any currently active ChatCompletion, text-to-speech, or speech-to-text streams
self._thread_queue.add_task(threading.Thread(target=self.respond, daemon=True))
self._thread_queue.add_task(
threading.Thread(target=self.respond, kwargs={"init_time": time.perf_counter_ns()}, daemon=True)
)

def run(self, greet: bool = False) -> None:
"""
Expand Down
12 changes: 4 additions & 8 deletions banterbot/handlers/speech_synthesis_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,17 @@ def __iter__(self) -> Generator[Word, None, None]:
self._synthesizer.speak_ssml_async(self._ssml)
logging.debug("SpeechSynthesisHandler synthesizer started")

print("SSML")
print(self._ssml)

# Process the words as they are synthesized.
for stream_log_entry in self._queue:
word = stream_log_entry.value
for item in self._queue:
# Determine if a delay is needed to match the word's offset.
dt = 1e-9 * (word["time"] - time.perf_counter_ns())
dt = 1e-9 * (item["time"] - time.perf_counter_ns())
# If a delay is needed, wait for the specified time.
if dt > 0:
time.sleep(dt)

# Yield the word.
yield word
logging.debug(f"SpeechSynthesisHandler yielded word: `{word['text']}`")
yield item["word"]
logging.debug(f"SpeechSynthesisHandler yielded word: `{item['word']}`")

def close(self):
self._synthesizer.stop_speaking()
Expand Down
2 changes: 1 addition & 1 deletion banterbot/handlers/stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import threading
import time

from banterbot.models.number import Number
from banterbot.utils.closeable_queue import CloseableQueue
from banterbot.utils.number import Number


class StreamHandler:
Expand Down
31 changes: 11 additions & 20 deletions banterbot/managers/stream_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,24 @@
from typing import Any, Optional

from banterbot.handlers.stream_handler import StreamHandler
from banterbot.models.number import Number
from banterbot.models.stream_log_entry import StreamLogEntry
from banterbot.utils.closeable_queue import CloseableQueue
from banterbot.utils.indexed_event import IndexedEvent
from banterbot.utils.number import Number


class StreamManager:
"""
Manages streaming of data through threads and allows hard or soft interruption of the streamed data.
"""

def __init__(self, lock: Optional[threading.Lock] = None) -> None:
def __init__(self) -> None:
"""
Initializes the StreamManager with default values.
"""
self._processor: Callable[[IndexedEvent, int, dict], Any] = lambda log, index, shared_data: log[index]
self._exception_handler: Optional[Callable[[IndexedEvent, dict], Any]] = None
self._completion_handler: Callable[[IndexedEvent, int, dict], Any] = None
self._lock = lock

def connect_processor(self, func: Callable[[list[StreamLogEntry], int, dict], Any]) -> None:
"""
Expand Down Expand Up @@ -101,7 +100,6 @@ def stream(

# Creating the interrupt, index, and index_max values to be used.
interrupt = Number(value=0)
index_max = Number(None)

# Creating the queue and log to be used.
queue = CloseableQueue()
Expand All @@ -122,7 +120,6 @@ def stream(
stream_thread = threading.Thread(
target=self._wrap_stream,
kwargs={
"index_max": index_max,
"indexed_event": indexed_event,
"kill_event": kill_event,
"log": log,
Expand All @@ -138,8 +135,8 @@ def stream(
kwargs={
"timestamp": timestamp,
"interrupt": interrupt,
"index_max": index_max,
"indexed_event": indexed_event,
"kill_event": kill_event,
"queue": queue,
"log": log,
"processor": self._processor,
Expand All @@ -164,7 +161,6 @@ def stream(

def _wrap_stream(
self,
index_max: Number,
indexed_event: IndexedEvent,
kill_event: threading.Event,
log: list[StreamLogEntry],
Expand All @@ -175,7 +171,6 @@ def _wrap_stream(
Wraps the `_stream` thread to allow for instant interruption using the `kill` event.
Args:
index_max (Number): The maximum index to stream to.
indexed_event (IndexedEvent): The indexed event to use for tracking the current index.
kill_event (threading.Event): The event to use for interrupting the stream.
log (list[StreamLogEntry]): The log to store streamed data in.
Expand All @@ -186,7 +181,6 @@ def _wrap_stream(
thread = threading.Thread(
target=self._stream,
kwargs={
"index_max": index_max,
"indexed_event": indexed_event,
"kill_event": kill_event,
"log": log,
Expand All @@ -207,7 +201,6 @@ def _wrap_stream(

def _stream(
self,
index_max: Number,
indexed_event: IndexedEvent,
kill_event: threading.Event,
log: list[StreamLogEntry],
Expand All @@ -223,19 +216,18 @@ def _stream(
log (list[StreamLogEntry]): The log to store streamed data in.
iterable (Iterable[Any]): The iterable to stream data from.
"""
for n, value in enumerate(iterable):
for value in iterable:
print("STREAM: ", value)
log.append(StreamLogEntry(value=value))
indexed_event.increment()
index_max.set(n - 1)
indexed_event.increment()
kill_event.set()

def _wrap_processor(
self,
timestamp: float,
interrupt: Number,
index_max: Number,
indexed_event: IndexedEvent,
kill_event: threading.Event,
queue: CloseableQueue,
log: list[StreamLogEntry],
processor: Callable[[list[StreamLogEntry], int, dict], Any],
Expand All @@ -249,8 +241,8 @@ def _wrap_processor(
Args:
timestamp (float): The timestamp of the stream.
interrupt (Number): The interrupt time of the stream.
index_max (Number): The maximum index to stream to.
indexed_event (IndexedEvent): The indexed event to use for tracking the current index.
kill_event (threading.Event): The event to use for interrupting the stream.
queue (CloseableQueue): The queue to store processed data in.
log (list[StreamLogEntry]): The log to store streamed data in.
stream_processor (Callable[[list[StreamLogEntry], int, dict], Any]): The stream processor function to
Expand All @@ -262,11 +254,13 @@ def _wrap_processor(
shared_data (Optional[dict[str, Any]]): The shared data to be used.
"""
index = 0
while interrupt < timestamp and (index_max.is_null() or index < index_max):
while interrupt < timestamp and (not kill_event.is_set() or index < len(log)):
indexed_event.wait()
if not index_max.is_null() and index >= index_max:
indexed_event.decrement()
if kill_event.is_set() and index >= len(log):
continue
try:
print(log[index])
output = processor(log=log, index=index, shared_data=shared_data)
if output is not None:
queue.put(output)
Expand All @@ -289,6 +283,3 @@ def _wrap_processor(
break
index += 1
queue.close()

if self._lock and self._lock.locked():
self._lock.release()
2 changes: 2 additions & 0 deletions banterbot/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from banterbot.models.azure_neural_voice_profile import AzureNeuralVoiceProfile
from banterbot.models.memory import Memory
from banterbot.models.message import Message
from banterbot.models.number import Number
from banterbot.models.openai_model import OpenAIModel
from banterbot.models.phrase import Phrase
from banterbot.models.speech_recognition_input import SpeechRecognitionInput
Expand All @@ -11,6 +12,7 @@
"AzureNeuralVoiceProfile",
"Memory",
"Message",
"Number",
"OpenAIModel",
"Phrase",
"SpeechRecognitionInput",
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion banterbot/services/speech_recognition_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self._last_total_offset = 0

# Initialize the `StreamManager` for handling streaming processes.
self._stream_manager = StreamManager(lock=self.__class__._recognition_lock)
self._stream_manager = StreamManager()

# Indicates whether the current instance of `SpeechRecognitionService` is listening.
self._recognizing = False
Expand Down
44 changes: 23 additions & 21 deletions banterbot/services/speech_synthesis_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self._init_synthesizer(output_format=output_format)

# Initialize the StreamManager for handling streaming processes.
self._stream_manager = StreamManager(lock=self.__class__._synthesis_lock)
self._stream_manager = StreamManager()

# Indicates whether the current instance of `SpeechSynthesisService` is speaking.
self._speaking = False
Expand Down Expand Up @@ -81,6 +81,7 @@ def synthesize(self, phrases: list[Phrase], init_time: Optional[int] = None) ->

with self.__class__._synthesis_lock:
self._queue = CloseableQueue()
self._first_word = True

iterable = SpeechSynthesisHandler(phrases=phrases, synthesizer=self._synthesizer, queue=self._queue)
handler = self._stream_manager.stream(iterable=iterable, close_stream=iterable.close)
Expand Down Expand Up @@ -147,26 +148,27 @@ def _callback_word_boundary(self, event: speechsdk.SessionEventArgs) -> None:
# Check if the type is not a sentence boundary
if event.boundary_type != speechsdk.SpeechSynthesisBoundaryType.Sentence:
# Add the event and timing information to the list of events
self._queue.put(
StreamLogEntry(
{
"event": event,
"time": (
self._start_synthesis_time
+ 5e8
+ 100 * event.audio_offset
+ 1e9 * event.duration.total_seconds() / event.word_length
),
"word": Word(
word=event.text if self._queue.empty() else " " + event.text,
offset=datetime.timedelta(microseconds=event.audio_offset / 10),
duration=event.duration,
category=event.boundary_type,
source=SpeechProcessingType.TTS,
),
}
)
)
self._queue.put({
"event": event,
"time": (
self._start_synthesis_time
+ 5e8
+ 100 * event.audio_offset
+ 1e9 * event.duration.total_seconds() / event.word_length
),
"word": Word(
word=(
event.text
if event.boundary_type == speechsdk.SpeechSynthesisBoundaryType.Word and self._first_word
else " " + event.text
),
offset=datetime.timedelta(microseconds=event.audio_offset / 10),
duration=event.duration,
category=event.boundary_type,
source=SpeechProcessingType.TTS,
),
})
self._first_word = False

def _callbacks_connect(self):
"""
Expand Down
3 changes: 1 addition & 2 deletions banterbot/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from banterbot.utils.closeable_queue import CloseableQueue
from banterbot.utils.indexed_event import IndexedEvent
from banterbot.utils.nlp import NLP
from banterbot.utils.number import Number
from banterbot.utils.thread_queue import ThreadQueue
from banterbot.utils.time_resolver import TimeResolver

__all__ = ["CloseableQueue", "IndexedEvent", "NLP", "Number", "ThreadQueue", "TimeResolver"]
__all__ = ["CloseableQueue", "IndexedEvent", "NLP", "ThreadQueue", "TimeResolver"]
1 change: 1 addition & 0 deletions banterbot/utils/closeable_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __iter__(self) -> Self:

while not self.finished():
self._indexed_event.wait()
self._indexed_event.decrement()
if not self.empty():
yield super().get()

Expand Down
Loading

0 comments on commit 5fdf4be

Please sign in to comment.