diff --git a/nemoguardrails/colang/runtime.py b/nemoguardrails/colang/runtime.py index ba61eaaf5..4ffa1817d 100644 --- a/nemoguardrails/colang/runtime.py +++ b/nemoguardrails/colang/runtime.py @@ -37,6 +37,12 @@ def __init__(self, config: RailsConfig, verbose: bool = False): import_paths=list(config.imported_paths.values()), ) + if hasattr(self, "_run_output_rails_in_parallel_streaming"): + self.action_dispatcher.register_action( + self._run_output_rails_in_parallel_streaming, + name="run_output_rails_in_parallel_streaming", + ) + # The list of additional parameters that can be passed to the actions. self.registered_action_params: dict = {} diff --git a/nemoguardrails/colang/v1_0/runtime/runtime.py b/nemoguardrails/colang/v1_0/runtime/runtime.py index 56fa00efc..661d5ad83 100644 --- a/nemoguardrails/colang/v1_0/runtime/runtime.py +++ b/nemoguardrails/colang/v1_0/runtime/runtime.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio import inspect import logging -import uuid from textwrap import indent from time import time from typing import Any, Dict, List, Optional, Tuple @@ -25,10 +24,13 @@ from langchain.chains.base import Chain from nemoguardrails.actions.actions import ActionResult +from nemoguardrails.actions.output_mapping import is_output_blocked from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.runtime import Runtime from nemoguardrails.colang.v1_0.runtime.flows import ( FlowConfig, + _get_flow_params, + _normalize_flow_id, compute_context, compute_next_steps, ) @@ -259,6 +261,89 @@ def _internal_error_action_result(message: str): ] ) + async def _run_output_rails_in_parallel_streaming( + self, flows_with_params: Dict[str, dict], events: List[dict] + ) -> ActionResult: + """Run the output rails in parallel for streaming chunks. + + This is a streamlined version that avoids the full flow state management + which can cause issues with hide_prev_turn logic during streaming. + + Args: + flows_with_params: Dictionary mapping flow_id to {"action_name": str, "params": dict} + events: The events list for context + """ + tasks = [] + + async def run_single_rail(flow_id: str, action_info: dict) -> tuple: + """Run a single rail flow and return (flow_id, result)""" + + try: + action_name = action_info["action_name"] + params = action_info["params"] + + result_tuple = await self.action_dispatcher.execute_action( + action_name, params + ) + result, status = result_tuple + + if status != "success": + log.error(f"Action {action_name} failed with status: {status}") + return flow_id, False # Allow on failure + + action_func = self.action_dispatcher.get_action(action_name) + + # use the mapping to decide if the result indicates blocked content. + # True means blocked, False means allowed + result = is_output_blocked(result, action_func) + + return flow_id, result + + except Exception as e: + log.error(f"Error executing rail {flow_id}: {e}") + return flow_id, False # Allow on error + + # create tasks for all flows + for flow_id, action_info in flows_with_params.items(): + task = asyncio.create_task(run_single_rail(flow_id, action_info)) + tasks.append(task) + + stopped_events = [] + + try: + for future in asyncio.as_completed(tasks): + try: + flow_id, is_blocked = await future + + # check if this rail blocked the content + if is_blocked: + # create stop events + stopped_events = [ + { + "type": "BotIntent", + "intent": "stop", + "flow_id": flow_id, + } + ] + + # cancel remaining tasks + for pending_task in tasks: + if not pending_task.done(): + pending_task.cancel() + break + + except asyncio.CancelledError: + pass + except Exception as e: + log.error(f"Error in parallel rail task: {e}") + continue + + except Exception as e: + log.error(f"Error in parallel rail execution: {e}") + return ActionResult(events=[]) + + return ActionResult(events=stopped_events) + async def _process_start_action(self, events: List[dict]) -> List[dict]: """ Start the specified action, wait for it to finish, and post back the result. @@ -458,8 +543,9 @@ async def _get_action_resp( ) resp = await resp.json() - result, status = resp.get("result", result), resp.get( - "status", status + result, status = ( + resp.get("result", result), + resp.get("status", status), ) except Exception as e: log.info(f"Exception {e} while making request to {action_name}") diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 306413155..d0e0cf03e 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -455,6 +455,11 @@ class OutputRailsStreamingConfig(BaseModel): class OutputRails(BaseModel): """Configuration of output rails.""" + parallel: Optional[bool] = Field( + default=False, + description="If True, the output rails are executed in parallel.", + ) + flows: List[str] = Field( default_factory=list, description="The names of all the flows that implement output rails.", diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 1999904e9..5c39f80a3 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -66,7 +66,7 @@ from nemoguardrails.logging.verbose import set_verbose from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop from nemoguardrails.rails.llm.buffer import get_buffer_strategy -from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, Model, RailsConfig +from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig from nemoguardrails.rails.llm.options import ( GenerationLog, GenerationOptions, @@ -1351,6 +1351,32 @@ def _get_latest_user_message( return message return {} + def _prepare_context_for_parallel_rails( + chunk_str: str, + prompt: Optional[str] = None, + messages: Optional[List[dict]] = None, + ) -> dict: + """Prepare context for parallel rails execution.""" + context_message = _get_last_context_message(messages) + user_message = prompt or _get_latest_user_message(messages) + + context = { + "user_message": user_message, + "bot_message": chunk_str, + } + + if context_message: + context.update(context_message["content"]) + + return context + + def _create_events_for_chunk(chunk_str: str, context: dict) -> List[dict]: + """Create events for running output rails on a chunk.""" + return [ + {"type": "ContextUpdate", "data": context}, + {"type": "BotMessage", "text": chunk_str}, + ] + def _prepare_params( flow_id: str, action_name: str, @@ -1404,6 +1430,8 @@ def _prepare_params( _get_action_details_from_flow_id, flows=self.config.flows ) + parallel_mode = getattr(self.config.rails.output, "parallel", False) + async for chunk_batch in buffer_strategy(streaming_handler): user_output_chunks = chunk_batch.user_output_chunks # format processing_context for output rails processing (needs full context) @@ -1427,48 +1455,118 @@ def _prepare_params( for chunk in user_output_chunks: yield chunk - for flow_id in output_rails_flows_id: - action_name, action_params = get_action_details(flow_id) + if parallel_mode: + try: + context = _prepare_context_for_parallel_rails( + bot_response_chunk, prompt, messages + ) + events = _create_events_for_chunk(bot_response_chunk, context) + + flows_with_params = {} + for flow_id in output_rails_flows_id: + action_name, action_params = get_action_details(flow_id) + params = _prepare_params( + flow_id=flow_id, + action_name=action_name, + bot_response_chunk=bot_response_chunk, + prompt=prompt, + messages=messages, + action_params=action_params, + ) + flows_with_params[flow_id] = { + "action_name": action_name, + "params": params, + } + + result_tuple = await self.runtime.action_dispatcher.execute_action( + "run_output_rails_in_parallel_streaming", + { + "flows_with_params": flows_with_params, + "events": events, + }, + ) - params = _prepare_params( - flow_id=flow_id, - action_name=action_name, - bot_response_chunk=bot_response_chunk, - prompt=prompt, - messages=messages, - action_params=action_params, - ) + # ActionDispatcher.execute_action always returns (result, status) + result, status = result_tuple - result = await self.runtime.action_dispatcher.execute_action( - action_name, params - ) + if status != "success": + log.error( + f"Parallel rails execution failed with status: {status}" + ) + # continue processing the chunk even if rails fail + pass + else: + # if there are any stop events, content was blocked + if result.events: + # extract the blocked flow from the first stop event + blocked_flow = result.events[0].get( + "flow_id", "output rails" + ) + + reason = f"Blocked by {blocked_flow} rails." + error_data = { + "error": { + "message": reason, + "type": "guardrails_violation", + "param": blocked_flow, + "code": "content_blocked", + } + } + yield json.dumps(error_data) + return + + except Exception as e: + log.error(f"Error in parallel rail execution: {e}") + # don't block the stream for rail execution errors + # continue processing the chunk + pass + + # update explain info for parallel mode self.explain_info = self._ensure_explain_info() - action_func = self.runtime.action_dispatcher.get_action(action_name) - - # Use the mapping to decide if the result indicates blocked content. - if is_output_blocked(result, action_func): - reason = f"Blocked by {flow_id} rails." - - # return the error as a plain JSON string (not in SSE format) - # NOTE: When integrating with the OpenAI Python client, the server code should: - # 1. detect this JSON error object in the stream - # 2. terminate the stream - # 3. format the error following OpenAI's SSE format - # the OpenAI client will then properly raise an APIError with this error message - - error_data = { - "error": { - "message": reason, - "type": "guardrails_violation", - "param": flow_id, - "code": "content_blocked", + else: + for flow_id in output_rails_flows_id: + action_name, action_params = get_action_details(flow_id) + + params = _prepare_params( + flow_id=flow_id, + action_name=action_name, + bot_response_chunk=bot_response_chunk, + prompt=prompt, + messages=messages, + action_params=action_params, + ) + + result = await self.runtime.action_dispatcher.execute_action( + action_name, params + ) + self.explain_info = self._ensure_explain_info() + + action_func = self.runtime.action_dispatcher.get_action(action_name) + + # Use the mapping to decide if the result indicates blocked content. + if is_output_blocked(result, action_func): + reason = f"Blocked by {flow_id} rails." + + # return the error as a plain JSON string (not in SSE format) + # NOTE: When integrating with the OpenAI Python client, the server code should: + # 1. detect this JSON error object in the stream + # 2. terminate the stream + # 3. format the error following OpenAI's SSE format + # the OpenAI client will then properly raise an APIError with this error message + + error_data = { + "error": { + "message": reason, + "type": "guardrails_violation", + "param": flow_id, + "code": "content_blocked", + } } - } - # return as plain JSON: the server should detect this JSON and convert it to an HTTP error - yield json.dumps(error_data) - return + # return as plain JSON: the server should detect this JSON and convert it to an HTTP error + yield json.dumps(error_data) + return if not stream_first: # yield the individual chunks directly from the buffer strategy diff --git a/tests/test_parallel_streaming_output_rails.py b/tests/test_parallel_streaming_output_rails.py new file mode 100644 index 000000000..a286c5fb9 --- /dev/null +++ b/tests/test_parallel_streaming_output_rails.py @@ -0,0 +1,1221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the parallel output rails streaming functionality.""" + +import asyncio +import json +import time +from json.decoder import JSONDecodeError + +import pytest + +from nemoguardrails import RailsConfig +from nemoguardrails.actions import action +from tests.utils import TestChat + + +@pytest.fixture +def parallel_output_rails_streaming_config(): + """Config for testing parallel output rails with streaming enabled and multiple flows""" + + return RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "parallel": True, + "flows": [ + "self check output safety", + "self check output compliance", + "self check output quality", + ], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + "stream_first": False, + }, + } + }, + "streaming": False, + "prompts": [ + {"task": "self_check_output", "content": "Check: {{ bot_response }}"}, + ], + }, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot tell joke + + define subflow self check output safety + $allowed = execute self_check_output_safety + if not $allowed + bot refuse to respond + stop + + define subflow self check output compliance + $allowed = execute self_check_output_compliance + if not $allowed + bot refuse to respond + stop + + define subflow self check output quality + $allowed = execute self_check_output_quality + if not $allowed + bot refuse to respond + stop + """, + ) + + +@pytest.fixture +def parallel_output_rails_streaming_single_flow_config(): + """Config for testing parallel output rails with single flow""" + + return RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "parallel": True, + "flows": ["self check output"], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + "stream_first": False, + }, + } + }, + "streaming": False, + "prompts": [ + {"task": "self_check_output", "content": "Check: {{ bot_response }}"}, + ], + }, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot tell joke + + define subflow self check output + execute self_check_output + """, + ) + + +@pytest.fixture +def parallel_output_rails_default_config(): + """Config for testing parallel output rails with default streaming settings""" + + return RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "parallel": True, + "flows": [ + "self check output safety", + "self check output compliance", + ], + } + }, + "streaming": True, + "prompts": [ + {"task": "self_check_output", "content": "Check: {{ bot_response }}"}, + ], + }, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot tell joke + + define subflow self check output safety + execute self_check_output_safety + + define subflow self check output compliance + execute self_check_output_compliance + """, + ) + + +@action(is_system_action=True) +def self_check_output_safety(context=None, **params): + """Safety check that blocks content containing UNSAFE keyword.""" + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "UNSAFE" in bot_message_chunk: + return False + return True + + +@action(is_system_action=True) +def self_check_output_compliance(context=None, **params): + """Compliance check that blocks content containing VIOLATION keyword.""" + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "VIOLATION" in bot_message_chunk: + return False + return True + + +@action(is_system_action=True) +def self_check_output_quality(context=None, **params): + """Quality check that blocks content containing LOWQUALITY keyword.""" + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "LOWQUALITY" in bot_message_chunk: + return False + return True + + +@action(is_system_action=True) +def self_check_output(context=None, **params): + """Generic check that blocks content containing BLOCK keyword.""" + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "BLOCK" in bot_message_chunk: + return False + return True + + +@action(is_system_action=True, output_mapping=lambda result: not result) +async def slow_self_check_output_safety(**params): + """Slow safety check for timing tests.""" + await asyncio.sleep(0.1) + return self_check_output_safety(**params) + + +@action(is_system_action=True, output_mapping=lambda result: not result) +async def slow_self_check_output_compliance(**params): + """Slow compliance check for timing tests.""" + await asyncio.sleep(0.1) + return self_check_output_compliance(**params) + + +@action(is_system_action=True, output_mapping=lambda result: not result) +async def slow_self_check_output_quality(**params): + """Slow quality check for timing tests.""" + await asyncio.sleep(0.1) + return self_check_output_quality(**params) + + +async def run_parallel_self_check_test(config, llm_completions, register_actions=True): + """Helper function to run parallel self check test with the given config and llm completions""" + + chat = TestChat( + config, + llm_completions=llm_completions, + streaming=True, + ) + + if register_actions: + chat.app.register_action(self_check_output_safety) + chat.app.register_action(self_check_output_compliance) + chat.app.register_action(self_check_output_quality) + chat.app.register_action(self_check_output) + + chunks = [] + async for chunk in chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + chunks.append(chunk) + + return chunks + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_allowed( + parallel_output_rails_streaming_config, +): + """Tests that parallel output rails allow content when no blocking keywords are present""" + + llm_completions = [ + " bot express insult", + ' "Hi, how are you doing?"', + ' "This is a safe and compliant high quality joke that should pass all checks."', + ] + + chunks = await run_parallel_self_check_test( + parallel_output_rails_streaming_config, llm_completions + ) + + # should receive all chunks without blocking + response = "".join(chunks) + assert len(response) > 0 + assert len(chunks) > 1 + assert "This is a safe" in response + assert "compliant high quality" in response + + error_chunks = [chunk for chunk in chunks if chunk.startswith('{"error":')] + assert len(error_chunks) == 0 + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_blocked_by_safety( + parallel_output_rails_streaming_config, +): + """Tests that parallel output rails block content when safety rail detects UNSAFE keyword""" + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This is an UNSAFE joke that should be blocked by safety check."', + ] + + chunks = await run_parallel_self_check_test( + parallel_output_rails_streaming_config, llm_completions + ) + + expected_error = { + "error": { + "message": "Blocked by self check output safety rails.", + "type": "guardrails_violation", + "param": "self check output safety", + "code": "content_blocked", + } + } + + error_found = False + for chunk in chunks: + try: + parsed = json.loads(chunk) + if "error" in parsed and parsed == expected_error: + error_found = True + break + except JSONDecodeError: + continue + + assert error_found, f"Expected error not found in chunks: {chunks}" + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_blocked_by_compliance( + parallel_output_rails_streaming_config, +): + """Tests that parallel output rails block content when compliance rail detects VIOLATION keyword""" + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This joke contains a policy VIOLATION and should be blocked."', + ] + + chunks = await run_parallel_self_check_test( + parallel_output_rails_streaming_config, llm_completions + ) + + expected_error = { + "error": { + "message": "Blocked by self check output compliance rails.", + "type": "guardrails_violation", + "param": "self check output compliance", + "code": "content_blocked", + } + } + + error_found = False + for chunk in chunks: + try: + parsed = json.loads(chunk) + if "error" in parsed and parsed == expected_error: + error_found = True + break + except JSONDecodeError: + continue + + assert error_found, f"Expected error not found in chunks: {chunks}" + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_blocked_by_quality( + parallel_output_rails_streaming_config, +): + """Tests that parallel output rails block content when quality rail detects LOWQUALITY keyword""" + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This is a LOWQUALITY joke that should be blocked by quality check."', + ] + + chunks = await run_parallel_self_check_test( + parallel_output_rails_streaming_config, llm_completions + ) + + expected_error = { + "error": { + "message": "Blocked by self check output quality rails.", + "type": "guardrails_violation", + "param": "self check output quality", + "code": "content_blocked", + } + } + + error_found = False + for chunk in chunks: + try: + parsed = json.loads(chunk) + if "error" in parsed and parsed == expected_error: + error_found = True + break + except JSONDecodeError: + continue + + assert error_found, f"Expected error not found in chunks: {chunks}" + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_blocked_at_start( + parallel_output_rails_streaming_single_flow_config, +): + """Tests parallel blocking when BLOCK keyword appears at the very beginning""" + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "[BLOCK] This should be blocked immediately at the start."', + ] + + chunks = await run_parallel_self_check_test( + parallel_output_rails_streaming_single_flow_config, llm_completions + ) + + expected_error = { + "error": { + "message": "Blocked by self check output rails.", + "type": "guardrails_violation", + "param": "self check output", + "code": "content_blocked", + } + } + + # should be blocked immediately with only one error chunk + assert len(chunks) == 1 + assert json.loads(chunks[0]) == expected_error + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_multiple_blocking_keywords( + parallel_output_rails_streaming_config, +): + """Tests parallel rails when multiple blocking keywords are present - should block on first detected""" + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This contains both UNSAFE content and a VIOLATION which is also LOWQUALITY."', + ] + + chunks = await run_parallel_self_check_test( + parallel_output_rails_streaming_config, llm_completions + ) + + # should be blocked by one of the rails (whichever detects first in parallel execution) + error_chunks = [] + for chunk in chunks: + try: + parsed = json.loads(chunk) + if "error" in parsed: + error_chunks.append(parsed) + except JSONDecodeError: + continue + + assert ( + len(error_chunks) == 1 + ), f"Expected exactly one error chunk, got {len(error_chunks)}" + + error = error_chunks[0] + assert error["error"]["type"] == "guardrails_violation" + assert error["error"]["code"] == "content_blocked" + assert "Blocked by" in error["error"]["message"] + + # should be blocked by one of the three rails + blocked_by_options = [ + "self check output safety", + "self check output compliance", + "self check output quality", + ] + assert error["error"]["param"] in blocked_by_options + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_performance_benefits(): + """Tests that parallel rails execution provides performance benefits over sequential""" + + parallel_config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "parallel": True, + "flows": [ + "slow self check output safety", + "slow self check output compliance", + "slow self check output quality", + ], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + }, + } + }, + "streaming": False, + }, + colang_content=""" + define user express greeting + "hi" + define flow + user express greeting + bot tell joke + + define subflow slow self check output safety + execute slow_self_check_output_safety + + define subflow slow self check output compliance + execute slow_self_check_output_compliance + + define subflow slow self check output quality + execute slow_self_check_output_quality + """, + ) + + sequential_config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "parallel": False, + "flows": [ + "slow self check output safety", + "slow self check output compliance", + "slow self check output quality", + ], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + }, + } + }, + "streaming": False, + }, + colang_content=""" + define user express greeting + "hi" + define flow + user express greeting + bot tell joke + + define subflow slow self check output safety + execute slow_self_check_output_safety + + define subflow slow self check output compliance + execute slow_self_check_output_compliance + + define subflow slow self check output quality + execute slow_self_check_output_quality + """, + ) + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This is a safe and compliant high quality response for timing tests."', + ] + + parallel_chat = TestChat( + parallel_config, llm_completions=llm_completions, streaming=True + ) + parallel_chat.app.register_action(slow_self_check_output_safety) + parallel_chat.app.register_action(slow_self_check_output_compliance) + parallel_chat.app.register_action(slow_self_check_output_quality) + + start_time = time.time() + parallel_chunks = [] + async for chunk in parallel_chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + parallel_chunks.append(chunk) + parallel_time = time.time() - start_time + + sequential_chat = TestChat( + sequential_config, llm_completions=llm_completions, streaming=True + ) + sequential_chat.app.register_action(slow_self_check_output_safety) + sequential_chat.app.register_action(slow_self_check_output_compliance) + sequential_chat.app.register_action(slow_self_check_output_quality) + + start_time = time.time() + sequential_chunks = [] + async for chunk in sequential_chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + sequential_chunks.append(chunk) + sequential_time = time.time() - start_time + + # Parallel should be faster than sequential (allowing some margin for test variability) + print( + f"Parallel time: {parallel_time:.2f}s, Sequential time: {sequential_time:.2f}s" + ) + + # with 3 rails each taking ~0.1 s sequential should take ~0.3 s per chunk, parallel should be closer to 0.1s + # we allow some margin for test execution overhead + assert parallel_time < sequential_time * 0.8, ( + f"Parallel execution ({parallel_time:.2f}s) should be significantly faster than " + f"sequential execution ({sequential_time:.2f}s)" + ) + + parallel_response = "".join(parallel_chunks) + sequential_response = "".join(sequential_chunks) + assert parallel_response == sequential_response + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_default_config_behavior( + parallel_output_rails_default_config, +): + """Tests parallel output rails with default streaming configuration""" + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This is a test message with default streaming config."', + ] + + chunks = await run_parallel_self_check_test( + parallel_output_rails_default_config, llm_completions + ) + + response = "".join(chunks) + assert len(response) > 0 + assert len(chunks) > 0 + assert "test message" in response + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_error_handling(): + """Tests error handling in parallel streaming when rails fail""" + + @action(is_system_action=True, output_mapping=lambda result: not result) + def failing_rail(**params): + raise Exception("Simulated rail failure") + + @action(is_system_action=True, output_mapping=lambda result: not result) + def working_rail(**params): + return True + + config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "parallel": True, + "flows": ["failing rail", "working rail"], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + }, + } + }, + "streaming": False, + }, + colang_content=""" + define user express greeting + "hi" + define flow + user express greeting + bot tell joke + + define subflow failing rail + execute failing_rail + + define subflow working rail + execute working_rail + """, + ) + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This message should still be processed despite one rail failing."', + ] + + chat = TestChat(config, llm_completions=llm_completions, streaming=True) + chat.app.register_action(failing_rail) + chat.app.register_action(working_rail) + + chunks = [] + async for chunk in chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + chunks.append(chunk) + + # should continue processing despite one rail failing + response = "".join(chunks) + assert len(response) > 0 + assert "should still be processed" in response + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_stream_first_enabled(): + """Tests parallel streaming with stream_first option enabled""" + + config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "parallel": True, + "flows": ["self check output"], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + "stream_first": True, + }, + } + }, + "streaming": False, + "prompts": [ + {"task": "self_check_output", "content": "Check: {{ bot_response }}"}, + ], + }, + colang_content=""" + define user express greeting + "hi" + define flow + user express greeting + bot tell joke + + define subflow self check output + execute self_check_output + """, + ) + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This is a test message for stream first functionality."', + ] + + chunks = await run_parallel_self_check_test(config, llm_completions) + + assert len(chunks) > 1 + response = "".join(chunks) + assert "test message" in response + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_streaming_output_rails_large_chunk_processing(): + """Tests parallel streaming with larger chunks to ensure proper processing""" + + config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "parallel": True, + "flows": [ + "self check output safety", + "self check output compliance", + ], + "streaming": { + "enabled": True, + "chunk_size": 10, + "context_size": 3, + }, + } + }, + "streaming": False, + }, + colang_content=""" + define user express greeting + "hi" + define flow + user express greeting + bot tell joke + + define subflow self check output safety + execute self_check_output_safety + + define subflow self check output compliance + execute self_check_output_compliance + """, + ) + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This is a much longer response that will be processed in larger chunks to test the parallel rail processing functionality with bigger chunk sizes and ensure that everything works correctly."', + ] + + chunks = await run_parallel_self_check_test(config, llm_completions) + + response = "".join(chunks) + assert len(response) > 50 + assert "much longer response" in response + assert "parallel rail processing" in response + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_sequential_vs_parallel_streaming_output_rails_comparison(): + """Direct comparison test between sequential and parallel streaming output rails. + + This test demonstrates the differences between sequential and parallel execution + using identical content and configurations, except for the parallel flag. + """ + + @action(is_system_action=True, output_mapping=lambda result: not result) + def test_self_check_output(context=None, **params): + """Test check that blocks content containing BLOCK keyword.""" + + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "BLOCK" in bot_message_chunk: + return False + return True + + base_config = { + "models": [], + "rails": { + "output": { + "flows": ["test self check output"], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + "stream_first": False, + }, + } + }, + "streaming": False, + } + + colang_content = """ + define user express greeting + "hi" + + define flow + user express greeting + bot tell joke + + define subflow test self check output + execute test_self_check_output + """ + + sequential_config = RailsConfig.from_content( + config=base_config, + colang_content=colang_content, + ) + + parallel_config_dict = base_config.copy() + parallel_config_dict["rails"]["output"]["parallel"] = True + + parallel_config = RailsConfig.from_content( + config=parallel_config_dict, + colang_content=colang_content, + ) + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This is a safe and compliant high quality joke that should pass all checks."', + ] + + sequential_chat = TestChat( + sequential_config, + llm_completions=llm_completions, + streaming=True, + ) + sequential_chat.app.register_action(test_self_check_output) + + parallel_chat = TestChat( + parallel_config, + llm_completions=llm_completions, + streaming=True, + ) + parallel_chat.app.register_action(test_self_check_output) + + import time + + start_time = time.time() + sequential_chunks = [] + async for chunk in sequential_chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + sequential_chunks.append(chunk) + sequential_time = time.time() - start_time + + start_time = time.time() + parallel_chunks = [] + async for chunk in parallel_chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + parallel_chunks.append(chunk) + parallel_time = time.time() - start_time + + # both should produce the same successful output + sequential_response = "".join(sequential_chunks) + parallel_response = "".join(parallel_chunks) + + assert len(sequential_response) > 0 + assert len(parallel_response) > 0 + assert "This is a safe" in sequential_response + assert "This is a safe" in parallel_response + assert "compliant high quality" in sequential_response + assert "compliant high quality" in parallel_response + + # neither should have error chunks + sequential_error_chunks = [ + chunk for chunk in sequential_chunks if chunk.startswith('{"error":') + ] + parallel_error_chunks = [ + chunk for chunk in parallel_chunks if chunk.startswith('{"error":') + ] + + assert ( + len(sequential_error_chunks) == 0 + ), f"Sequential had errors: {sequential_error_chunks}" + assert ( + len(parallel_error_chunks) == 0 + ), f"Parallel had errors: {parallel_error_chunks}" + + assert sequential_response == parallel_response, ( + f"Sequential and parallel should produce identical content:\n" + f"Sequential: {sequential_response}\n" + f"Parallel: {parallel_response}" + ) + + # log timing comparison (parallel should be faster or similar for single rail) + print(f"\nTiming Comparison:") + print(f"Sequential: {sequential_time:.4f}s") + print(f"Parallel: {parallel_time:.4f}s") + print(f"Speedup: {sequential_time / parallel_time:.2f}x") + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_sequential_vs_parallel_streaming_blocking_comparison(): + """Test that both sequential and parallel handle blocking scenarios identically""" + + @action(is_system_action=True, output_mapping=lambda result: not result) + def test_self_check_output_blocking(context=None, **params): + """Test check that blocks content containing BLOCK keyword.""" + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "BLOCK" in bot_message_chunk: + return False + return True + + base_config = { + "models": [], + "rails": { + "output": { + "flows": ["test self check output blocking"], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + "stream_first": False, + }, + } + }, + "streaming": False, + } + + colang_content = """ + define user express greeting + "hi" + + define flow + user express greeting + bot tell joke + + define subflow test self check output blocking + execute test_self_check_output_blocking + """ + + sequential_config = RailsConfig.from_content( + config=base_config, colang_content=colang_content + ) + + parallel_config_dict = base_config.copy() + parallel_config_dict["rails"]["output"]["parallel"] = True + parallel_config = RailsConfig.from_content( + config=parallel_config_dict, colang_content=colang_content + ) + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This contains a BLOCK keyword that should be blocked."', + ] + + sequential_chat = TestChat( + sequential_config, + llm_completions=llm_completions, + streaming=True, + ) + sequential_chat.app.register_action(test_self_check_output_blocking) + + parallel_chat = TestChat( + parallel_config, + llm_completions=llm_completions, + streaming=True, + ) + parallel_chat.app.register_action(test_self_check_output_blocking) + + sequential_chunks = [] + async for chunk in sequential_chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + sequential_chunks.append(chunk) + + parallel_chunks = [] + async for chunk in parallel_chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + parallel_chunks.append(chunk) + + sequential_errors = [] + parallel_errors = [] + + for chunk in sequential_chunks: + try: + parsed = json.loads(chunk) + if "error" in parsed: + sequential_errors.append(parsed) + except JSONDecodeError: + continue + + for chunk in parallel_chunks: + try: + parsed = json.loads(chunk) + if "error" in parsed: + parallel_errors.append(parsed) + except JSONDecodeError: + continue + + assert ( + len(sequential_errors) == 1 + ), f"Sequential should have 1 error, got {len(sequential_errors)}" + assert ( + len(parallel_errors) == 1 + ), f"Parallel should have 1 error, got {len(parallel_errors)}" + + seq_error = sequential_errors[0] + par_error = parallel_errors[0] + + assert seq_error["error"]["type"] == "guardrails_violation" + assert par_error["error"]["type"] == "guardrails_violation" + assert seq_error["error"]["code"] == "content_blocked" + assert par_error["error"]["code"] == "content_blocked" + assert "Blocked by" in seq_error["error"]["message"] + assert "Blocked by" in par_error["error"]["message"] + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + + +@pytest.mark.asyncio +async def test_parallel_vs_sequential_with_slow_actions(): + """Test that demonstrates real parallel speedup with slow actions""" + + import time + + @action(is_system_action=True, output_mapping=lambda result: not result) + async def slow_safety_check(context=None, **params): + """Slow safety check that simulates real processing time.""" + # simulate 100ms of processing + await asyncio.sleep(0.1) + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "UNSAFE" in bot_message_chunk: + return False + return True + + @action(is_system_action=True, output_mapping=lambda result: not result) + async def slow_compliance_check(context=None, **params): + """Slow compliance check that simulates real processing time.""" + await asyncio.sleep(0.1) + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "VIOLATION" in bot_message_chunk: + return False + return True + + @action(is_system_action=True, output_mapping=lambda result: not result) + async def slow_quality_check(context=None, **params): + """Slow quality check that simulates real processing time.""" + await asyncio.sleep(0.1) + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") + if "LOWQUALITY" in bot_message_chunk: + return False + return True + + base_config = { + "models": [], + "rails": { + "output": { + "flows": [ + "slow safety check", + "slow compliance check", + "slow quality check", + ], + "streaming": { + "enabled": True, + "chunk_size": 4, + "context_size": 2, + "stream_first": False, + }, + } + }, + "streaming": False, + } + + colang_content = """ + define user express greeting + "hi" + + define flow + user express greeting + bot tell joke + + define subflow slow safety check + execute slow_safety_check + + define subflow slow compliance check + execute slow_compliance_check + + define subflow slow quality check + execute slow_quality_check + """ + + sequential_config = RailsConfig.from_content( + config=base_config, + colang_content=colang_content, + ) + + parallel_config_dict = base_config.copy() + parallel_config_dict["rails"]["output"]["parallel"] = True + + parallel_config = RailsConfig.from_content( + config=parallel_config_dict, + colang_content=colang_content, + ) + + llm_completions = [ + ' express greeting\nbot express greeting\n "Hi, how are you doing?"', + ' "This is a safe and compliant high quality joke that should pass all checks."', + ] + + sequential_chat = TestChat( + sequential_config, + llm_completions=llm_completions, + streaming=True, + ) + sequential_chat.app.register_action(slow_safety_check) + sequential_chat.app.register_action(slow_compliance_check) + sequential_chat.app.register_action(slow_quality_check) + + parallel_chat = TestChat( + parallel_config, + llm_completions=llm_completions, + streaming=True, + ) + parallel_chat.app.register_action(slow_safety_check) + parallel_chat.app.register_action(slow_compliance_check) + parallel_chat.app.register_action(slow_quality_check) + + print(f"\n=== SLOW ACTIONS PERFORMANCE TEST ===") + print(f"Each action takes 100ms, 3 actions total") + print(f"Expected: Sequential ~300ms per chunk, Parallel ~100ms per chunk") + + start_time = time.time() + sequential_chunks = [] + async for chunk in sequential_chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + sequential_chunks.append(chunk) + sequential_time = time.time() - start_time + + start_time = time.time() + parallel_chunks = [] + async for chunk in parallel_chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + parallel_chunks.append(chunk) + parallel_time = time.time() - start_time + + sequential_response = "".join(sequential_chunks) + parallel_response = "".join(parallel_chunks) + + assert len(sequential_response) > 0 + assert len(parallel_response) > 0 + assert "This is a safe" in sequential_response + assert "This is a safe" in parallel_response + + sequential_error_chunks = [ + chunk for chunk in sequential_chunks if chunk.startswith('{"error":') + ] + parallel_error_chunks = [ + chunk for chunk in parallel_chunks if chunk.startswith('{"error":') + ] + + assert len(sequential_error_chunks) == 0 + assert len(parallel_error_chunks) == 0 + + assert sequential_response == parallel_response + + speedup = sequential_time / parallel_time + + print(f"\nSlow Actions Timing Results:") + print(f"Sequential: {sequential_time:.4f}s") + print(f"Parallel: {parallel_time:.4f}s") + print(f"Speedup: {speedup:.2f}x") + + # with slow actions, parallel should be significantly faster + # we expect at least 1.5x speedup (theoretical max ~3x, but overhead reduces it) + assert speedup >= 1.5, ( + f"With slow actions, parallel should be at least 1.5x faster than sequential. " + f"Got speedup of {speedup:.2f}x. Sequential: {sequential_time:.4f}s, Parallel: {parallel_time:.4f}s" + ) + + print(f" Parallel execution achieved {speedup:.2f}x speedup as expected!") + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})