diff --git a/nemoguardrails/rails/llm/buffer.py b/nemoguardrails/rails/llm/buffer.py index e3e299f90..30e48c4e3 100644 --- a/nemoguardrails/rails/llm/buffer.py +++ b/nemoguardrails/rails/llm/buffer.py @@ -14,95 +14,372 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, Tuple +from typing import AsyncGenerator, List, NamedTuple from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig +__all__ = ["ChunkBatch", "BufferStrategy", "RollingBuffer", "get_buffer_strategy"] + + +class ChunkBatch(NamedTuple): + """Represents a batch of processed chunks from a buffer strategy. + + This class contains the raw chunk data from buffer processing. For string + representation of chunks, use the buffer strategy's format_chunks() method. + + Attributes: + processing_context (List[str]): Chunks to be used for output rails processing, + including context from previous chunks. + user_output_chunks (List[str]): New chunks to be streamed to the end user + in their original token format. Use this for user output or when you + only need the newly processed content. + + Example: + >>> async for chunk_batch in buffer_strategy.process_stream(handler): + ... # for output rails processing (needs context): + ... context_str = buffer_strategy.format_chunks(chunk_batch.processing_context) + ... analyze_content(context_str) + ... + ... # for user output (only new content): + ... user_output = buffer_strategy.format_chunks(chunk_batch.user_output_chunks) + ... yield_to_user(user_output) + ... + ... # or iterate over raw chunks: + ... for chunk in chunk_batch.user_output_chunks: + ... process_individual_chunk(chunk) + """ + + processing_context: List[str] + user_output_chunks: List[str] + class BufferStrategy(ABC): + """Abstract base class for buffer strategies in streaming output rails. + + This class defines the interface for buffer strategies that manage how + streaming chunks are buffered and processed for output rails. + Concrete implementations should handle the accumulation and yielding of + chunks in a way that optimizes output rails processing while maintaining + streaming performance. + + The interface separates concerns: + - Buffer management logic (process_stream) + - Chunk representation formatting (format_chunks) + + Note: + All concrete implementations must implement `from_config`, `process_stream`, + and `format_chunks` methods to provide configuration-based + instantiation, chunk processing, and string representation capabilities. + """ + @classmethod @abstractmethod def from_config(cls, config: OutputRailsStreamingConfig) -> "BufferStrategy": - pass + """Create a buffer strategy instance from configuration. + + Args: + config (OutputRailsStreamingConfig): Configuration object containing + buffer strategy parameters. + + Returns: + BufferStrategy: A configured buffer strategy instance. + + """ + ... - # The abstract method is not async to ensure the return type - # matches the async generator in the concrete implementation. @abstractmethod - def __call__( - self, streaming_handler - ) -> AsyncGenerator[Tuple[List[str], str], None]: - pass + def format_chunks(self, chunks: List[str]) -> str: + """Format chunks into a string representation for user consumption. + + This method defines how chunks should be formatted into a string + representation. Different strategies might join chunks differently + (e.g., preserving spaces, adding separators, etc.). + + Args: + chunks (List[str]): List of chunk tokens to be formatted. + + Returns: + str: String representation of the chunks ready for consumers. + + + Example: + >>> strategy = SomeBufferStrategy() + >>> chunks = ["Hello", " ", "world"] + >>> result = strategy.format_chunks(chunks) + >>> print(result) # "Hello world" + """ + ... @abstractmethod - def generate_chunk_str(self, *args, **kwargs) -> str: - pass + async def process_stream( + self, streaming_handler + ) -> AsyncGenerator[ChunkBatch, None]: + """Process streaming chunks and yield chunk batches. + + This is the main method that concrete buffer strategies must implement. + It defines how chunks from the streaming handler should be buffered, + processed, and yielded as ChunkBatch objects. + + Args: + streaming_handler: An async iterator that yields individual string + chunks from the LLM stream. + + Yields: + ChunkBatch: Named tuple containing processing_context and user_output_chunks. + + + Example: + >>> strategy = SomeBufferStrategy() + >>> async for chunk_batch in strategy.process_stream(handler): + ... # for output rails processing (needs context): + ... context_formatted = strategy.format_chunks(chunk_batch.processing_context) + ... # for user output (new content only): + ... user_formatted = strategy.format_chunks(chunk_batch.user_output_chunks) + ... print(f"Processing: {context_formatted}") + ... print(f"User: {user_formatted}") + """ + ... + + async def __call__(self, streaming_handler) -> AsyncGenerator[ChunkBatch, None]: + """Callable interface that delegates to process_stream. + + It delegates to the `process_stream` method and can + be extended to add common functionality like validation, logging, + or error handling. + + Args: + streaming_handler: An async iterator that yields individual string + chunks from the LLM stream. + + Yields: + ChunkBatch: Named tuple containing processing_context and user_output_chunks. + + Example: + >>> strategy = SomeBufferStrategy() + >>> # both of these work: + >>> async for batch in strategy.process_stream(handler): + ... context_formatted = strategy.format_chunks(batch.processing_context) + >>> async for batch in strategy(handler): # delegates to process_stream + ... user_formatted = strategy.format_chunks(batch.user_output_chunks) + """ + async for chunk_batch in self.process_stream(streaming_handler): + yield chunk_batch class RollingBuffer(BufferStrategy): - """A minimal buffer strategy that buffers chunks and yields them when the buffer is full. + """A rolling buffer strategy for streaming output rails processing. + + This strategy accumulates incoming chunks in a buffer and yields them in + batches when the buffer reaches the specified chunk size. It maintains + context from previous chunks to ensure continuity in processing output rails. + + The buffer operates by: + 1. Accumulating incoming chunks until reaching the chunk size threshold + 2. Yielding a processing buffer (with context) and new chunks to process + 3. Retaining context tokens for the next processing round + 4. Yielding any remaining chunks at the end of the stream Args: - buffer_context_size (int): The number of tokens carried over from the previous chunk to provide context for continuity in processing. - buffer_chunk_size (int): The number of tokens in each processing chunk. This is the size of the token block on which output rails are applied. + buffer_context_size (int, optional): Number of tokens carried over from + previous chunks to provide context for continuity. Defaults to 5. + buffer_chunk_size (int, optional): Number of tokens in each processing + chunk. This determines the size of token blocks on which output + rails are applied. Defaults to 10. + + Attributes: + buffer_context_size (int): Number of context tokens retained between chunks. + buffer_chunk_size (int): Number of tokens in each processing chunk. + total_yielded (int): Tracks the total number of chunks yielded to the user. + + Example: + >>> config = OutputRailsStreamingConfig(context_size=2, chunk_size=4) + >>> buffer = RollingBuffer.from_config(config) + >>> async for chunk_batch in buffer.process_stream(stream_handler): + ... # for output rails processing (needs context) + ... processing_text = buffer.format_chunks(chunk_batch.processing_context) + ... # For user output (new content only) + ... user_text = buffer.format_chunks(chunk_batch.user_output_chunks) + ... pass + >>> # or use the callable interface: + >>> async for chunk_batch in buffer(stream_handler): + ... # same as above, delegates to process_stream + ... processing_text = buffer.format_chunks(chunk_batch.processing_context) + ... pass + + Note: + The processing buffer includes context from previous chunks, while + user_output_chunks contains only the tokens to be yielded to the user. """ def __init__(self, buffer_context_size: int = 5, buffer_chunk_size: int = 10): + """Initialize the RollingBuffer with specified buffer sizes. + + Args: + buffer_context_size (int, optional): Number of context tokens to + retain between chunks. Defaults to 5. + buffer_chunk_size (int, optional): Number of tokens per processing + chunk. Defaults to 10. + + Returns: + None + + Raises: + ValueError: If buffer_context_size or buffer_chunk_size is negative. + """ + if buffer_context_size < 0: + raise ValueError("buffer_context_size must be non-negative") + if buffer_chunk_size < 0: + raise ValueError("buffer_chunk_size must be non-negative") + self.buffer_context_size = buffer_context_size self.buffer_chunk_size = buffer_chunk_size - self.last_index = 0 + # track total chunks yielded to user + self.total_yielded = 0 @classmethod def from_config(cls, config: OutputRailsStreamingConfig): + """Create a RollingBuffer instance from a streaming configuration. + + Args: + config (OutputRailsStreamingConfig): Configuration object containing + context_size and chunk_size parameters. + + Returns: + RollingBuffer: A new RollingBuffer instance configured with the + provided parameters. + + Example: + >>> config = OutputRailsStreamingConfig(context_size=3, chunk_size=6) + >>> buffer = RollingBuffer.from_config(config) + """ return cls( buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size ) - async def __call__( + async def process_stream( self, streaming_handler - ) -> AsyncGenerator[Tuple[List[str], str], None]: + ) -> AsyncGenerator[ChunkBatch, None]: + """Process streaming chunks using rolling buffer strategy. + + This method implements the rolling buffer logic, accumulating chunks + and yielding them in batches with context for output rails processing. + The buffer maintains a sliding window of context tokens for continuity. + + Args: + streaming_handler: An async iterator that yields individual string + chunks from the LLM stream. + + Yields: + ChunkBatch: Named tuple containing processing_context and user_output_chunks. + + Example: + >>> async def stream_handler(): + ... for chunk in ["Hello", " ", "world", "!"]: + ... yield chunk + >>> + >>> buffer = RollingBuffer(context_size=1, chunk_size=2) + >>> async for chunk_batch in buffer.process_stream(stream_handler()): + ... print(f"Processing buffer: {chunk_batch.processing_context}") + ... print(f"New chunks: {chunk_batch.user_output_chunks}") + ... # for output rails processing (with context): + ... context_str = buffer.format_chunks(chunk_batch.processing_context) + ... # for user output (new content only): + ... user_str = buffer.format_chunks(chunk_batch.user_output_chunks) + ... print(f"Processing: '{context_str}', User: '{user_str}'") + + Note: + The method resets the total_yielded counter at the start of each + streaming session to ensure accurate tracking. + """ + # reset state for each streaming session + self.total_yielded = 0 buffer = [] - index = 0 + total_chunks = 0 async for chunk in streaming_handler: buffer.append(chunk) - index += 1 + total_chunks += 1 if len(buffer) >= self.buffer_chunk_size: - yield ( - # we apply output rails on the buffer - buffer[-self.buffer_chunk_size - self.buffer_context_size :], - # generate_chunk_str is what gets printed in the console or yield to user - # to avoid repeating the already streamed/printed chunk - self.generate_chunk_str( - buffer[-self.buffer_chunk_size - self.buffer_context_size :], - index, - ), + # calculate how many new chunks should be yielded + new_chunks_to_yield = min( + self.buffer_chunk_size, total_chunks - self.total_yielded + ) + + # create the processing buffer (includes context) + processing_buffer = buffer[ + -self.buffer_chunk_size - self.buffer_context_size : + ] + + # get the new chunks to yield to user (preserve original token format) + # the new chunks are at the end of the buffer + chunks_to_yield = buffer[-new_chunks_to_yield:] + self.total_yielded += new_chunks_to_yield + + yield ChunkBatch( + processing_context=processing_buffer, + user_output_chunks=chunks_to_yield, ) buffer = buffer[-self.buffer_context_size :] - # Yield any remaining buffer if it's not empty + # yield any remaining buffer if it's not empty if buffer: - yield ( - buffer, - self.generate_chunk_str( - buffer[-self.buffer_chunk_size - self.buffer_context_size :], index - ), + # calculate how many chunks from the remaining buffer haven't been yielded yet + remaining_chunks_to_yield = total_chunks - self.total_yielded + chunks_to_yield = ( + buffer[-remaining_chunks_to_yield:] + if remaining_chunks_to_yield > 0 + else [] + ) + + yield ChunkBatch( + processing_context=buffer, + user_output_chunks=chunks_to_yield, ) - def generate_chunk_str(self, buffer, current_index) -> str: - if current_index <= self.last_index: - return "" + def format_chunks(self, chunks: List[str]) -> str: + """Generate string representation of chunks preserving original token format. + + The RollingBuffer strategy preserves the original token format by + joining chunks without modification, maintaining spaces and formatting + as they appeared in the original LLM output. - new_chunks = buffer[self.last_index - current_index :] - self.last_index = current_index - # TODO: something causes duplicate whitespaces between tokens, figure out why, - # If using `return "".join(new_chunks)` works, then the issue might be elsewhere in the code where the chunks are being generated or processed. - # Ensure that the chunks themselves do not contain extra spaces. - # WAR: return "".join(new_chunks) - return "".join(new_chunks) + Args: + chunks (List[str]): List of chunk tokens to be formatted. + + Returns: + str: String representation preserving original token spacing and format. + + Example: + >>> buffer = RollingBuffer() + >>> chunks = ["Hello", " ", "world", "!"] + >>> result = buffer.format_chunks(chunks) + >>> print(result) # "Hello world!" + """ + return "".join(chunks) def get_buffer_strategy(config: OutputRailsStreamingConfig) -> BufferStrategy: + """Create a buffer strategy from the given configuration. + + Args: + config (OutputRailsStreamingConfig): Configuration object specifying + the buffer strategy parameters. + + Returns: + BufferStrategy: A configured buffer strategy instance. Currently + returns a RollingBuffer instance. + + Example: + >>> config = OutputRailsStreamingConfig(context_size=2, chunk_size=4) + >>> strategy = get_buffer_strategy(config) + >>> isinstance(strategy, RollingBuffer) + True + + Note: + This is currently a simple factory that only returns RollingBuffer + instances. Future versions may support multiple buffer strategies + with a registry pattern. + """ # TODO: use a factory function or class - # currently we only have RollingBuffer, in future we use a registry return RollingBuffer.from_config(config) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 446238c30..4064efe69 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -1056,17 +1056,39 @@ def stream_async( include_generation_metadata=include_generation_metadata ) - # todo use a context var for buffer strategy and return it here? - # then iterating over buffer strategy is nested loop? - asyncio.create_task( - self.generate_async( - prompt=prompt, - messages=messages, - streaming_handler=streaming_handler, - options=options, - state=state, - ) - ) + # Create a properly managed task with exception handling + async def _generation_task(): + try: + await self.generate_async( + prompt=prompt, + messages=messages, + streaming_handler=streaming_handler, + options=options, + state=state, + ) + except Exception as e: + # If an exception occurs during generation, push it to the streaming handler as a json string + # This ensures the streaming pipeline is properly terminated + log.error(f"Error in generation task: {e}", exc_info=True) + error_message = str(e) + error_dict = extract_error_json(error_message) + error_payload = json.dumps(error_dict) + await streaming_handler.push_chunk(error_payload) + await streaming_handler.push_chunk(END_OF_STREAM) + + task = asyncio.create_task(_generation_task()) + + # Store task reference to prevent garbage collection and ensure proper cleanup + if not hasattr(self, "_active_tasks"): + self._active_tasks = set() + self._active_tasks.add(task) + + # Clean up task when it's done + def task_done_callback(task): + self._active_tasks.discard(task) + + task.add_done_callback(task_done_callback) + # when we have output rails we wrap the streaming handler # if len(self.config.rails.output.flows) > 0: # @@ -1327,7 +1349,7 @@ def _get_latest_user_message( def _prepare_params( flow_id: str, action_name: str, - chunk_str: str, + bot_response_chunk: str, prompt: Optional[str] = None, messages: Optional[List[dict]] = None, action_params: Dict[str, Any] = {}, @@ -1337,7 +1359,7 @@ def _prepare_params( context = { "user_message": user_message, - "bot_message": chunk_str, + "bot_message": bot_response_chunk, } if context_message: @@ -1350,7 +1372,7 @@ def _prepare_params( # to resolve replace placeholders in action_params for key, value in action_params.items(): if value == "$bot_message": - action_params[key] = chunk_str + action_params[key] = bot_response_chunk elif value == "$user_message": action_params[key] = user_message @@ -1377,24 +1399,28 @@ def _prepare_params( _get_action_details_from_flow_id, flows=self.config.flows ) - async for chunk_list, chunk_str_rep in buffer_strategy(streaming_handler): - chunk_str = " ".join(chunk_list) + async for chunk_batch in buffer_strategy(streaming_handler): + user_output_chunks = chunk_batch.user_output_chunks + # format processing_context for output rails processing (needs full context) + bot_response_chunk = buffer_strategy.format_chunks( + chunk_batch.processing_context + ) + + # check if user_output_chunks is a list of individual chunks + # or if it's a JSON string, by convention this means an error occurred and the error dict is stored as a JSON + if not isinstance(user_output_chunks, list): + try: + json.loads(user_output_chunks) + yield user_output_chunks + return + except (json.JSONDecodeError, TypeError): + # if it's not JSON, treat it as empty list + user_output_chunks = [] - # Check if chunk_str_rep is a JSON string - # we yield a json error payload in generate_async when - # streaming has errors - try: - json.loads(chunk_str_rep) - yield chunk_str_rep - return - except json.JSONDecodeError: - pass if stream_first: - words = chunk_str_rep.split() - if words: - yield words[0] - for word in words[1:]: - yield f" {word}" + # yield the individual chunks directly from the buffer strategy + for chunk in user_output_chunks: + yield chunk for flow_id in output_rails_flows_id: action_name, action_params = get_action_details(flow_id) @@ -1402,20 +1428,17 @@ def _prepare_params( params = _prepare_params( flow_id=flow_id, action_name=action_name, - chunk_str=chunk_str, + bot_response_chunk=bot_response_chunk, prompt=prompt, messages=messages, action_params=action_params, ) - # Execute the action. (Your execute_action returns only the result.) result = await self.runtime.action_dispatcher.execute_action( action_name, params ) - # Include explain info (whatever _update_explain_info does) self.explain_info = self._ensure_explain_info() - # Retrieve the action function from the dispatcher action_func = self.runtime.action_dispatcher.get_action(action_name) # Use the mapping to decide if the result indicates blocked content. @@ -1443,11 +1466,9 @@ def _prepare_params( return if not stream_first: - words = chunk_str_rep.split() - if words: - yield words[0] - for word in words[1:]: - yield f" {word}" + # yield the individual chunks directly from the buffer strategy + for chunk in user_output_chunks: + yield chunk def _get_action_details_from_flow_id( diff --git a/tests/test_buffer_strategy.py b/tests/test_buffer_strategy.py index c0062551f..7c56dc762 100644 --- a/tests/test_buffer_strategy.py +++ b/tests/test_buffer_strategy.py @@ -15,7 +15,12 @@ import pytest -from nemoguardrails.rails.llm.buffer import RollingBuffer as BufferStrategy +from nemoguardrails.rails.llm.buffer import ( + BufferStrategy, + RollingBuffer, + get_buffer_strategy, +) +from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig async def fake_streaming_handler(): @@ -24,12 +29,40 @@ async def fake_streaming_handler(): yield f"chunk{i}" +async def realistic_streaming_handler(): + """Simulate realistic LLM streaming with proper tokens including spaces.""" + response = "This is a safe and compliant response that should pass." + tokens = [] + words = response.split(" ") + for i, word in enumerate(words): + if i < len(words) - 1: + # add space to all tokens except the last one + tokens.append(word + " ") + else: + tokens.append(word) + + for token in tokens: + yield token + + +async def short_streaming_handler(): + """Stream shorter than buffer size.""" + for token in ["Hello", " ", "world"]: + yield token + + +async def empty_streaming_handler(): + """Empty stream.""" + return + yield # unreachable + + @pytest.mark.asyncio async def test_buffer_strategy(): - buffer_strategy = BufferStrategy(buffer_context_size=5, buffer_chunk_size=10) + buffer_strategy = RollingBuffer(buffer_context_size=5, buffer_chunk_size=10) streaming_handler = fake_streaming_handler() - expected_buffers = [ + expected_processing_contexts = [ [ "chunk0", "chunk1", @@ -57,8 +90,269 @@ async def test_buffer_strategy(): ["chunk10", "chunk11", "chunk12", "chunk13", "chunk14"], ] - async for idx, (buffer, _) in async_enumerate(buffer_strategy(streaming_handler)): - assert buffer == expected_buffers[idx] + expected_user_output_chunks = [ + [ + "chunk0", + "chunk1", + "chunk2", + "chunk3", + "chunk4", + "chunk5", + "chunk6", + "chunk7", + "chunk8", + "chunk9", + ], + ["chunk10", "chunk11", "chunk12", "chunk13", "chunk14"], + [], + ] + + results = [] + async for idx, chunk_batch in async_enumerate(buffer_strategy(streaming_handler)): + results.append( + { + "processing_context": chunk_batch.processing_context, + "user_output_chunks": chunk_batch.user_output_chunks, + } + ) + + for idx, result in enumerate(results): + assert result["processing_context"] == expected_processing_contexts[idx] + assert result["user_output_chunks"] == expected_user_output_chunks[idx] + + +@pytest.mark.asyncio +async def test_buffer_strategy_realistic_data(): + """Test with realistic token data including spaces.""" + buffer_strategy = RollingBuffer(buffer_context_size=2, buffer_chunk_size=4) + streaming_handler = realistic_streaming_handler() + + expected_results = [ + { + "processing_context": ["This ", "is ", "a ", "safe "], + "user_output_chunks": ["This ", "is ", "a ", "safe "], + }, + { + "processing_context": ["a ", "safe ", "and ", "compliant "], + "user_output_chunks": ["and ", "compliant "], + }, + { + "processing_context": ["and ", "compliant ", "response ", "that "], + "user_output_chunks": ["response ", "that "], + }, + { + "processing_context": ["response ", "that ", "should ", "pass."], + "user_output_chunks": ["should ", "pass."], + }, + { + "processing_context": ["should ", "pass."], + "user_output_chunks": [], + }, + ] + + results = [] + async for chunk_batch in buffer_strategy(streaming_handler): + results.append( + { + "processing_context": chunk_batch.processing_context, + "user_output_chunks": chunk_batch.user_output_chunks, + } + ) + + assert results == expected_results + + +@pytest.mark.asyncio +async def test_both_interfaces_identical(): + """Test both process_stream() and __call__() interfaces work identically.""" + buffer_strategy = RollingBuffer(buffer_context_size=1, buffer_chunk_size=3) + + # process_stream interface + results_process_stream = [] + async for chunk_batch in buffer_strategy.process_stream( + realistic_streaming_handler() + ): + results_process_stream.append( + ( + chunk_batch.processing_context.copy(), + chunk_batch.user_output_chunks.copy(), + ) + ) + + # __call__ interface + results_call = [] + async for chunk_batch in buffer_strategy(realistic_streaming_handler()): + results_call.append( + ( + chunk_batch.processing_context.copy(), + chunk_batch.user_output_chunks.copy(), + ) + ) + + assert results_process_stream == results_call + + +@pytest.mark.asyncio +async def test_edge_cases(): + """Test various edge cases.""" + + # empty stream + buffer_strategy = RollingBuffer(buffer_context_size=2, buffer_chunk_size=4) + results = [] + async for chunk_batch in buffer_strategy(empty_streaming_handler()): + results.append(chunk_batch) + assert results == [], "Empty stream should yield no results" + + # stream shorter than buffer + results = [] + async for chunk_batch in buffer_strategy(short_streaming_handler()): + results.append(chunk_batch) + + assert len(results) == 1 + assert results[0].processing_context == ["Hello", " ", "world"] + assert results[0].user_output_chunks == ["Hello", " ", "world"] + + +def test_validation(): + """Test input validation.""" + with pytest.raises(ValueError, match="buffer_context_size must be non-negative"): + RollingBuffer(buffer_context_size=-1) + + with pytest.raises(ValueError, match="buffer_chunk_size must be non-negative"): + RollingBuffer(buffer_chunk_size=-1) + + buffer = RollingBuffer(buffer_context_size=0, buffer_chunk_size=1) + assert buffer.buffer_context_size == 0 + assert buffer.buffer_chunk_size == 1 + + +def test_from_config(): + """Test configuration-based instantiation.""" + config = OutputRailsStreamingConfig(context_size=3, chunk_size=6) + buffer = RollingBuffer.from_config(config) + + assert buffer.buffer_context_size == 3 + assert buffer.buffer_chunk_size == 6 + + +def test_get_buffer_strategy(): + """Test factory function.""" + config = OutputRailsStreamingConfig(context_size=2, chunk_size=5) + strategy = get_buffer_strategy(config) + + assert isinstance(strategy, RollingBuffer) + assert strategy.buffer_context_size == 2 + assert strategy.buffer_chunk_size == 5 + + +def test_format_chunks(): + buffer_strategy = RollingBuffer(buffer_context_size=5, buffer_chunk_size=10) + chunks = ["chunk0", "chunk1", "chunk2", "chunk3", "chunk4", "chunk5"] + + result = buffer_strategy.format_chunks(chunks) + assert result == "chunk0chunk1chunk2chunk3chunk4chunk5" + + +def test_format_chunks_realistic(): + """Test format_chunks with realistic token data.""" + buffer_strategy = RollingBuffer() + + chunks = ["Hello", " ", "world", "!"] + result = buffer_strategy.format_chunks(chunks) + assert result == "Hello world!" + + # empty chunks + assert buffer_strategy.format_chunks([]) == "" + + # single chunk + assert buffer_strategy.format_chunks(["test"]) == "test" + + +@pytest.mark.asyncio +async def test_total_yielded_tracking(): + """Test that total_yielded is correctly tracked and reset.""" + buffer_strategy = RollingBuffer(buffer_context_size=1, buffer_chunk_size=2) + + # first stream + user_chunks_1 = [] + async for chunk_batch in buffer_strategy(short_streaming_handler()): + user_chunks_1.extend(chunk_batch.user_output_chunks) + + # second stream: total_yielded should reset + user_chunks_2 = [] + async for chunk_batch in buffer_strategy(short_streaming_handler()): + user_chunks_2.extend(chunk_batch.user_output_chunks) + + # verifies reset worked + assert user_chunks_1 == user_chunks_2 + + +@pytest.mark.asyncio +async def test_boundary_conditions(): + """Test exact buffer size boundaries.""" + + async def exact_size_handler(): + """Stream exactly buffer_chunk_size tokens.""" + for i in range(4): + yield f"token{i} " + + buffer_strategy = RollingBuffer(buffer_context_size=1, buffer_chunk_size=4) + results = [] + async for chunk_batch in buffer_strategy(exact_size_handler()): + results.append(chunk_batch) + + # should get exactly one full chunk plus final empty + assert len(results) == 2 + assert len(results[0].user_output_chunks) == 4 + # final empty yield + assert len(results[1].user_output_chunks) == 0 + + +@pytest.mark.asyncio +async def test_subword_token_preservation(): + """Test that subword tokens are preserved without extra spaces (issue #1197).""" + + async def subword_token_stream(): + # simulate subword tokens like BPE tokenization + # example: "assisting" becomes ["ass", "isting"] + yield "ass" + yield "isting" + yield " with " + yield "help" + yield "ing" + yield " you" + + buffer_strategy = RollingBuffer(buffer_context_size=2, buffer_chunk_size=3) + + # Collect all data in a single pass to avoid creating duplicate streams + processing_contexts = [] + user_output_parts = [] + + async for chunk_batch in buffer_strategy(subword_token_stream()): + formatted_text = buffer_strategy.format_chunks(chunk_batch.processing_context) + processing_contexts.append(formatted_text) + + user_chunk_text = buffer_strategy.format_chunks(chunk_batch.user_output_chunks) + user_output_parts.append(user_chunk_text) + + # reconstruct the full text from user output chunks + full_text = "".join(user_output_parts) + + # subword tokens should be properly joined + assert "assisting" in full_text, f"Expected 'assisting' but got: {full_text}" + assert "helping" in full_text, f"Expected 'helping' but got: {full_text}" + + # verify no extra spaces were introduced between subword tokens + assert ( + "ass isting" not in full_text + ), f"Found extra space in subword tokens: {full_text}" + assert ( + "help ing" not in full_text + ), f"Found extra space in subword tokens: {full_text}" + + # expected result should be: "assisting with helping you" + expected = "assisting with helping you" + assert full_text == expected, f"Expected '{expected}' but got '{full_text}'" async def async_enumerate(aiterable, start=0): @@ -68,13 +362,139 @@ async def async_enumerate(aiterable, start=0): idx += 1 -def test_generate_chunk_str(): - buffer_strategy = BufferStrategy(buffer_context_size=5, buffer_chunk_size=10) - buffer = ["chunk0", "chunk1", "chunk2", "chunk3", "chunk4", "chunk5"] - current_index = 6 +def test_abstract_base_class_cannot_be_instantiated(): + """Test that the abstract BufferStrategy cannot be instantiated directly.""" + + with pytest.raises(TypeError): + BufferStrategy() + + +def test_incomplete_implementation_raises_error(): + """Test that incomplete implementations of BufferStrategy raise TypeError.""" + + class IncompleteBufferStrategy(BufferStrategy): + pass + + with pytest.raises(TypeError): + IncompleteBufferStrategy() + + class MissingProcessStreamStrategy(BufferStrategy): + @classmethod + def from_config(cls, config): + return cls() + + def format_chunks(self, chunks): + return "".join(chunks) + + with pytest.raises(TypeError): + MissingProcessStreamStrategy() + + class MissingFormatChunksStrategy(BufferStrategy): + @classmethod + def from_config(cls, config): + return cls() + + async def process_stream(self, streaming_handler): + async for chunk in streaming_handler: + yield chunk + + with pytest.raises(TypeError): + MissingFormatChunksStrategy() + + class MissingFromConfigStrategy(BufferStrategy): + def format_chunks(self, chunks): + return "".join(chunks) + + async def process_stream(self, streaming_handler): + async for chunk in streaming_handler: + yield chunk + + with pytest.raises(TypeError): + MissingFromConfigStrategy() + + +def test_additional_validation_errors(): + """Test additional validation errors beyond the existing ones.""" + + with pytest.raises(ValueError, match="buffer_context_size must be non-negative"): + RollingBuffer(buffer_context_size=-100) + + with pytest.raises(ValueError, match="buffer_chunk_size must be non-negative"): + RollingBuffer(buffer_chunk_size=-1000) + + with pytest.raises(ValueError, match="buffer_context_size must be non-negative"): + RollingBuffer(buffer_context_size=-1, buffer_chunk_size=-1) + + +def test_validation_with_zero_values(): + """Test that zero values are accepted for buffer parameters.""" + + buffer = RollingBuffer(buffer_context_size=0, buffer_chunk_size=5) + assert buffer.buffer_context_size == 0 + assert buffer.buffer_chunk_size == 5 + + buffer = RollingBuffer(buffer_context_size=5, buffer_chunk_size=0) + assert buffer.buffer_context_size == 5 + assert buffer.buffer_chunk_size == 0 + + buffer = RollingBuffer(buffer_context_size=0, buffer_chunk_size=0) + assert buffer.buffer_context_size == 0 + assert buffer.buffer_chunk_size == 0 + + +@pytest.mark.asyncio +async def test_complete_implementation_works(): + """Test that a complete implementation of BufferStrategy works correctly.""" + + class CompleteBufferStrategy(BufferStrategy): + def __init__(self, test_param=None): + self.test_param = test_param + + @classmethod + def from_config(cls, config): + return cls(test_param="from_config") + + def format_chunks(self, chunks): + return "|".join(chunks) + + async def process_stream(self, streaming_handler): + buffer = [] + async for chunk in streaming_handler: + buffer.append(chunk) + if len(buffer) >= 2: + from nemoguardrails.rails.llm.buffer import ChunkBatch + + yield ChunkBatch( + processing_context=buffer, user_output_chunks=buffer + ) + buffer = [] + + if buffer: + from nemoguardrails.rails.llm.buffer import ChunkBatch + + yield ChunkBatch(processing_context=buffer, user_output_chunks=buffer) + + strategy = CompleteBufferStrategy() + assert strategy.test_param is None + + config = OutputRailsStreamingConfig(context_size=1, chunk_size=1) + strategy = CompleteBufferStrategy.from_config(config) + assert strategy.test_param == "from_config" + + chunks = ["hello", "world"] + result = strategy.format_chunks(chunks) + assert result == "hello|world" + + async def test_handler(): + for chunk in ["a", "b", "c"]: + yield chunk - # we've already processed chunks 0 to 4 by setting last_index to 5 - buffer_strategy.last_index = 5 + results = [] + async for chunk_batch in strategy.process_stream(test_handler()): + results.append(chunk_batch) - result = buffer_strategy.generate_chunk_str(buffer, current_index) - assert result == "chunk5" + assert len(results) == 2 + assert results[0].processing_context == ["a", "b"] + assert results[0].user_output_chunks == ["a", "b"] + assert results[1].processing_context == ["c"] + assert results[1].user_output_chunks == ["c"] diff --git a/tests/test_streaming.py b/tests/test_streaming.py index a590569ad..74e215ce2 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -336,22 +336,22 @@ async def test_streaming_output_rails_allowed(output_rails_streaming_config): ] expected_tokens = [ - "This", - " is", - " a", - " funny", - "joke", - " but", - "you", - " should", - "not", - " laught", - "at", - " it", - "because", - " you", - "will", - " be", + "This ", + "is ", + "a ", + "funny ", + "joke ", + "but ", + "you ", + "should ", + "not ", + "laught ", + "at ", + "it ", + "because ", + "you ", + "will ", + "be ", "cursed!.", ] tokens = await run_self_check_test(output_rails_streaming_config, llm_completions) @@ -366,6 +366,32 @@ async def test_streaming_output_rails_allowed(output_rails_streaming_config): await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) +@pytest.mark.asyncio +async def test_sequential_streaming_output_rails_allowed( + output_rails_streaming_config, +): + """Tests that sequential output rails allow content when no blocking keywords are present""" + + llm_completions = [ + " bot express insult", + ' "Hi, how are you doing?"', + ' "This is a safe and compliant high quality joke that should pass all checks."', + ] + + chunks = await run_self_check_test(output_rails_streaming_config, llm_completions) + + response = "".join(chunks) + assert len(response) > 0 + assert len(chunks) > 1 + assert "This is a safe" in response + assert "compliant high quality" in response + + error_chunks = [chunk for chunk in chunks if chunk.startswith('{"error":')] + assert len(error_chunks) == 0 + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + @pytest.mark.asyncio async def test_streaming_output_rails_blocked(output_rails_streaming_config): """This test checks if the streaming output rails block the completions when a BLOCK keyword is present.