diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 2a57e1c26..fbaa4c0b7 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -83,7 +83,7 @@ class LLMGenerationActions: def __init__( self, config: RailsConfig, - llm: Union[BaseLLM, BaseChatModel], + llm: Optional[Union[BaseLLM, BaseChatModel]], llm_task_manager: LLMTaskManager, get_embedding_search_provider_instance: Callable[ [Optional[EmbeddingSearchProvider]], EmbeddingsIndex diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py index e66f1a0d5..a92faefaa 100644 --- a/nemoguardrails/context.py +++ b/nemoguardrails/context.py @@ -14,25 +14,42 @@ # limitations under the License. import contextvars -from typing import Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional -streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None) +if TYPE_CHECKING: + from nemoguardrails.logging.explain import ExplainInfo + from nemoguardrails.rails.llm.options import GenerationOptions, LLMStats + from nemoguardrails.streaming import StreamingHandler + +streaming_handler_var: contextvars.ContextVar[ + Optional["StreamingHandler"] +] = contextvars.ContextVar("streaming_handler", default=None) # The object that holds additional explanation information. -explain_info_var = contextvars.ContextVar("explain_info", default=None) +explain_info_var: contextvars.ContextVar[ + Optional["ExplainInfo"] +] = contextvars.ContextVar("explain_info", default=None) # The current LLM call. -llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None) +llm_call_info_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + "llm_call_info", default=None +) # All the generation options applicable to the current context. -generation_options_var = contextvars.ContextVar("generation_options", default=None) +generation_options_var: contextvars.ContextVar[ + Optional["GenerationOptions"] +] = contextvars.ContextVar("generation_options", default=None) # The stats about the LLM calls. -llm_stats_var = contextvars.ContextVar("llm_stats", default=None) +llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar( + "llm_stats", default=None +) # The raw LLM request that comes from the user. # This is used in passthrough mode. -raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None) +raw_llm_request: contextvars.ContextVar[Optional[Any]] = contextvars.ContextVar( + "raw_llm_request", default=None +) reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( "reasoning_trace", default=None diff --git a/nemoguardrails/rails/llm/buffer.py b/nemoguardrails/rails/llm/buffer.py index 30e48c4e3..541f52915 100644 --- a/nemoguardrails/rails/llm/buffer.py +++ b/nemoguardrails/rails/llm/buffer.py @@ -14,7 +14,10 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, NamedTuple +from typing import TYPE_CHECKING, AsyncGenerator, List, NamedTuple + +if TYPE_CHECKING: + from collections.abc import AsyncIterator from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig @@ -111,9 +114,7 @@ def format_chunks(self, chunks: List[str]) -> str: ... @abstractmethod - async def process_stream( - self, streaming_handler - ) -> AsyncGenerator[ChunkBatch, None]: + async def process_stream(self, streaming_handler): """Process streaming chunks and yield chunk batches. This is the main method that concrete buffer strategies must implement. @@ -138,9 +139,9 @@ async def process_stream( ... print(f"Processing: {context_formatted}") ... print(f"User: {user_formatted}") """ - ... + yield ChunkBatch([], []) # pragma: no cover - async def __call__(self, streaming_handler) -> AsyncGenerator[ChunkBatch, None]: + async def __call__(self, streaming_handler): """Callable interface that delegates to process_stream. It delegates to the `process_stream` method and can @@ -256,9 +257,7 @@ def from_config(cls, config: OutputRailsStreamingConfig): buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size ) - async def process_stream( - self, streaming_handler - ) -> AsyncGenerator[ChunkBatch, None]: + async def process_stream(self, streaming_handler): """Process streaming chunks using rolling buffer strategy. This method implements the rolling buffer logic, accumulating chunks diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index bc12569a1..8557a43b0 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -1048,7 +1048,9 @@ def _load_path( # the first .railsignore file found from cwd down to its subdirectories railsignore_path = utils.get_railsignore_path(config_path) - ignore_patterns = utils.get_railsignore_patterns(railsignore_path) + ignore_patterns = ( + utils.get_railsignore_patterns(railsignore_path) if railsignore_path else set() + ) if os.path.isdir(config_path): for root, _, files in os.walk(config_path, followlinks=True): @@ -1165,8 +1167,8 @@ def _parse_colang_files_recursively( current_file, current_path = colang_files[len(parsed_colang_files)] with open(current_path, "r", encoding="utf-8") as f: + content = f.read() try: - content = f.read() _parsed_config = parse_colang_file( current_file, content=content, version=colang_version ) @@ -1668,7 +1670,7 @@ def streaming_supported(self): # if we have output rails streaming enabled # we keep it in case it was needed when we have # support per rails - if self.rails.output.streaming.enabled: + if self.rails.output.streaming and self.rails.output.streaming.enabled: return True return False diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 0027b7fc5..d84cdb860 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -24,7 +24,18 @@ import threading import time from functools import partial -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Type, Union, cast +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + cast, +) from langchain_core.language_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM @@ -67,7 +78,11 @@ from nemoguardrails.logging.verbose import set_verbose from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop from nemoguardrails.rails.llm.buffer import get_buffer_strategy -from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig +from nemoguardrails.rails.llm.config import ( + EmbeddingSearchProvider, + OutputRailsStreamingConfig, + RailsConfig, +) from nemoguardrails.rails.llm.options import ( GenerationLog, GenerationOptions, @@ -203,17 +218,18 @@ def __init__( # We check if the configuration or any of the imported ones have config.py modules. config_modules = [] - for _path in list(self.config.imported_paths.values()) + [ - self.config.config_path - ]: + for _path in list( + self.config.imported_paths.values() if self.config.imported_paths else [] + ) + [self.config.config_path]: if _path: filepath = os.path.join(_path, "config.py") if os.path.exists(filepath): filename = os.path.basename(filepath) spec = importlib.util.spec_from_file_location(filename, filepath) - config_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config_module) - config_modules.append(config_module) + if spec and spec.loader: + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + config_modules.append(config_module) # First, we initialize the runtime. if config.colang_version == "1.0": @@ -393,8 +409,8 @@ def _configure_main_llm_streaming( if not self.config.streaming: return - if "streaming" in llm.model_fields: - llm.streaming = True + if hasattr(llm, "streaming"): + setattr(llm, "streaming", True) self.main_llm_supports_streaming = True else: self.main_llm_supports_streaming = False @@ -623,9 +639,13 @@ def _create_action_llm_copy( # isolate model_kwargs to prevent shared mutable state if ( hasattr(isolated_llm, "model_kwargs") - and isolated_llm.model_kwargs is not None + and getattr(isolated_llm, "model_kwargs", None) is not None ): - isolated_llm.model_kwargs = isolated_llm.model_kwargs.copy() + setattr( + isolated_llm, + "model_kwargs", + getattr(isolated_llm, "model_kwargs").copy(), + ) log.debug( "Successfully created isolated LLM copy for action: %s", action_name @@ -853,6 +873,19 @@ async def generate_async( The completion (when a prompt is provided) or the next message. System messages are not yet supported.""" + # convert options to gen_options of type GenerationOptions + gen_options: Optional[GenerationOptions] = None + + if prompt is None and messages is None: + raise ValueError("Either prompt or messages must be provided.") + + if prompt is not None and messages is not None: + raise ValueError("Only one of prompt or messages can be provided.") + + if prompt is not None: + # Currently, we transform the prompt request into a single turn conversation + messages = [{"role": "user", "content": prompt}] + # If a state object is specified, then we switch to "generation options" mode. # This is because we want the output to be a GenerationResponse which will contain # the output state. @@ -862,14 +895,25 @@ async def generate_async( state = json_to_state(state["state"]) if options is None: - options = GenerationOptions() - - # We allow options to be specified both as a dict and as an object. - if options and isinstance(options, dict): - options = GenerationOptions(**options) + gen_options = GenerationOptions() + elif isinstance(options, dict): + gen_options = GenerationOptions(**options) + else: + gen_options = options + else: + # We allow options to be specified both as a dict and as an object. + if options and isinstance(options, dict): + gen_options = GenerationOptions(**options) + elif isinstance(options, GenerationOptions): + gen_options = options + elif options is None: + gen_options = None + else: + raise TypeError("options must be a dict or GenerationOptions") # Save the generation options in the current async context. - generation_options_var.set(options) + # At this point, gen_options is either None or GenerationOptions + generation_options_var.set(gen_options) if streaming_handler: streaming_handler_var.set(streaming_handler) @@ -879,26 +923,25 @@ async def generate_async( # requests are made. self.explain_info = self._ensure_explain_info() - if prompt is not None: - # Currently, we transform the prompt request into a single turn conversation - messages = [{"role": "user", "content": prompt}] - raw_llm_request.set(prompt) - else: - raw_llm_request.set(messages) + raw_llm_request.set(messages) # If we have generation options, we also add them to the context - if options: + if gen_options: messages = [ - {"role": "context", "content": {"generation_options": options.dict()}} - ] + messages + { + "role": "context", + "content": {"generation_options": gen_options.model_dump()}, + } + ] + (messages or []) # If the last message is from the assistant, rather than the user, then # we move that to the `$bot_message` variable. This is to enable a more # convenient interface. (only when dialog rails are disabled) if ( - messages[-1]["role"] == "assistant" - and options - and options.rails.dialog is False + messages + and messages[-1]["role"] == "assistant" + and gen_options + and gen_options.rails.dialog is False ): # We already have the first message with a context update, so we use that messages[0]["content"]["bot_message"] = messages[-1]["content"] @@ -915,7 +958,7 @@ async def generate_async( processing_log = [] # The array of events corresponding to the provided sequence of messages. - events = self._get_events_for_messages(messages, state) + events = self._get_events_for_messages(messages, state) # type: ignore if self.config.colang_version == "1.0": # If we had a state object, we also need to prepend the events from the state. @@ -939,10 +982,10 @@ async def generate_async( # Push an error chunk instead of None. error_message = str(e) error_dict = extract_error_json(error_message) - error_payload = json.dumps(error_dict) + error_payload: str = json.dumps(error_dict) await streaming_handler.push_chunk(error_payload) # push a termination signal - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore # Re-raise the exact exception raise else: @@ -1013,7 +1056,7 @@ async def generate_async( response_events.append(event) if exception: - new_message = {"role": "exception", "content": exception} + new_message: dict = {"role": "exception", "content": exception} else: # Ensure all items in responses are strings @@ -1021,7 +1064,7 @@ async def generate_async( str(response) if not isinstance(response, str) else response for response in responses ] - new_message = {"role": "assistant", "content": "\n".join(responses)} + new_message: dict = {"role": "assistant", "content": "\n".join(responses)} if response_tool_calls: new_message["tool_calls"] = response_tool_calls if response_events: @@ -1034,7 +1077,7 @@ async def generate_async( # If a state object is not used, then we use the implicit caching if state is None: # Save the new events in the history and update the cache - cache_key = get_history_cache_key(messages + [new_message]) + cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore self.events_history_cache[cache_key] = events else: output_state = {"events": events} @@ -1057,35 +1100,31 @@ async def generate_async( streaming_handler = streaming_handler_var.get() if streaming_handler: # print("Closing the stream handler explicitly") - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore # IF tracing is enabled we need to set GenerationLog attrs original_log_options = None if self.config.tracing.enabled: - if options is None: - options = GenerationOptions() + if gen_options is None: + gen_options = GenerationOptions() else: - # create a copy of the options to avoid modifying the original - if isinstance(options, GenerationOptions): - options = options.model_copy(deep=True) - else: - # If options is a dict, convert it to GenerationOptions - options = GenerationOptions(**options) - original_log_options = options.log.model_copy(deep=True) + # create a copy of the gen_options to avoid modifying the original + gen_options = gen_options.model_copy(deep=True) + original_log_options = gen_options.log.model_copy(deep=True) # enable log options # it is aggressive, but these are required for tracing if ( - not options.log.activated_rails - or not options.log.llm_calls - or not options.log.internal_events + not gen_options.log.activated_rails + or not gen_options.log.llm_calls + or not gen_options.log.internal_events ): - options.log.activated_rails = True - options.log.llm_calls = True - options.log.internal_events = True + gen_options.log.activated_rails = True + gen_options.log.llm_calls = True + gen_options.log.internal_events = True # If we have generation options, we prepare a GenerationResponse instance. - if options: + if gen_options: # If a prompt was used, we only need to return the content of the message. if prompt: res = GenerationResponse(response=new_message["content"]) @@ -1094,21 +1133,24 @@ async def generate_async( if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): if prompt: - res.response = reasoning_trace + res.response + # For prompt mode, response should be a string + if isinstance(res.response, str): + res.response = reasoning_trace + res.response else: - res.response[0]["content"] = ( - reasoning_trace + res.response[0]["content"] - ) + # For message mode, response should be a list + if isinstance(res.response, list) and len(res.response) > 0: + res.response[0]["content"] = ( + reasoning_trace + res.response[0]["content"] + ) if self.config.colang_version == "1.0": # If output variables are specified, we extract their values - if options.output_vars: + if gen_options and gen_options.output_vars: context = compute_context(events) - if isinstance(options.output_vars, list): + output_vars = gen_options.output_vars + if isinstance(output_vars, list): # If we have only a selection of keys, we filter to only that. - res.output_data = { - k: context.get(k) for k in options.output_vars - } + res.output_data = {k: context.get(k) for k in output_vars} else: # Otherwise, we return the full context res.output_data = context @@ -1116,37 +1158,40 @@ async def generate_async( _log = compute_generation_log(processing_log) # Include information about activated rails and LLM calls if requested - if options.log.activated_rails or options.log.llm_calls: + log_options = gen_options.log if gen_options else None + if log_options and ( + log_options.activated_rails or log_options.llm_calls + ): res.log = GenerationLog() # We always include the stats res.log.stats = _log.stats - if options.log.activated_rails: + if log_options.activated_rails: res.log.activated_rails = _log.activated_rails - if options.log.llm_calls: + if log_options.llm_calls: res.log.llm_calls = [] for activated_rail in _log.activated_rails: for executed_action in activated_rail.executed_actions: res.log.llm_calls.extend(executed_action.llm_calls) # Include internal events if requested - if options.log.internal_events: + if log_options and log_options.internal_events: if res.log is None: res.log = GenerationLog() res.log.internal_events = new_events # Include the Colang history if requested - if options.log.colang_history: + if log_options and log_options.colang_history: if res.log is None: res.log = GenerationLog() res.log.colang_history = get_colang_history(events) # Include the raw llm output if requested - if options.llm_output: + if gen_options and gen_options.llm_output: # Currently, we include the output from the generation LLM calls. for activated_rail in _log.activated_rails: if activated_rail.type == "generation": @@ -1154,22 +1199,23 @@ async def generate_async( for llm_call in executed_action.llm_calls: res.llm_output = llm_call.raw_response else: - if options.output_vars: + if gen_options and gen_options.output_vars: raise ValueError( "The `output_vars` option is not supported for Colang 2.0 configurations." ) - if ( - options.log.activated_rails - or options.log.llm_calls - or options.log.internal_events - or options.log.colang_history + log_options = gen_options.log if gen_options else None + if log_options and ( + log_options.activated_rails + or log_options.llm_calls + or log_options.internal_events + or log_options.colang_history ): raise ValueError( "The `log` option is not supported for Colang 2.0 configurations." ) - if options.llm_output: + if gen_options and gen_options.llm_output: raise ValueError( "The `llm_output` option is not supported for Colang 2.0 configurations." ) @@ -1211,12 +1257,14 @@ async def generate_async( ): res.log = None else: - if not original_log_options.internal_events: - res.log.internal_events = [] - if not original_log_options.activated_rails: - res.log.activated_rails = [] - if not original_log_options.llm_calls: - res.log.llm_calls = [] + # Ensure res.log exists before setting attributes + if res.log is not None: + if not original_log_options.internal_events: + res.log.internal_events = [] + if not original_log_options.activated_rails: + res.log.activated_rails = [] + if not original_log_options.llm_calls: + res.log.llm_calls = [] return res else: @@ -1243,9 +1291,13 @@ def stream_async( # if an external generator is provided, use it directly if generator: - if self.config.rails.output.streaming.enabled: + if ( + self.config.rails.output.streaming + and self.config.rails.output.streaming.enabled + ): return self._run_output_rails_in_streaming( streaming_handler=generator, + output_rails_streaming_config=self.config.rails.output.streaming, messages=messages, prompt=prompt, ) @@ -1276,7 +1328,7 @@ async def _generation_task(): error_dict = extract_error_json(error_message) error_payload = json.dumps(error_dict) await streaming_handler.push_chunk(error_payload) - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore task = asyncio.create_task(_generation_task()) @@ -1294,10 +1346,14 @@ def task_done_callback(task): # when we have output rails we wrap the streaming handler # if len(self.config.rails.output.flows) > 0: # - if self.config.rails.output.streaming.enabled: + if ( + self.config.rails.output.streaming + and self.config.rails.output.streaming.enabled + ): # returns an async generator return self._run_output_rails_in_streaming( streaming_handler=streaming_handler, + output_rails_streaming_config=self.config.rails.output.streaming, messages=messages, prompt=prompt, ) @@ -1449,7 +1505,7 @@ def process_events( self.process_events_async(events, state, blocking) ) - def register_action(self, action: callable, name: Optional[str] = None) -> Self: + def register_action(self, action: Callable, name: Optional[str] = None) -> Self: """Register a custom action for the rails configuration.""" self.runtime.register_action(action, name) return self @@ -1459,12 +1515,12 @@ def register_action_param(self, name: str, value: Any) -> Self: self.runtime.register_action_param(name, value) return self - def register_filter(self, filter_fn: callable, name: Optional[str] = None) -> Self: + def register_filter(self, filter_fn: Callable, name: Optional[str] = None) -> Self: """Register a custom filter for the rails configuration.""" self.runtime.llm_task_manager.register_filter(filter_fn, name) return self - def register_output_parser(self, output_parser: callable, name: str) -> Self: + def register_output_parser(self, output_parser: Callable, name: str) -> Self: """Register a custom output parser for the rails configuration.""" self.runtime.llm_task_manager.register_output_parser(output_parser, name) return self @@ -1509,6 +1565,8 @@ def register_embedding_provider( def explain(self) -> ExplainInfo: """Helper function to return the latest ExplainInfo object.""" + if self.explain_info is None: + self.explain_info = self._ensure_explain_info() return self.explain_info def __getstate__(self): @@ -1524,6 +1582,7 @@ def __setstate__(self, state): async def _run_output_rails_in_streaming( self, streaming_handler: AsyncIterator[str], + output_rails_streaming_config: OutputRailsStreamingConfig, prompt: Optional[str] = None, messages: Optional[List[dict]] = None, stream_first: Optional[bool] = None, @@ -1626,7 +1685,6 @@ def _prepare_params( **action_params, } - output_rails_streaming_config = self.config.rails.output.streaming buffer_strategy = get_buffer_strategy(output_rails_streaming_config) output_rails_flows_id = self.config.rails.output.flows stream_first = stream_first or output_rails_streaming_config.stream_first @@ -1701,9 +1759,10 @@ def _prepare_params( pass else: # if there are any stop events, content was blocked or internal error occurred - if result.events: + result_events = getattr(result, "events", None) + if result_events: # extract the flow info from the first stop event - stop_event = result.events[0] + stop_event = result_events[0] blocked_flow = stop_event.get("flow_id", "output rails") error_type = stop_event.get("error_type") diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index 51c712f03..a64e8cedb 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Generation options give more control over the generation and the result. +"""Generation options give more control over the generation and the result. For example, to run only the input rails:: @@ -76,6 +76,7 @@ # {..., log: {"llm_calls": [...]}} """ + from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field, root_validator @@ -146,7 +147,7 @@ class GenerationOptions(BaseModel): default=None, description="Additional parameters that should be used for the LLM call", ) - llm_output: Optional[bool] = Field( + llm_output: bool = Field( default=False, description="Whether the response should also include any custom LLM output.", ) @@ -221,7 +222,7 @@ class ActivatedRail(BaseModel): ) decisions: List[str] = Field( default_factory=list, - descriptino="A sequence of decisions made by the rail, e.g., 'bot refuse to respond', 'stop', 'continue'.", + description="A sequence of decisions made by the rail, e.g., 'bot refuse to respond', 'stop', 'continue'.", ) executed_actions: List[ExecutedAction] = Field( default_factory=list, description="The list of actions executed by the rail." @@ -315,7 +316,7 @@ def print_summary(self): duration = 0 print(f"- Total time: {self.stats.total_duration:.2f}s") - if self.stats.input_rails_duration: + if self.stats.input_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.input_rails_duration / self.stats.total_duration, 2 ) @@ -323,7 +324,7 @@ def print_summary(self): duration += self.stats.input_rails_duration print(f" - [{self.stats.input_rails_duration:.2f}s][{_pc}%]: INPUT Rails") - if self.stats.dialog_rails_duration: + if self.stats.dialog_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.dialog_rails_duration / self.stats.total_duration, 2 ) @@ -333,7 +334,7 @@ def print_summary(self): print( f" - [{self.stats.dialog_rails_duration:.2f}s][{_pc}%]: DIALOG Rails" ) - if self.stats.generation_rails_duration: + if self.stats.generation_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.generation_rails_duration / self.stats.total_duration, 2, @@ -344,7 +345,7 @@ def print_summary(self): print( f" - [{self.stats.generation_rails_duration:.2f}s][{_pc}%]: GENERATION Rails" ) - if self.stats.output_rails_duration: + if self.stats.output_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.output_rails_duration / self.stats.total_duration, 2 ) @@ -355,12 +356,12 @@ def print_summary(self): f" - [{self.stats.output_rails_duration:.2f}s][{_pc}%]: OUTPUT Rails" ) - processing_overhead = self.stats.total_duration - duration + processing_overhead = (self.stats.total_duration or 0) - duration if processing_overhead >= 0.01: _pc = round(100 - pc, 2) print(f" - [{processing_overhead:.2f}s][{_pc}%]: Processing overhead ") - if self.stats.llm_calls_count > 0: + if self.stats.llm_calls_count and self.stats.llm_calls_count > 0: print( f"- {self.stats.llm_calls_count} LLM calls, " f"{self.stats.llm_calls_duration:.2f}s total duration, " @@ -379,7 +380,10 @@ def print_summary(self): for action in activated_rail.executed_actions: llm_calls_count += len(action.llm_calls) llm_calls_durations.extend( - [f"{round(llm_call.duration, 2)}s" for llm_call in action.llm_calls] + [ + f"{round(llm_call.duration or 0, 2)}s" + for llm_call in action.llm_calls + ] ) print( f"- [{activated_rail.duration:.2f}s] {activated_rail.type.upper()} ({activated_rail.name}): " @@ -411,4 +415,6 @@ class GenerationResponse(BaseModel): if __name__ == "__main__": - print(GenerationOptions(**{"rails": {"input": False}})) + print( + GenerationOptions(rails=GenerationRailsOptions(input=False)) + ) # pragma: no cover (Can't run as script for test coverage) diff --git a/nemoguardrails/utils.py b/nemoguardrails/utils.py index a337a978f..bc27a6c74 100644 --- a/nemoguardrails/utils.py +++ b/nemoguardrails/utils.py @@ -375,7 +375,7 @@ def get_railsignore_patterns(railsignore_path: Path) -> Set[str]: return ignored_patterns -def is_ignored_by_railsignore(filename: str, ignore_patterns: str) -> bool: +def is_ignored_by_railsignore(filename: str, ignore_patterns: Set[str]) -> bool: """Verify if a filename should be ignored by a railsignore pattern""" ignore = False diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index 7b4a3cfe1..f79dbc0ad 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +from unittest.mock import MagicMock + import pytest +from langchain.llms.base import BaseLLM from pydantic import ValidationError -from nemoguardrails.rails.llm.config import ( - Document, - Instruction, - Model, - RailsConfig, - TaskPrompt, -) +from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt +from nemoguardrails.rails.llm.llmrails import LLMRails def test_task_prompt_valid_content(): @@ -307,3 +306,76 @@ def test_rails_config_none_config_path(): result2 = config3 + config4 assert result2.config_path == "" + + +def test_llm_rails_configure_streaming_with_attr(): + """Check LLM has the streaming attribute set if RailsConfig has it""" + + mock_llm = MagicMock(spec=BaseLLM) + config = RailsConfig( + models=[], + streaming=True, + ) + + rails = LLMRails(config, llm=mock_llm) + setattr(mock_llm, "streaming", None) + rails._configure_main_llm_streaming(llm=mock_llm) + + assert mock_llm.streaming + + +def test_llm_rails_configure_streaming_without_attr(caplog): + """Check LLM has the streaming attribute set if RailsConfig has it""" + + mock_llm = MagicMock(spec=BaseLLM) + config = RailsConfig( + models=[], + streaming=True, + ) + + rails = LLMRails(config, llm=mock_llm) + rails._configure_main_llm_streaming(mock_llm) + + assert caplog.messages[-1] == "Provided main LLM does not support streaming." + + +def test_rails_config_streaming_supported_no_output_flows(): + """Check `streaming_supported` property doesn't depend on RailsConfig.streaming with no output flows""" + + config = RailsConfig( + models=[], + streaming=False, + ) + assert config.streaming_supported + + +def test_rails_config_flows_streaming_supported_true(): + """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" + + rails = { + "output": { + "flows": ["content_safety_check_output"], + "streaming": {"enabled": True}, + } + } + prompts = [{"task": "content safety check output", "content": "..."}] + rails_config = RailsConfig.model_validate( + {"models": [], "rails": rails, "prompts": prompts} + ) + assert rails_config.streaming_supported + + +def test_rails_config_flows_streaming_supported_false(): + """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" + + rails = { + "output": { + "flows": ["content_safety_check_output"], + "streaming": {"enabled": False}, + } + } + prompts = [{"task": "content safety check output", "content": "..."}] + rails_config = RailsConfig.model_validate( + {"models": [], "rails": rails, "prompts": prompts} + ) + assert not rails_config.streaming_supported diff --git a/tests/test_llm_options.py b/tests/test_llm_options.py new file mode 100644 index 000000000..72226afda --- /dev/null +++ b/tests/test_llm_options.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LLM isolation with models that don't have model_kwargs field.""" + +from typing import Any, Dict, List, Optional +from unittest.mock import Mock + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from pydantic import BaseModel, Field + +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.llmrails import LLMRails +from nemoguardrails.rails.llm.options import GenerationLog, GenerationStats + + +def test_generation_log_print_summary(capsys): + """Test printing rais stats with dummy data""" + + stats = GenerationStats( + input_rails_duration=1.0, + dialog_rails_duration=2.0, + generation_rails_duration=3.0, + output_rails_duration=4.0, + total_duration=10.0, # Sum of all previous rail durations + llm_calls_duration=8.0, # Less than total duration + llm_calls_count=4, # Input, dialog, generation and output calls + llm_calls_total_prompt_tokens=1000, + llm_calls_total_completion_tokens=2000, + llm_calls_total_tokens=3000, # Sum of prompt and completion tokens + ) + + generation_log = GenerationLog(activated_rails=[], stats=stats) + + generation_log.print_summary() + capture = capsys.readouterr() + capture_lines = capture.out.splitlines() + + # Check the correct times were printed + assert capture_lines[1] == "# General stats" + assert capture_lines[3] == "- Total time: 10.00s" + assert capture_lines[4] == " - [1.00s][10.0%]: INPUT Rails" + assert capture_lines[5] == " - [2.00s][20.0%]: DIALOG Rails" + assert capture_lines[6] == " - [3.00s][30.0%]: GENERATION Rails" + assert capture_lines[7] == " - [4.00s][40.0%]: OUTPUT Rails" + assert ( + capture_lines[8] + == "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens." + )