Skip to content

Commit

Permalink
Partial implementation of StreamManager in OpenAIService
Browse files Browse the repository at this point in the history
  • Loading branch information
GabrielSCabrera committed Nov 27, 2023
1 parent 8d0334a commit 743b4bf
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 9 deletions.
16 changes: 10 additions & 6 deletions banterbot/managers/stream_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def __init__(self) -> None:
"log": threading.Lock(),
}

self._parser: Callable[[IndexedEvent], Any] = lambda x: x
self._parser: Callable[[IndexedEvent, bool], Any] = lambda x, y: x
self._parser_finalizer: Optional[Callable[[IndexedEvent], Any]] = None
self._interrupt: int = 0
self._idx: int = 0
self._idx_max: Optional[int] = None
self._log: list[StreamLogEntry] = []

self._reset()
Expand All @@ -46,12 +48,13 @@ def interrupt(self, timestamp: Optional[int] = None) -> None:
self._interrupt = timestamp
self._events["kill"].set()

def connect_parser(self, func: Callable[[IndexedEvent], Any]) -> None:
def connect_parser(self, func: Callable[[IndexedEvent, bool], Any]) -> None:
"""
Connects a parser function for processing each streamed item.
Connects a parser function for processing each streamed item. The parser function should take an IndexedEvent
and a boolean indicating whether the stream is on its final iteration.
Args:
func (Callable[[IndexedEvent], Any]): The parser function to be used.
func (Callable[[IndexedEvent, bool], Any]): The parser function to be used.
"""
self._parser = func

Expand Down Expand Up @@ -126,7 +129,7 @@ def _wrap_parser(self, timestamp: int) -> None:
Yields:
Any: The result of processing a log entry.
"""
while timestamp < self._interrupt:
while timestamp < self._interrupt and (self._idx_max is None or self._idx <= self._idx_max):
self._events["indexed"].wait()
yield self._parser(self._log[self._idx])
self._idx += 1
Expand All @@ -151,6 +154,7 @@ def _wrap_stream(self, iterable: Iterable[Any]) -> None:
Args:
iterable (Iterable[Any]): The iterable to stream data from.
"""
for value in iterable:
for n, value in enumerate(iterable):
self._append_to_log(value=value)
self._events["indexed"].increment()
self._idx_max = n - 1
68 changes: 65 additions & 3 deletions banterbot/services/openai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from banterbot.config import RETRY_LIMIT, RETRY_TIME
from banterbot.data.enums import EnvVar
from banterbot.managers.stream_manager import StreamManager
from banterbot.models.message import Message
from banterbot.models.openai_model import OpenAIModel
from banterbot.models.stream_log_entry import StreamLogEntry
from banterbot.utils.nlp import NLP

# Set the OpenAI API key
Expand Down Expand Up @@ -46,6 +48,13 @@ def __init__(self, model: OpenAIModel) -> None:
# Set the interruption flag to zero: if interruptions are raised, this will be updated.
self._interrupt: int = 0

# The text that is currently being processed by the OpenAI ChatCompletion API.
self._text: str = ""

# Initialize the StreamManager for handling streaming processes.
self._stream_manager = StreamManager()
self._stream_manager.connect_parser(self._response_parse_stream)

def count_tokens(self, string: str) -> int:
"""
Counts the number of tokens in the provided string.
Expand Down Expand Up @@ -107,8 +116,8 @@ def prompt_stream(self, messages: list[Message], **kwargs) -> Generator[tuple[st
block contains one or more sentences that form a part of the generated response. This can be used to display
the response to the user in real-time or for further processing.
"""
# Record the time at which the request was made, in order to account for future interruptions.
init_time = time.perf_counter_ns()
# Reset the state of the current instance of OpenAIService
self._reset()

# Obtain a response from the OpenAI ChatCompletion API
response = self._request(messages=messages, stream=True, **kwargs)
Expand All @@ -117,7 +126,7 @@ def prompt_stream(self, messages: list[Message], **kwargs) -> Generator[tuple[st
self._streaming = True

# Yield the responses as they are streamed
for block in self._response_parse_stream(response=response, init_time=init_time):
for block in self._response_parse_stream(response=response):
yield block

# Reset the streaming flag to False
Expand All @@ -143,6 +152,53 @@ def streaming(self) -> bool:
"""
return self._streaming

def _parser(self, entry: StreamLogEntry, final_iteration: bool) -> list[str]:
"""
Parses a chunk of data from the OpenAI API response.
Args:
entry (StreamLogEntry): A log entry from the OpenAI API response.
final_iteration (bool): Whether the current chunk is the final chunk of data from the OpenAI API response.
Returns:
list[str]: A list of sentences parsed from the chunk.
"""
if "content" in entry.value["choices"][0]["delta"].keys():
self._text += entry.value["choices"][0]["delta"]["content"]

# If the current chunk is the final chunk of data from the OpenAI API response, parse the final chunk.
if final_iteration:
sentences = NLP.segment_sentences(self._text)
logging.debug(f"OpenAIService yielded final sentences: {sentences[:-1]}")
logging.debug("OpenAIService stream stopped")
return sentences

# If the current chunk is not the final chunk of data from the OpenAI API response, parse the chunk.
elif len(sentences := NLP.segment_sentences(self._text)) > 1:
self._text = sentences[-1]
logging.debug(f"OpenAIService yielded sentences: {sentences[:-1]}")
return sentences[:-1]

def _parser_finalizer(self, chunk: openai.openai_object.OpenAIObject) -> list[str]:
"""
Parses the final chunk of data from the OpenAI API response.
Args:
chunk (openai.openai_object.OpenAIObject): The final chunk of data from the OpenAI API response.
Returns:
list[str]: A list of sentences parsed from the chunk.
"""
text = ""
delta = chunk["choices"][0]["delta"]

if "content" in delta.keys():
text += delta["content"]

sentences = NLP.segment_sentences(text)
logging.debug(f"OpenAIService yielded final sentences: {sentences[:-1]}")
return sentences

def _response_parse_stream(self, response: Iterator, init_time: int) -> Generator[list[str], None, None]:
"""
Parses a streaming response from the OpenAI API and yields blocks of text as they are received.
Expand Down Expand Up @@ -233,3 +289,9 @@ def _request(self, messages: list[Message], stream: bool, **kwargs) -> Union[Ite
raise openai.error.APIError

return response if stream else response.choices[0].message.content.strip()

def _reset(self) -> None:
"""
Resets the state of the current instance of OpenAIService.
"""
self._text = ""

0 comments on commit 743b4bf

Please sign in to comment.