From f5b2d1f24a9b8991d6a9a7c1c171679791b8b343 Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Fri, 15 Aug 2025 17:16:56 -0400 Subject: [PATCH 1/7] feat: Modify endpoints for OpenAPI compatibility --- nemoguardrails/server/api.py | 317 ++++++++++++++++++++++---- tests/test_api.py | 65 +++++- tests/test_server_calls_with_state.py | 13 +- tests/test_threads.py | 8 +- 4 files changed, 343 insertions(+), 60 deletions(-) diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 658cffd01..9ccb4a074 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -20,6 +20,7 @@ import os.path import re import time +import uuid import warnings from contextlib import asynccontextmanager from typing import Any, Callable, List, Optional @@ -32,7 +33,6 @@ from nemoguardrails import LLMRails, RailsConfig, utils from nemoguardrails.rails.llm.options import ( - GenerationLog, GenerationOptions, GenerationResponse, ) @@ -88,9 +88,9 @@ async def lifespan(app: GuardrailsApp): # If there is a `config.yml` in the root `app.rails_config_path`, then # that means we are in single config mode. - if os.path.exists(os.path.join(app.rails_config_path, "config.yml")) or os.path.exists( - os.path.join(app.rails_config_path, "config.yaml") - ): + if os.path.exists( + os.path.join(app.rails_config_path, "config.yml") + ) or os.path.exists(os.path.join(app.rails_config_path, "config.yaml")): app.single_config_mode = True app.single_config_id = os.path.basename(app.rails_config_path) else: @@ -228,16 +228,63 @@ class RequestBody(BaseModel): default=None, description="A state object that should be used to continue the interaction.", ) + # Standard OpenAI completion parameters + model: Optional[str] = Field( + default=None, + description="The model to use for chat completion. Maps to config_id for backward compatibility.", + ) + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate.", + ) + temperature: Optional[float] = Field( + default=None, + description="Sampling temperature to use.", + ) + top_p: Optional[float] = Field( + default=None, + description="Top-p sampling parameter.", + ) + stop: Optional[str] = Field( + default=None, + description="Stop sequences.", + ) + presence_penalty: Optional[float] = Field( + default=None, + description="Presence penalty parameter.", + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="Frequency penalty parameter.", + ) + function_call: Optional[dict] = Field( + default=None, + description="Function call parameter.", + ) + logit_bias: Optional[dict] = Field( + default=None, + description="Logit bias parameter.", + ) + log_probs: Optional[bool] = Field( + default=None, + description="Log probabilities parameter.", + ) @root_validator(pre=True) def ensure_config_id(cls, data: Any) -> Any: if isinstance(data, dict): + if data.get("model") is not None and data.get("config_id") is None: + data["config_id"] = data["model"] if data.get("config_id") is not None and data.get("config_ids") is not None: - raise ValueError("Only one of config_id or config_ids should be specified") + raise ValueError( + "Only one of config_id or config_ids should be specified" + ) if data.get("config_id") is None and data.get("config_ids") is not None: data["config_id"] = None if data.get("config_id") is None and data.get("config_ids") is None: - warnings.warn("No config_id or config_ids provided, using default config_id") + warnings.warn( + "No config_id or config_ids provided, using default config_id" + ) return data @validator("config_ids", pre=True, always=True) @@ -248,23 +295,115 @@ def ensure_config_ids(cls, v, values): return v +class Choice(BaseModel): + index: Optional[int] = Field( + default=None, description="The index of the choice in the list of choices." + ) + messages: Optional[dict] = Field( + default=None, description="The message of the choice" + ) + logprobs: Optional[dict] = Field( + default=None, description="The log probabilities of the choice" + ) + finish_reason: Optional[str] = Field( + default=None, description="The reason the model stopped generating tokens." + ) + + class ResponseBody(BaseModel): - messages: Optional[List[dict]] = Field(default=None, description="The new messages in the conversation") - llm_output: Optional[dict] = Field( - default=None, - description="Contains any additional output coming from the LLM.", + # OpenAI-compatible fields + id: Optional[str] = Field( + default=None, description="A unique identifier for the chat completion." ) - output_data: Optional[dict] = Field( + object: str = Field( + default="chat.completion", + description="The object type, which is always chat.completion", + ) + created: Optional[int] = Field( default=None, - description="The output data, i.e. a dict with the values corresponding to the `output_vars`.", + description="The Unix timestamp (in seconds) of when the chat completion was created.", ) - log: Optional[GenerationLog] = Field(default=None, description="Additional logging information.") + model: Optional[str] = Field( + default=None, description="The model used for the chat completion." + ) + choices: Optional[List[Choice]] = Field( + default=None, description="A list of chat completion choices." + ) + # NeMo-Guardrails specific fields for backward compatibility state: Optional[dict] = Field( - default=None, - description="A state object that should be used to continue the interaction in the future.", + default=None, description="State object for continuing the conversation." + ) + llm_output: Optional[dict] = Field( + default=None, description="Additional LLM output data." + ) + output_data: Optional[dict] = Field( + default=None, description="Additional output data." + ) + log: Optional[dict] = Field(default=None, description="Generation log data.") + + +class Model(BaseModel): + id: str = Field( + description="The model identifier, which can be referenced in the API endpoints." + ) + object: str = Field( + default="model", description="The object type, which is always 'model'." + ) + created: int = Field( + description="The Unix timestamp (in seconds) of when the model was created." + ) + owned_by: str = Field( + default="nemo-guardrails", description="The organization that owns the model." ) +class ModelsResponse(BaseModel): + object: str = Field( + default="list", description="The object type, which is always 'list'." + ) + data: List[Model] = Field(description="The list of models.") + + +@app.get( + "/v1/models", + response_model=ModelsResponse, + summary="List available models", + description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.", +) +async def get_models(): + """Returns the list of available models (guardrails configurations) in OpenAI-compatible format.""" + + # Use the same logic as get_rails_configs to find available configurations + if app.single_config_mode: + config_ids = [app.single_config_id] if app.single_config_id else [] + else: + config_ids = [ + f + for f in os.listdir(app.rails_config_path) + if os.path.isdir(os.path.join(app.rails_config_path, f)) + and f[0] != "." + and f[0] != "_" + # Filter out all the configs for which there is no `config.yml` file. + and ( + os.path.exists(os.path.join(app.rails_config_path, f, "config.yml")) + or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml")) + ) + ] + + # Convert configurations to OpenAI model format + models = [] + for config_id in config_ids: + model = Model( + id=config_id, + object="model", + created=int(time.time()), # Use current time as created timestamp + owned_by="nemo-guardrails", + ) + models.append(model) + + return ModelsResponse(data=models) + + @app.get( "/v1/rails/configs", summary="Get List of available rails configurations.", @@ -350,7 +489,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails: llm_rails_instances[configs_cache_key] = llm_rails # If we have a cache for the events, we restore it - llm_rails.events_history_cache = llm_rails_events_history_cache.get(configs_cache_key, {}) + llm_rails.events_history_cache = llm_rails_events_history_cache.get( + configs_cache_key, {} + ) return llm_rails @@ -367,7 +508,9 @@ async def chat_completion(body: RequestBody, request: Request): """ log.info("Got request for config %s", body.config_id) for logger in registered_loggers: - asyncio.get_event_loop().create_task(logger({"endpoint": "/v1/chat/completions", "body": body.json()})) + asyncio.get_event_loop().create_task( + logger({"endpoint": "/v1/chat/completions", "body": body.json()}) + ) # Save the request headers in a context variable. api_request_headers.set(request.headers) @@ -379,20 +522,31 @@ async def chat_completion(body: RequestBody, request: Request): if app.default_config_id: config_ids = [app.default_config_id] else: - raise GuardrailsConfigurationError("No request config_ids provided and server has no default configuration") + raise GuardrailsConfigurationError( + "No request config_ids provided and server has no default configuration" + ) try: llm_rails = _get_rails(config_ids) except ValueError as ex: log.exception(ex) return ResponseBody( - messages=[ - { - "role": "assistant", - "content": f"Could not load the {config_ids} guardrails configuration. " - f"An internal error has occurred.", - } - ] + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=config_ids[0] if config_ids else None, + choices=[ + Choice( + index=0, + messages={ + "content": f"Could not load the {config_ids} guardrails configuration. " + f"An internal error has occurred.", + "role": "assistant", + }, + finish_reason="error", + logprobs=None, + ) + ], ) try: @@ -410,12 +564,21 @@ async def chat_completion(body: RequestBody, request: Request): # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: return ResponseBody( - messages=[ - { - "role": "assistant", - "content": "The `thread_id` must have a minimum length of 16 characters.", - } - ] + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=None, + choices=[ + Choice( + index=0, + messages={ + "content": "The `thread_id` must have a minimum length of 16 characters.", + "role": "assistant", + }, + finish_reason="error", + logprobs=None, + ) + ], ) # Fetch the existing thread messages. For easier management, we prepend @@ -426,7 +589,29 @@ async def chat_completion(body: RequestBody, request: Request): # And prepend them. messages = thread_messages + messages - if body.stream and llm_rails.config.streaming_supported and llm_rails.main_llm_supports_streaming: + # Map OpenAI-compatible parameters to generation options + generation_options = body.options + # Initialize llm_params if not already set + if generation_options.llm_params is None: + generation_options.llm_params = {} + if body.max_tokens: + generation_options.llm_params["max_tokens"] = body.max_tokens + if body.temperature is not None: + generation_options.llm_params["temperature"] = body.temperature + if body.top_p is not None: + generation_options.llm_params["top_p"] = body.top_p + if body.stop: + generation_options.llm_params["stop"] = body.stop + if body.presence_penalty is not None: + generation_options.llm_params["presence_penalty"] = body.presence_penalty + if body.frequency_penalty is not None: + generation_options.llm_params["frequency_penalty"] = body.frequency_penalty + + if ( + body.stream + and llm_rails.config.streaming_supported + and llm_rails.main_llm_supports_streaming + ): # Create the streaming handler instance streaming_handler = StreamingHandler() @@ -435,16 +620,16 @@ async def chat_completion(body: RequestBody, request: Request): llm_rails.generate_async( messages=messages, streaming_handler=streaming_handler, - options=body.options, + options=generation_options, state=body.state, ) ) - # TODO: Add support for thread_ids in streaming mode - return StreamingResponse(streaming_handler) else: - res = await llm_rails.generate_async(messages=messages, options=body.options, state=body.state) + res = await llm_rails.generate_async( + messages=messages, options=generation_options, state=body.state + ) if isinstance(res, GenerationResponse): bot_message_content = res.response[0] @@ -462,20 +647,50 @@ async def chat_completion(body: RequestBody, request: Request): if body.thread_id and datastore is not None and datastore_key is not None: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - result = ResponseBody(messages=[bot_message]) - - # If we have additional GenerationResponse fields, we return as well + # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions + response_kwargs = { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": config_ids[0] if config_ids else None, + "choices": [ + Choice( + index=0, + messages=bot_message, + finish_reason="stop", + logprobs=None, + ) + ], + } + + # If we have additional GenerationResponse fields, include them for backward compatibility if isinstance(res, GenerationResponse): - result.llm_output = res.llm_output - result.output_data = res.output_data - result.log = res.log - result.state = res.state + response_kwargs["llm_output"] = res.llm_output + response_kwargs["output_data"] = res.output_data + response_kwargs["log"] = res.log + response_kwargs["state"] = res.state - return result + return ResponseBody(**response_kwargs) except Exception as ex: log.exception(ex) - return ResponseBody(messages=[{"role": "assistant", "content": "Internal server error."}]) + return ResponseBody( + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=None, + choices=[ + Choice( + index=0, + messages={ + "content": "Internal server error", + "role": "assistant", + }, + finish_reason="error", + logprobs=None, + ) + ], + ) # By default, there are no challenges @@ -525,7 +740,9 @@ def on_any_event(self, event): return None elif event.event_type == "created" or event.event_type == "modified": - log.info(f"Watchdog received {event.event_type} event for file {event.src_path}") + log.info( + f"Watchdog received {event.event_type} event for file {event.src_path}" + ) # Compute the relative path src_path_str = str(event.src_path) @@ -549,7 +766,9 @@ def on_any_event(self, event): # We save the events history cache, to restore it on the new instance llm_rails_events_history_cache[config_id] = val - log.info(f"Configuration {config_id} has changed. Clearing cache.") + log.info( + f"Configuration {config_id} has changed. Clearing cache." + ) observer = Observer() event_handler = Handler() @@ -564,7 +783,9 @@ def on_any_event(self, event): except ImportError: # Since this is running in a separate thread, we just print the error. - print("The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`.") + print( + "The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`." + ) # Force close everything. os._exit(-1) diff --git a/tests/test_api.py b/tests/test_api.py index b6619fe7a..0a5966f2b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -41,6 +41,26 @@ def test_get(): assert len(result) > 0 +def test_get_models(): + """Test the OpenAI-compatible /v1/models endpoint.""" + response = client.get("/v1/models") + assert response.status_code == 200 + + result = response.json() + + # Check OpenAI models list format + assert result["object"] == "list" + assert "data" in result + assert len(result["data"]) > 0 + + # Check each model has the required OpenAI format + for model in result["data"]: + assert "id" in model + assert model["object"] == "model" + assert "created" in model + assert model["owned_by"] == "nemo-guardrails" + + @pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") def test_chat_completion(): response = client.post( @@ -57,8 +77,14 @@ def test_chat_completion(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" @pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") @@ -78,8 +104,14 @@ def test_chat_completion_with_default_configs(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" def test_request_body_validation(): @@ -113,6 +145,31 @@ def test_request_body_validation(): assert request_body.config_ids is None +def test_openai_model_field_mapping(): + """Test OpenAI-compatible model field mapping to config_id.""" + + # Test model field maps to config_id + data = { + "model": "test_model", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_model" + assert request_body.config_ids == ["test_model"] + + # Test model and config_id both provided (config_id takes precedence) + data = { + "model": "test_model", + "config_id": "test_config", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_config" + assert request_body.config_ids == ["test_config"] + + def test_request_body_state(): """Test RequestBody state handling.""" data = { diff --git a/tests/test_server_calls_with_state.py b/tests/test_server_calls_with_state.py index 051096432..1d932efe6 100644 --- a/tests/test_server_calls_with_state.py +++ b/tests/test_server_calls_with_state.py @@ -37,8 +37,9 @@ def _test_call(config_id): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + print(res) + assert len(res["choices"][0]["messages"]) == 2 + assert res["choices"][0]["messages"]["content"] == "Hello!" assert res.get("state") # When making a second call with the returned state, the conversations should continue @@ -51,13 +52,17 @@ def _test_call(config_id): { "content": "hi", "role": "user", - } + }, + { + "content": "hi", + "role": "assistant", + }, ], "state": res["state"], }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["messages"]["content"] == "Hello again!" def test_1(): diff --git a/tests/test_threads.py b/tests/test_threads.py index 88946007b..5ce14fe6c 100644 --- a/tests/test_threads.py +++ b/tests/test_threads.py @@ -51,8 +51,8 @@ def test_1(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + assert len(res["choices"][0]["messages"]) == 2 + assert res["choices"][0]["messages"]["content"] == "Hello!" # When making a second call with the same thread_id, the conversations should continue # and we should get the "Hello again!" message. @@ -70,7 +70,7 @@ def test_1(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["messages"]["content"] == "Hello again!" @pytest.mark.parametrize( @@ -138,4 +138,4 @@ def test_with_redis(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"]["messages"][0]["content"] == "Hello again!" From 9970296b654004f035d81384b7963076f8e9082b Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Wed, 10 Sep 2025 12:27:07 -0400 Subject: [PATCH 2/7] fix: Colang 2.x doesn't support assistant messages --- nemoguardrails/colang/v2_x/runtime/runtime.py | 12 ++++++++---- tests/test_server_calls_with_state.py | 8 +++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index 6980714bc..9cbbcb776 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -31,6 +31,7 @@ ColangSyntaxError, ) from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus +from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state from nemoguardrails.colang.v2_x.runtime.statemachine import ( FlowConfig, InternalEvent, @@ -394,10 +395,13 @@ async def process_events( state = State(flow_states={}, flow_configs=self.flow_configs, rails_config=self.config) initialize_state(state) elif isinstance(state, dict): - # TODO: Implement dict to State conversion - raise NotImplementedError() - # if isinstance(state, dict): - # state = State.from_dict(state) + # Convert dict to State object + if state.get("version") == "2.x" and "state" in state: + # Handle the serialized state format from API calls + state = json_to_state(state["state"]) + else: + # TODO: Implement other dict to State conversion formats if needed + raise NotImplementedError("Unsupported state dict format") assert isinstance(state, State) assert state.main_flow_state is not None diff --git a/tests/test_server_calls_with_state.py b/tests/test_server_calls_with_state.py index 1d932efe6..07ee4bc5d 100644 --- a/tests/test_server_calls_with_state.py +++ b/tests/test_server_calls_with_state.py @@ -44,6 +44,8 @@ def _test_call(config_id): # When making a second call with the returned state, the conversations should continue # and we should get the "Hello again!" message. + # For Colang 2.x, we only send the new user message, not the conversation history + # since the state maintains the conversation context. response = client.post( "/v1/chat/completions", json={ @@ -52,11 +54,7 @@ def _test_call(config_id): { "content": "hi", "role": "user", - }, - { - "content": "hi", - "role": "assistant", - }, + } ], "state": res["state"], }, From 4cdf23292850e8980852887bfe666b8bd9a4a3a2 Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Fri, 24 Oct 2025 10:26:51 -0400 Subject: [PATCH 3/7] chore: Move OpenAPI schema and fix typos --- nemoguardrails/server/api.py | 139 +++-------------------- nemoguardrails/server/schemas/openai.py | 143 ++++++++++++++++++++++++ tests/test_server_calls_with_state.py | 6 +- tests/test_threads.py | 8 +- 4 files changed, 168 insertions(+), 128 deletions(-) create mode 100644 nemoguardrails/server/schemas/openai.py diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 9ccb4a074..028c2beba 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import contextvars import importlib.util @@ -27,16 +28,20 @@ from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field, root_validator, validator +from pydantic import Field, root_validator, validator from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from nemoguardrails import LLMRails, RailsConfig, utils -from nemoguardrails.rails.llm.options import ( - GenerationOptions, - GenerationResponse, -) +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse from nemoguardrails.server.datastore.datastore import DataStore +from nemoguardrails.server.schemas.openai import ( + Choice, + Model, + ModelsResponse, + OpenAIRequestFields, + ResponseBody, +) from nemoguardrails.streaming import StreamingHandler logging.basicConfig(level=logging.INFO) @@ -190,7 +195,7 @@ async def root_handler(): app.single_config_id = None -class RequestBody(BaseModel): +class RequestBody(OpenAIRequestFields): config_id: Optional[str] = Field( default=os.getenv("DEFAULT_CONFIG_ID", None), description="The id of the configuration to be used. If not set, the default configuration will be used.", @@ -228,47 +233,6 @@ class RequestBody(BaseModel): default=None, description="A state object that should be used to continue the interaction.", ) - # Standard OpenAI completion parameters - model: Optional[str] = Field( - default=None, - description="The model to use for chat completion. Maps to config_id for backward compatibility.", - ) - max_tokens: Optional[int] = Field( - default=None, - description="The maximum number of tokens to generate.", - ) - temperature: Optional[float] = Field( - default=None, - description="Sampling temperature to use.", - ) - top_p: Optional[float] = Field( - default=None, - description="Top-p sampling parameter.", - ) - stop: Optional[str] = Field( - default=None, - description="Stop sequences.", - ) - presence_penalty: Optional[float] = Field( - default=None, - description="Presence penalty parameter.", - ) - frequency_penalty: Optional[float] = Field( - default=None, - description="Frequency penalty parameter.", - ) - function_call: Optional[dict] = Field( - default=None, - description="Function call parameter.", - ) - logit_bias: Optional[dict] = Field( - default=None, - description="Logit bias parameter.", - ) - log_probs: Optional[bool] = Field( - default=None, - description="Log probabilities parameter.", - ) @root_validator(pre=True) def ensure_config_id(cls, data: Any) -> Any: @@ -295,75 +259,6 @@ def ensure_config_ids(cls, v, values): return v -class Choice(BaseModel): - index: Optional[int] = Field( - default=None, description="The index of the choice in the list of choices." - ) - messages: Optional[dict] = Field( - default=None, description="The message of the choice" - ) - logprobs: Optional[dict] = Field( - default=None, description="The log probabilities of the choice" - ) - finish_reason: Optional[str] = Field( - default=None, description="The reason the model stopped generating tokens." - ) - - -class ResponseBody(BaseModel): - # OpenAI-compatible fields - id: Optional[str] = Field( - default=None, description="A unique identifier for the chat completion." - ) - object: str = Field( - default="chat.completion", - description="The object type, which is always chat.completion", - ) - created: Optional[int] = Field( - default=None, - description="The Unix timestamp (in seconds) of when the chat completion was created.", - ) - model: Optional[str] = Field( - default=None, description="The model used for the chat completion." - ) - choices: Optional[List[Choice]] = Field( - default=None, description="A list of chat completion choices." - ) - # NeMo-Guardrails specific fields for backward compatibility - state: Optional[dict] = Field( - default=None, description="State object for continuing the conversation." - ) - llm_output: Optional[dict] = Field( - default=None, description="Additional LLM output data." - ) - output_data: Optional[dict] = Field( - default=None, description="Additional output data." - ) - log: Optional[dict] = Field(default=None, description="Generation log data.") - - -class Model(BaseModel): - id: str = Field( - description="The model identifier, which can be referenced in the API endpoints." - ) - object: str = Field( - default="model", description="The object type, which is always 'model'." - ) - created: int = Field( - description="The Unix timestamp (in seconds) of when the model was created." - ) - owned_by: str = Field( - default="nemo-guardrails", description="The organization that owns the model." - ) - - -class ModelsResponse(BaseModel): - object: str = Field( - default="list", description="The object type, which is always 'list'." - ) - data: List[Model] = Field(description="The list of models.") - - @app.get( "/v1/models", response_model=ModelsResponse, @@ -538,7 +433,7 @@ async def chat_completion(body: RequestBody, request: Request): choices=[ Choice( index=0, - messages={ + message={ "content": f"Could not load the {config_ids} guardrails configuration. " f"An internal error has occurred.", "role": "assistant", @@ -571,7 +466,7 @@ async def chat_completion(body: RequestBody, request: Request): choices=[ Choice( index=0, - messages={ + message={ "content": "The `thread_id` must have a minimum length of 16 characters.", "role": "assistant", }, @@ -589,11 +484,13 @@ async def chat_completion(body: RequestBody, request: Request): # And prepend them. messages = thread_messages + messages - # Map OpenAI-compatible parameters to generation options generation_options = body.options + # Initialize llm_params if not already set if generation_options.llm_params is None: generation_options.llm_params = {} + + # Set OpenAI-compatible parameters in llm_params if body.max_tokens: generation_options.llm_params["max_tokens"] = body.max_tokens if body.temperature is not None: @@ -656,7 +553,7 @@ async def chat_completion(body: RequestBody, request: Request): "choices": [ Choice( index=0, - messages=bot_message, + message=bot_message, finish_reason="stop", logprobs=None, ) @@ -682,7 +579,7 @@ async def chat_completion(body: RequestBody, request: Request): choices=[ Choice( index=0, - messages={ + message={ "content": "Internal server error", "role": "assistant", }, diff --git a/nemoguardrails/server/schemas/openai.py b/nemoguardrails/server/schemas/openai.py new file mode 100644 index 000000000..99ad1a700 --- /dev/null +++ b/nemoguardrails/server/schemas/openai.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenAI API schema definitions for the NeMo Guardrails server.""" + +from typing import List, Optional, Union + +from pydantic import BaseModel, Field + + +class OpenAIRequestFields(BaseModel): + """OpenAI API request fields that can be mixed into other request schemas.""" + + # Standard OpenAI completion parameters + model: Optional[str] = Field( + default=None, + description="The model to use for chat completion. Maps to config_id for backward compatibility.", + ) + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate.", + ) + temperature: Optional[float] = Field( + default=None, + description="Sampling temperature to use.", + ) + top_p: Optional[float] = Field( + default=None, + description="Top-p sampling parameter.", + ) + stop: Optional[Union[str, List[str]]] = Field( + default=None, + description="Stop sequences.", + ) + presence_penalty: Optional[float] = Field( + default=None, + description="Presence penalty parameter.", + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="Frequency penalty parameter.", + ) + function_call: Optional[dict] = Field( + default=None, + description="Function call parameter.", + ) + logit_bias: Optional[dict] = Field( + default=None, + description="Logit bias parameter.", + ) + log_probs: Optional[bool] = Field( + default=None, + description="Log probabilities parameter.", + ) + + +class Choice(BaseModel): + """OpenAI API choice structure in chat completion responses.""" + + index: Optional[int] = Field( + default=None, description="The index of the choice in the list of choices." + ) + message: Optional[dict] = Field( + default=None, description="The message of the choice" + ) + logprobs: Optional[dict] = Field( + default=None, description="The log probabilities of the choice" + ) + finish_reason: Optional[str] = Field( + default=None, description="The reason the model stopped generating tokens." + ) + + +class ResponseBody(BaseModel): + """OpenAI API response body with NeMo-Guardrails extensions.""" + + # OpenAI API fields + id: Optional[str] = Field( + default=None, description="A unique identifier for the chat completion." + ) + object: str = Field( + default="chat.completion", + description="The object type, which is always chat.completion", + ) + created: Optional[int] = Field( + default=None, + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + model: Optional[str] = Field( + default=None, description="The model used for the chat completion." + ) + choices: Optional[List[Choice]] = Field( + default=None, description="A list of chat completion choices." + ) + # NeMo-Guardrails specific fields for backward compatibility + state: Optional[dict] = Field( + default=None, description="State object for continuing the conversation." + ) + llm_output: Optional[dict] = Field( + default=None, description="Additional LLM output data." + ) + output_data: Optional[dict] = Field( + default=None, description="Additional output data." + ) + log: Optional[dict] = Field(default=None, description="Generation log data.") + + +class Model(BaseModel): + """OpenAI API model representation.""" + + id: str = Field( + description="The model identifier, which can be referenced in the API endpoints." + ) + object: str = Field( + default="model", description="The object type, which is always 'model'." + ) + created: int = Field( + description="The Unix timestamp (in seconds) of when the model was created." + ) + owned_by: str = Field( + default="nemo-guardrails", description="The organization that owns the model." + ) + + +class ModelsResponse(BaseModel): + """OpenAI API models list response.""" + + object: str = Field( + default="list", description="The object type, which is always 'list'." + ) + data: List[Model] = Field(description="The list of models.") diff --git a/tests/test_server_calls_with_state.py b/tests/test_server_calls_with_state.py index 07ee4bc5d..736f2592c 100644 --- a/tests/test_server_calls_with_state.py +++ b/tests/test_server_calls_with_state.py @@ -38,8 +38,8 @@ def _test_call(config_id): assert response.status_code == 200 res = response.json() print(res) - assert len(res["choices"][0]["messages"]) == 2 - assert res["choices"][0]["messages"]["content"] == "Hello!" + assert len(res["choices"][0]["message"]) == 2 + assert res["choices"][0]["message"]["content"] == "Hello!" assert res.get("state") # When making a second call with the returned state, the conversations should continue @@ -60,7 +60,7 @@ def _test_call(config_id): }, ) res = response.json() - assert res["choices"][0]["messages"]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" def test_1(): diff --git a/tests/test_threads.py b/tests/test_threads.py index 5ce14fe6c..3975e266f 100644 --- a/tests/test_threads.py +++ b/tests/test_threads.py @@ -51,8 +51,8 @@ def test_1(): ) assert response.status_code == 200 res = response.json() - assert len(res["choices"][0]["messages"]) == 2 - assert res["choices"][0]["messages"]["content"] == "Hello!" + assert len(res["choices"][0]["message"]) == 2 + assert res["choices"][0]["message"]["content"] == "Hello!" # When making a second call with the same thread_id, the conversations should continue # and we should get the "Hello again!" message. @@ -70,7 +70,7 @@ def test_1(): }, ) res = response.json() - assert res["choices"][0]["messages"]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" @pytest.mark.parametrize( @@ -138,4 +138,4 @@ def test_with_redis(): }, ) res = response.json() - assert res["choices"]["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" From aaac16167387b51c94440304497a3a65a142e68a Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Tue, 11 Nov 2025 15:23:47 -0500 Subject: [PATCH 4/7] Extend existing OpenAI types and add support for streaming chat completion --- nemoguardrails/server/api.py | 134 ++++++++++- nemoguardrails/server/schemas/openai.py | 105 +------- nemoguardrails/streaming.py | 55 ++++- poetry.lock | 2 +- pyproject.toml | 1 + tests/test_api.py | 303 +++++++++++++++++++++++- tests/test_threads.py | 3 +- 7 files changed, 484 insertions(+), 119 deletions(-) diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 028c2beba..eab308563 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -24,24 +24,20 @@ import uuid import warnings from contextlib import asynccontextmanager -from typing import Any, Callable, List, Optional +from typing import Any, AsyncIterator, Callable, List, Optional, Union from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from pydantic import Field, root_validator, validator +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.model import Model +from pydantic import BaseModel, Field, root_validator, validator from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from nemoguardrails import LLMRails, RailsConfig, utils from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse from nemoguardrails.server.datastore.datastore import DataStore -from nemoguardrails.server.schemas.openai import ( - Choice, - Model, - ModelsResponse, - OpenAIRequestFields, - ResponseBody, -) +from nemoguardrails.server.schemas.openai import ModelsResponse, ResponseBody from nemoguardrails.streaming import StreamingHandler logging.basicConfig(level=logging.INFO) @@ -195,7 +191,7 @@ async def root_handler(): app.single_config_id = None -class RequestBody(OpenAIRequestFields): +class RequestBody(ChatCompletion): config_id: Optional[str] = Field( default=os.getenv("DEFAULT_CONFIG_ID", None), description="The id of the configuration to be used. If not set, the default configuration will be used.", @@ -212,6 +208,50 @@ class RequestBody(OpenAIRequestFields): max_length=255, description="The id of an existing thread to which the messages should be added.", ) + model: Optional[str] = Field( + default=None, + description="The model used for the chat completion.", + ) + id: Optional[str] = Field( + default=None, + description="The id of the chat completion.", + ) + object: Optional[str] = Field( + default="chat.completion", + description="The object type, which is always chat.completion", + ) + created: Optional[int] = Field( + default=None, + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + choices: Optional[List[Choice]] = Field( + default=None, + description="The list of choices for the chat completion.", + ) + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate.", + ) + temperature: Optional[float] = Field( + default=None, + description="The temperature to use for the chat completion.", + ) + top_p: Optional[float] = Field( + default=None, + description="The top p to use for the chat completion.", + ) + stop: Optional[Union[str, List[str]]] = Field( + default=None, + description="The stop sequences to use for the chat completion.", + ) + presence_penalty: Optional[float] = Field( + default=None, + description="The presence penalty to use for the chat completion.", + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="The frequency penalty to use for the chat completion.", + ) messages: Optional[List[dict]] = Field( default=None, description="The list of messages in the current conversation." ) @@ -391,6 +431,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails: return llm_rails +async def _format_streaming_response( + streaming_handler: StreamingHandler, model_name: Optional[str] +) -> AsyncIterator[str]: + while True: + try: + chunk = await streaming_handler.__anext__() + except StopAsyncIteration: + # When the stream ends, yield the [DONE] message + yield "data: [DONE]\n\n" + break + + # Determine the payload format based on chunk type + if isinstance(chunk, dict): + # If chunk is a dict, wrap it in OpenAI chunk format with delta + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": chunk, + "index": None, + "finish_reason": None, + } + ], + } + elif isinstance(chunk, str): + try: + # Try parsing as JSON - if it parses, it might be a pre-formed payload + payload = json.loads(chunk) + except Exception: + # treat as plain text content token + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": {"content": chunk}, + "index": None, + "finish_reason": None, + } + ], + } + else: + # For any other type, treat as plain content + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": {"content": str(chunk)}, + "index": None, + "finish_reason": None, + } + ], + } + + # Send the payload as JSON + data = json.dumps(payload, ensure_ascii=False) + yield f"data: {data}\n\n" + + @app.post( "/v1/chat/completions", response_model=ResponseBody, @@ -522,7 +629,12 @@ async def chat_completion(body: RequestBody, request: Request): ) ) - return StreamingResponse(streaming_handler) + return StreamingResponse( + _format_streaming_response( + streaming_handler, model_name=config_ids[0] if config_ids else None + ), + media_type="text/event-stream", + ) else: res = await llm_rails.generate_async( messages=messages, options=generation_options, state=body.state diff --git a/nemoguardrails/server/schemas/openai.py b/nemoguardrails/server/schemas/openai.py index 99ad1a700..ce79a399a 100644 --- a/nemoguardrails/server/schemas/openai.py +++ b/nemoguardrails/server/schemas/openai.py @@ -15,96 +15,16 @@ """OpenAI API schema definitions for the NeMo Guardrails server.""" -from typing import List, Optional, Union +from typing import List, Optional +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.model import Model from pydantic import BaseModel, Field -class OpenAIRequestFields(BaseModel): - """OpenAI API request fields that can be mixed into other request schemas.""" - - # Standard OpenAI completion parameters - model: Optional[str] = Field( - default=None, - description="The model to use for chat completion. Maps to config_id for backward compatibility.", - ) - max_tokens: Optional[int] = Field( - default=None, - description="The maximum number of tokens to generate.", - ) - temperature: Optional[float] = Field( - default=None, - description="Sampling temperature to use.", - ) - top_p: Optional[float] = Field( - default=None, - description="Top-p sampling parameter.", - ) - stop: Optional[Union[str, List[str]]] = Field( - default=None, - description="Stop sequences.", - ) - presence_penalty: Optional[float] = Field( - default=None, - description="Presence penalty parameter.", - ) - frequency_penalty: Optional[float] = Field( - default=None, - description="Frequency penalty parameter.", - ) - function_call: Optional[dict] = Field( - default=None, - description="Function call parameter.", - ) - logit_bias: Optional[dict] = Field( - default=None, - description="Logit bias parameter.", - ) - log_probs: Optional[bool] = Field( - default=None, - description="Log probabilities parameter.", - ) - - -class Choice(BaseModel): - """OpenAI API choice structure in chat completion responses.""" - - index: Optional[int] = Field( - default=None, description="The index of the choice in the list of choices." - ) - message: Optional[dict] = Field( - default=None, description="The message of the choice" - ) - logprobs: Optional[dict] = Field( - default=None, description="The log probabilities of the choice" - ) - finish_reason: Optional[str] = Field( - default=None, description="The reason the model stopped generating tokens." - ) - - -class ResponseBody(BaseModel): +class ResponseBody(ChatCompletion): """OpenAI API response body with NeMo-Guardrails extensions.""" - # OpenAI API fields - id: Optional[str] = Field( - default=None, description="A unique identifier for the chat completion." - ) - object: str = Field( - default="chat.completion", - description="The object type, which is always chat.completion", - ) - created: Optional[int] = Field( - default=None, - description="The Unix timestamp (in seconds) of when the chat completion was created.", - ) - model: Optional[str] = Field( - default=None, description="The model used for the chat completion." - ) - choices: Optional[List[Choice]] = Field( - default=None, description="A list of chat completion choices." - ) - # NeMo-Guardrails specific fields for backward compatibility state: Optional[dict] = Field( default=None, description="State object for continuing the conversation." ) @@ -117,23 +37,6 @@ class ResponseBody(BaseModel): log: Optional[dict] = Field(default=None, description="Generation log data.") -class Model(BaseModel): - """OpenAI API model representation.""" - - id: str = Field( - description="The model identifier, which can be referenced in the API endpoints." - ) - object: str = Field( - default="model", description="The object type, which is always 'model'." - ) - created: int = Field( - description="The Unix timestamp (in seconds) of when the model was created." - ) - owned_by: str = Field( - default="nemo-guardrails", description="The organization that owns the model." - ) - - class ModelsResponse(BaseModel): """OpenAI API models list response.""" diff --git a/nemoguardrails/streaming.py b/nemoguardrails/streaming.py index 7cf8ac7c3..65f990c7c 100644 --- a/nemoguardrails/streaming.py +++ b/nemoguardrails/streaming.py @@ -173,18 +173,39 @@ async def __anext__(self): async def _process( self, - chunk: Union[str, object], + chunk: Union[str, dict, object], generation_info: Optional[Dict[str, Any]] = None, ): - """Process a chunk of text. + """Process a chunk of text or dict. If we're in buffering mode, record the text. Otherwise, update the full completion, check for stop tokens, and enqueue the chunk. + Dict chunks bypass completion tracking and go directly to the queue. """ if self.include_generation_metadata and generation_info: self.current_generation_info = generation_info + # Dict chunks bypass buffering and completion tracking + if isinstance(chunk, dict): + if self.pipe_to: + asyncio.create_task(self.pipe_to.push_chunk(chunk)) + else: + if self.include_generation_metadata: + await self.queue.put( + { + "text": chunk, + "generation_info": ( + self.current_generation_info.copy() + if self.current_generation_info + else {} + ), + } + ) + else: + await self.queue.put(chunk) + return + if self.enable_buffer: if chunk is not END_OF_STREAM: self.buffer += chunk if chunk is not None else "" @@ -254,10 +275,28 @@ async def _process( async def push_chunk( self, - chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None], + chunk: Union[ + str, + dict, + GenerationChunk, + AIMessageChunk, + ChatGenerationChunk, + None, + object, + ], generation_info: Optional[Dict[str, Any]] = None, ): - """Push a new chunk to the stream.""" + """Push a new chunk to the stream. + + Args: + chunk: The chunk to push. Can be: + - str: Plain text content + - dict: Dictionary with fields like role, content, etc. + - GenerationChunk/AIMessageChunk/ChatGenerationChunk: LangChain chunk types + - None: Signals end of stream (converted to END_OF_STREAM) + - object: END_OF_STREAM sentinel + generation_info: Optional metadata about the generation + """ # if generation_info is not explicitly passed, # try to get it from the chunk itself if it's a GenerationChunk or ChatGenerationChunk @@ -281,6 +320,9 @@ async def push_chunk( elif isinstance(chunk, str): # empty string is a valid chunk and should be processed normally pass + elif isinstance(chunk, dict): + # plain dict chunks are allowed (e.g., for OpenAI-compatible streaming) + pass else: raise Exception(f"Unsupported chunk type: {chunk.__class__.__name__}") @@ -291,6 +333,11 @@ async def push_chunk( if self.include_generation_metadata and generation_info: self.current_generation_info = generation_info + # Dict chunks bypass prefix/suffix processing and go directly to _process + if isinstance(chunk, dict): + await self._process(chunk, generation_info) + return + # Process prefix: accumulate until the expected prefix is received, then remove it. if self.prefix: if chunk is not None and chunk is not END_OF_STREAM: diff --git a/poetry.lock b/poetry.lock index 9e24d2a40..026ac4825 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accessible-pygments" diff --git a/pyproject.toml b/pyproject.toml index f3452a964..75f83e3f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ opentelemetry-api = { version = ">=1.27.0,<2.0.0", optional = true } aiofiles = { version = ">=24.1.0", optional = true } # openai +openai = { version = ">=1.0.0, <2.0.0", optional = true } langchain-openai = { version = ">=0.1.0", optional = true } # eval diff --git a/tests/test_api.py b/tests/test_api.py index 0a5966f2b..9fbf05c3c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,13 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import json import os import pytest from fastapi.testclient import TestClient from nemoguardrails.server import api -from nemoguardrails.server.api import RequestBody +from nemoguardrails.server.api import RequestBody, _format_streaming_response +from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler client = TestClient(api.app) @@ -199,3 +202,301 @@ def test_request_body_messages(): } request_body = RequestBody.model_validate(data) assert len(request_body.messages) == 1 + + +@pytest.mark.asyncio +async def test_openai_sse_format_basic_chunks(): + """Test basic string chunks are properly formatted as SSE events.""" + handler = StreamingHandler() + + # Collect yielded SSE messages + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push a couple of chunks and then signal completion + await handler.push_chunk("Hello ") + await handler.push_chunk("world") + await handler.push_chunk(END_OF_STREAM) + + # Wait for the collector task to finish + await task + + # We expect three messages: two data: {json}\n\n events and final data: [DONE]\n\n + assert len(collected) == 3 + # First two are JSON SSE events + evt1 = collected[0] + evt2 = collected[1] + done = collected[2] + + assert evt1.startswith("data: ") + j1 = json.loads(evt1[len("data: ") :].strip()) + assert j1["object"] == "chat.completion.chunk" + assert j1["choices"][0]["delta"]["content"] == "Hello " + + assert evt2.startswith("data: ") + j2 = json.loads(evt2[len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == "world" + + assert done == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_with_model_name(): + """Test that model name is properly included in the response.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="gpt-4"): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["model"] == "gpt-4" + assert j["choices"][0]["delta"]["content"] == "Test" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_with_dict_chunk(): + """Test that dict chunks with role and content are properly formatted.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push a dict chunk that includes role and content + await handler.push_chunk({"role": "assistant", "content": "Hi!"}) + await handler.push_chunk(None) + + await task + + # We expect two messages: one data chunk and final data: [DONE] + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["object"] == "chat.completion.chunk" + assert j["choices"][0]["delta"]["role"] == "assistant" + assert j["choices"][0]["delta"]["content"] == "Hi!" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_empty_string(): + """Test that empty strings are handled correctly.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("") + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["choices"][0]["delta"]["content"] == "" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_none_triggers_done(): + """Test that None (converted to END_OF_STREAM) triggers [DONE].""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Content") + await handler.push_chunk(None) # None converts to END_OF_STREAM + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["choices"][0]["delta"]["content"] == "Content" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_multiple_dict_chunks(): + """Test multiple dict chunks with different fields.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="test-model"): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push multiple dict chunks + await handler.push_chunk({"role": "assistant"}) + await handler.push_chunk({"content": "Hello"}) + await handler.push_chunk({"content": " world"}) + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 4 + + # Check first chunk (role only) + j1 = json.loads(collected[0][len("data: ") :].strip()) + assert j1["choices"][0]["delta"]["role"] == "assistant" + assert "content" not in j1["choices"][0]["delta"] + + # Check second chunk (content only) + j2 = json.loads(collected[1][len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == "Hello" + + # Check third chunk (content only) + j3 = json.loads(collected[2][len("data: ") :].strip()) + assert j3["choices"][0]["delta"]["content"] == " world" + + # Check [DONE] message + assert collected[3] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_special_characters(): + """Test that special characters are properly escaped in JSON.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push chunks with special characters + await handler.push_chunk("Line 1\nLine 2") + await handler.push_chunk('Quote: "test"') + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 3 + + # Verify first chunk with newline + j1 = json.loads(collected[0][len("data: ") :].strip()) + assert j1["choices"][0]["delta"]["content"] == "Line 1\nLine 2" + + # Verify second chunk with quotes + j2 = json.loads(collected[1][len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == 'Quote: "test"' + + assert collected[2] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_events(): + """Test that all events follow proper SSE format.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + # All events except [DONE] should be valid JSON with proper SSE format + for event in collected[:-1]: + assert event.startswith("data: ") + assert event.endswith("\n\n") + # Verify it's valid JSON + json_str = event[len("data: ") :].strip() + j = json.loads(json_str) + assert "object" in j + assert "choices" in j + assert isinstance(j["choices"], list) + assert len(j["choices"]) > 0 + + # Last event should be [DONE] + assert collected[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_chunk_metadata(): + """Test that chunk metadata is properly formatted.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="test-model"): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + + # Verify all required fields are present + assert j["id"] is None # id can be None for chunks + assert j["object"] == "chat.completion.chunk" + assert isinstance(j["created"], int) + assert j["model"] == "test-model" + assert isinstance(j["choices"], list) + assert len(j["choices"]) == 1 + + choice = j["choices"][0] + assert "delta" in choice + assert choice["index"] is None + assert choice["finish_reason"] is None + + +@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_chat_completion_with_streaming(): + response = client.post( + "/v1/chat/completions", + json={ + "config_id": "general", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/event-stream" + for chunk in response.iter_lines(): + assert chunk.startswith("data: ") + assert chunk.endswith("\n\n") + assert "data: [DONE]\n\n" in response.text diff --git a/tests/test_threads.py b/tests/test_threads.py index 3975e266f..baace32b7 100644 --- a/tests/test_threads.py +++ b/tests/test_threads.py @@ -51,7 +51,8 @@ def test_1(): ) assert response.status_code == 200 res = response.json() - assert len(res["choices"][0]["message"]) == 2 + assert "choices" in res + assert "message" in res["choices"][0] assert res["choices"][0]["message"]["content"] == "Hello!" # When making a second call with the same thread_id, the conversations should continue From 416ac3985a1f79b1f5853e78eef99b6885792720 Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Wed, 12 Nov 2025 14:40:12 -0500 Subject: [PATCH 5/7] Add OpenAI docs and integration tests --- docs/user-guides/community/openai.md | 16 +++ nemoguardrails/server/api.py | 119 ++++++++--------- nemoguardrails/server/schemas/openai.py | 2 +- poetry.lock | 20 +-- pyproject.toml | 7 +- tests/test_api.py | 2 +- tests/test_openai_integration.py | 167 ++++++++++++++++++++++++ 7 files changed, 258 insertions(+), 75 deletions(-) create mode 100644 docs/user-guides/community/openai.md create mode 100644 tests/test_openai_integration.py diff --git a/docs/user-guides/community/openai.md b/docs/user-guides/community/openai.md new file mode 100644 index 000000000..2150e4be8 --- /dev/null +++ b/docs/user-guides/community/openai.md @@ -0,0 +1,16 @@ +## OpenAI API Compatibility for NeMo Guardrails + +NeMo Guardrails provides server-side compatibility with OpenAI API endpoints, enabling applications that use OpenAI clients to seamlessly integrate with NeMo Guardrails for adding guardrails to LLM interactions. Point your OpenAI client to `http://localhost:8000` (or your server URL) and use the standard `/v1/chat/completions` endpoint. + +## Feature Support Matrix + +The following table outlines which OpenAI API features are currently supported when using NeMo Guardrails: + +| Feature | Status | Notes | +| :------ | :----: | :---- | +| **Basic Chat Completion** | ✔ Supported | Full support for standard chat completions with guardrails applied | +| **Streaming Responses** | ✔ Supported | Server-Sent Events (SSE) streaming with `stream=true` | +| **Multimodal Input** | ✖ Unsupported | Support for text and image inputs (vision models) with guardrails but not yet OpenAI compatible | +| **Function Calling** | ✖ Unsupported | Not yet implemented; guardrails need structured output support | +| **Tools** | ✖ Unsupported | Related to function calling; requires action flow integration | +| **Response Format (JSON Mode)** | ✖ Unsupported | Structured output with guardrails requires additional validation logic | diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index eab308563..4fa36160f 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -28,7 +28,8 @@ from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.model import Model from pydantic import BaseModel, Field, root_validator, validator from starlette.responses import StreamingResponse @@ -191,7 +192,7 @@ async def root_handler(): app.single_config_id = None -class RequestBody(ChatCompletion): +class RequestBody(BaseModel): config_id: Optional[str] = Field( default=os.getenv("DEFAULT_CONFIG_ID", None), description="The id of the configuration to be used. If not set, the default configuration will be used.", @@ -208,25 +209,31 @@ class RequestBody(ChatCompletion): max_length=255, description="The id of an existing thread to which the messages should be added.", ) - model: Optional[str] = Field( - default=None, - description="The model used for the chat completion.", + messages: Optional[List[dict]] = Field( + default=None, description="The list of messages in the current conversation." ) - id: Optional[str] = Field( + context: Optional[dict] = Field( default=None, - description="The id of the chat completion.", + description="Additional context data to be added to the conversation.", + ) + stream: Optional[bool] = Field( + default=False, + description="If set, partial message deltas will be sent, like in ChatGPT. " + "Tokens will be sent as data-only server-sent events as they become " + "available, with the stream terminated by a data: [DONE] message.", ) - object: Optional[str] = Field( - default="chat.completion", - description="The object type, which is always chat.completion", + options: GenerationOptions = Field( + default_factory=GenerationOptions, + description="Additional options for controlling the generation.", ) - created: Optional[int] = Field( + state: Optional[dict] = Field( default=None, - description="The Unix timestamp (in seconds) of when the chat completion was created.", + description="A state object that should be used to continue the interaction.", ) - choices: Optional[List[Choice]] = Field( + # Standard OpenAI completion parameters + model: Optional[str] = Field( default=None, - description="The list of choices for the chat completion.", + description="The model to use for chat completion. Maps to config_id for backward compatibility.", ) max_tokens: Optional[int] = Field( default=None, @@ -234,44 +241,35 @@ class RequestBody(ChatCompletion): ) temperature: Optional[float] = Field( default=None, - description="The temperature to use for the chat completion.", + description="Sampling temperature to use.", ) top_p: Optional[float] = Field( default=None, - description="The top p to use for the chat completion.", + description="Top-p sampling parameter.", ) - stop: Optional[Union[str, List[str]]] = Field( + stop: Optional[str] = Field( default=None, - description="The stop sequences to use for the chat completion.", + description="Stop sequences.", ) presence_penalty: Optional[float] = Field( default=None, - description="The presence penalty to use for the chat completion.", + description="Presence penalty parameter.", ) frequency_penalty: Optional[float] = Field( default=None, - description="The frequency penalty to use for the chat completion.", + description="Frequency penalty parameter.", ) - messages: Optional[List[dict]] = Field( - default=None, description="The list of messages in the current conversation." - ) - context: Optional[dict] = Field( + function_call: Optional[dict] = Field( default=None, - description="Additional context data to be added to the conversation.", + description="Function call parameter.", ) - stream: Optional[bool] = Field( - default=False, - description="If set, partial message deltas will be sent, like in ChatGPT. " - "Tokens will be sent as data-only server-sent events as they become " - "available, with the stream terminated by a data: [DONE] message.", - ) - options: GenerationOptions = Field( - default_factory=GenerationOptions, - description="Additional options for controlling the generation.", + logit_bias: Optional[dict] = Field( + default=None, + description="Logit bias parameter.", ) - state: Optional[dict] = Field( + log_probs: Optional[bool] = Field( default=None, - description="A state object that should be used to continue the interaction.", + description="Log probabilities parameter.", ) @root_validator(pre=True) @@ -453,7 +451,7 @@ async def _format_streaming_response( "choices": [ { "delta": chunk, - "index": None, + "index": 0, "finish_reason": None, } ], @@ -472,7 +470,7 @@ async def _format_streaming_response( "choices": [ { "delta": {"content": chunk}, - "index": None, + "index": 0, "finish_reason": None, } ], @@ -487,7 +485,7 @@ async def _format_streaming_response( "choices": [ { "delta": {"content": str(chunk)}, - "index": None, + "index": 0, "finish_reason": None, } ], @@ -536,16 +534,16 @@ async def chat_completion(body: RequestBody, request: Request): id=f"chatcmpl-{uuid.uuid4()}", object="chat.completion", created=int(time.time()), - model=config_ids[0] if config_ids else None, + model=config_ids[0] if config_ids else "unknown", choices=[ Choice( index=0, - message={ - "content": f"Could not load the {config_ids} guardrails configuration. " + message=ChatCompletionMessage( + content=f"Could not load the {config_ids} guardrails configuration. " f"An internal error has occurred.", - "role": "assistant", - }, - finish_reason="error", + role="assistant", + ), + finish_reason="stop", logprobs=None, ) ], @@ -569,15 +567,15 @@ async def chat_completion(body: RequestBody, request: Request): id=f"chatcmpl-{uuid.uuid4()}", object="chat.completion", created=int(time.time()), - model=None, + model=config_ids[0] if config_ids else "unknown", choices=[ Choice( index=0, - message={ - "content": "The `thread_id` must have a minimum length of 16 characters.", - "role": "assistant", - }, - finish_reason="error", + message=ChatCompletionMessage( + content="The `thread_id` must have a minimum length of 16 characters.", + role="assistant", + ), + finish_reason="stop", logprobs=None, ) ], @@ -661,11 +659,14 @@ async def chat_completion(body: RequestBody, request: Request): "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), - "model": config_ids[0] if config_ids else None, + "model": config_ids[0] if config_ids else "unknown", "choices": [ Choice( index=0, - message=bot_message, + message=ChatCompletionMessage( + role="assistant", + content=bot_message["content"], + ), finish_reason="stop", logprobs=None, ) @@ -687,15 +688,15 @@ async def chat_completion(body: RequestBody, request: Request): id=f"chatcmpl-{uuid.uuid4()}", object="chat.completion", created=int(time.time()), - model=None, + model="unknown", choices=[ Choice( index=0, - message={ - "content": "Internal server error", - "role": "assistant", - }, - finish_reason="error", + message=ChatCompletionMessage( + content="Internal server error", + role="assistant", + ), + finish_reason="stop", logprobs=None, ) ], diff --git a/nemoguardrails/server/schemas/openai.py b/nemoguardrails/server/schemas/openai.py index ce79a399a..a935a9c91 100644 --- a/nemoguardrails/server/schemas/openai.py +++ b/nemoguardrails/server/schemas/openai.py @@ -17,7 +17,7 @@ from typing import List, Optional -from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion import ChatCompletion from openai.types.model import Model from pydantic import BaseModel, Field diff --git a/poetry.lock b/poetry.lock index 026ac4825..f2811d34b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -22,7 +22,7 @@ tests = ["hypothesis", "pytest"] name = "aiofiles" version = "24.1.0" description = "File support for asyncio." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, @@ -918,7 +918,7 @@ files = [ name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, @@ -1722,7 +1722,7 @@ i18n = ["Babel (>=2.7)"] name = "jiter" version = "0.10.0" description = "Fast iterable JSON parser." -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303"}, @@ -2921,7 +2921,7 @@ sympy = "*" name = "openai" version = "1.102.0" description = "The official Python library for the openai API" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "openai-1.102.0-py3-none-any.whl", hash = "sha256:d751a7e95e222b5325306362ad02a7aa96e1fab3ed05b5888ce1c7ca63451345"}, @@ -4023,13 +4023,13 @@ dev = ["build", "flake8", "mypy", "pytest", "twine"] [[package]] name = "pyright" -version = "1.1.405" +version = "1.1.407" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a"}, - {file = "pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763"}, + {file = "pyright-1.1.407-py3-none-any.whl", hash = "sha256:6dd419f54fcc13f03b52285796d65e639786373f433e243f8b94cf93a7444d21"}, + {file = "pyright-1.1.407.tar.gz", hash = "sha256:099674dba5c10489832d4a4b2d302636152a9a42d317986c38474c76fe562262"}, ] [package.dependencies] @@ -6194,16 +6194,16 @@ files = [ cffi = ["cffi (>=1.17)"] [extras] -all = ["aiofiles", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] +all = ["google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] eval = ["numpy", "numpy", "numpy", "numpy", "streamlit", "tornado", "tqdm"] gcp = ["google-cloud-language"] jailbreak = ["yara-python"] nvidia = ["langchain-nvidia-ai-endpoints"] openai = ["langchain-openai"] sdd = ["presidio-analyzer", "presidio-anonymizer"] -tracing = ["aiofiles", "opentelemetry-api"] +tracing = ["opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "d5e8dc8fdbad5781141f4c65671d115060aa4c99abca0bd72ec025781352b775" +content-hash = "8d456424d7a10f6e08c69755568b81cd8d2779bae98baa8d29f2be06098c3bf5" diff --git a/pyproject.toml b/pyproject.toml index 75f83e3f9..78d78bbd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,13 +71,13 @@ starlette = ">=0.49.1" typer = ">=0.8" uvicorn = ">=0.23" watchdog = ">=3.0.0," +aiofiles = ">=24.1.0" +openai = ">=1.0.0, <2.0.0" # tracing opentelemetry-api = { version = ">=1.27.0,<2.0.0", optional = true } -aiofiles = { version = ">=24.1.0", optional = true } # openai -openai = { version = ">=1.0.0, <2.0.0", optional = true } langchain-openai = { version = ">=0.1.0", optional = true } # eval @@ -111,7 +111,7 @@ sdd = ["presidio-analyzer", "presidio-anonymizer"] eval = ["tqdm", "numpy", "streamlit", "tornado"] openai = ["langchain-openai"] gcp = ["google-cloud-language"] -tracing = ["opentelemetry-api", "aiofiles"] +tracing = ["opentelemetry-api"] nvidia = ["langchain-nvidia-ai-endpoints"] jailbreak = ["yara-python"] # Poetry does not support recursive dependencies, so we need to add all the dependencies here. @@ -126,7 +126,6 @@ all = [ "langchain-openai", "google-cloud-language", "opentelemetry-api", - "aiofiles", "langchain-nvidia-ai-endpoints", "yara-python", ] diff --git a/tests/test_api.py b/tests/test_api.py index 9fbf05c3c..66fb84406 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -480,7 +480,7 @@ async def collector(): choice = j["choices"][0] assert "delta" in choice - assert choice["index"] is None + assert choice["index"] == 0 assert choice["finish_reason"] is None diff --git a/tests/test_openai_integration.py b/tests/test_openai_integration.py new file mode 100644 index 000000000..735651a66 --- /dev/null +++ b/tests/test_openai_integration.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +from fastapi.testclient import TestClient +from openai import OpenAI + +from nemoguardrails.server import api + + +@pytest.fixture(scope="function", autouse=True) +def set_rails_config_path(): + """Set the rails config path to the test configs directory.""" + original_path = api.app.rails_config_path + api.app.rails_config_path = os.path.normpath( + os.path.join(os.path.dirname(__file__), "test_configs/simple_server") + ) + yield + + # Restore the original path and clear cache after the test + api.app.rails_config_path = original_path + api.llm_rails_instances.clear() + api.llm_rails_events_history_cache.clear() + + +@pytest.fixture(scope="function") +def test_client(): + """Create a FastAPI TestClient for the guardrails server.""" + return TestClient(api.app) + + +@pytest.fixture(scope="function") +def openai_client(test_client): + client = OpenAI( + api_key="dummy-key", + base_url="http://dummy-url/v1", + http_client=test_client, + ) + return client + + +def test_openai_client_list_models(openai_client): + models = openai_client.models.list() + + # Verify the response structure matches OpenAI's ModelList + assert models is not None + assert hasattr(models, "data") + assert len(models.data) > 0 + + # Check first model has required fields + model = models.data[0] + assert hasattr(model, "id") + assert hasattr(model, "object") + assert model.object == "model" + assert hasattr(model, "created") + assert hasattr(model, "owned_by") + assert model.owned_by == "nemo-guardrails" + + +def test_openai_client_chat_completion(openai_client): + response = openai_client.chat.completions.create( + model="config_1", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + # Verify response structure matches OpenAI's ChatCompletion object + assert response is not None + assert hasattr(response, "id") + assert response.id is not None + assert hasattr(response, "object") + assert response.object == "chat.completion" + assert hasattr(response, "created") + assert response.created > 0 + assert hasattr(response, "model") + assert response.model == "config_1" + + # Verify choices structure + assert hasattr(response, "choices") + assert len(response.choices) == 1 + choice = response.choices[0] + assert hasattr(choice, "index") + assert choice.index == 0 + assert hasattr(choice, "message") + assert hasattr(choice.message, "role") + assert choice.message.role == "assistant" + assert hasattr(choice.message, "content") + assert choice.message.content is not None + assert isinstance(choice.message.content, str) + assert len(choice.message.content) > 0 + assert hasattr(choice, "finish_reason") + assert choice.finish_reason == "stop" + + +def test_openai_client_chat_completion_parameterized(openai_client): + response = openai_client.chat.completions.create( + model="config_1", + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + max_tokens=100, + stream=False, + ) + + # Verify response exists + assert response is not None + assert response.choices[0].message.content is not None + + +def test_openai_client_chat_completion_input_rails(openai_client): + response = openai_client.chat.completions.create( + model="input_rails", + messages=[{"role": "user", "content": "Hello, how are you?"}], + stream=False, + ) + + # Verify response exists + assert response is not None + assert response.choices[0].message.content is not None + assert isinstance(response.choices[0].message.content, str) + + +@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_openai_client_chat_completion_streaming(openai_client): + stream = openai_client.chat.completions.create( + model="input_rails", + messages=[{"role": "user", "content": "Tell me a short joke."}], + stream=True, + ) + + chunks = list(stream) + assert len(chunks) > 0 + + # Verify at least one chunk has content + has_content = any( + hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content + for chunk in chunks + ) + assert has_content, "At least one chunk should contain content" + + +def test_openai_client_error_handling_invalid_model(openai_client): + response = openai_client.chat.completions.create( + model="nonexistent_config", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + # The server should return a response (not raise an exception) + assert response is not None + # The error should be in the content + assert ( + "Could not load" in response.choices[0].message.content + or "error" in response.choices[0].message.content.lower() + ) From c37edfb2eff8622baf49ed06f520ab99ceb33082 Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Tue, 2 Dec 2025 15:56:40 -0500 Subject: [PATCH 6/7] Add model name to response body --- nemoguardrails/rails/llm/llmrails.py | 5 + nemoguardrails/server/api.py | 133 ++++++++++++++---------- nemoguardrails/server/schemas/openai.py | 31 +++--- nemoguardrails/streaming.py | 4 +- poetry.lock | 22 ++-- tests/test_api.py | 13 ++- tests/test_openai_integration.py | 99 +++++++++--------- tests/test_streaming.py | 59 +++++++++++ 8 files changed, 229 insertions(+), 137 deletions(-) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index c4d33f83d..528a4c6d3 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -490,6 +490,11 @@ def _init_llms(self): if not self.llm: self.llm = llm_model self.runtime.register_action_param("llm", self.llm) + self._configure_main_llm_streaming( + self.llm, + model_name=llm_config.model, + provider_name=llm_config.engine, + ) else: model_name = f"{llm_config.type}_llm" if not hasattr(self, model_name): diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 4fa36160f..3fb726588 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -24,13 +24,12 @@ import uuid import warnings from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Callable, List, Optional, Union +from typing import Any, AsyncIterator, Callable, List, Optional from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage -from openai.types.model import Model from pydantic import BaseModel, Field, root_validator, validator from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles @@ -38,7 +37,11 @@ from nemoguardrails import LLMRails, RailsConfig, utils from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse from nemoguardrails.server.datastore.datastore import DataStore -from nemoguardrails.server.schemas.openai import ModelsResponse, ResponseBody +from nemoguardrails.server.schemas.openai import ( + GuardrailsModel, + ModelsResponse, + ResponseBody, +) from nemoguardrails.streaming import StreamingHandler logging.basicConfig(level=logging.INFO) @@ -90,9 +93,9 @@ async def lifespan(app: GuardrailsApp): # If there is a `config.yml` in the root `app.rails_config_path`, then # that means we are in single config mode. - if os.path.exists( - os.path.join(app.rails_config_path, "config.yml") - ) or os.path.exists(os.path.join(app.rails_config_path, "config.yaml")): + if os.path.exists(os.path.join(app.rails_config_path, "config.yml")) or os.path.exists( + os.path.join(app.rails_config_path, "config.yaml") + ): app.single_config_mode = True app.single_config_id = os.path.basename(app.rails_config_path) else: @@ -232,8 +235,8 @@ class RequestBody(BaseModel): ) # Standard OpenAI completion parameters model: Optional[str] = Field( - default=None, - description="The model to use for chat completion. Maps to config_id for backward compatibility.", + default="main", + description="The model to use for chat completion. Maps to the main model in the config.", ) max_tokens: Optional[int] = Field( default=None, @@ -278,15 +281,11 @@ def ensure_config_id(cls, data: Any) -> Any: if data.get("model") is not None and data.get("config_id") is None: data["config_id"] = data["model"] if data.get("config_id") is not None and data.get("config_ids") is not None: - raise ValueError( - "Only one of config_id or config_ids should be specified" - ) + raise ValueError("Only one of config_id or config_ids should be specified") if data.get("config_id") is None and data.get("config_ids") is not None: data["config_id"] = None if data.get("config_id") is None and data.get("config_ids") is None: - warnings.warn( - "No config_id or config_ids provided, using default config_id" - ) + warnings.warn("No config_id or config_ids provided, using default config_id") return data @validator("config_ids", pre=True, always=True) @@ -309,6 +308,7 @@ async def get_models(): # Use the same logic as get_rails_configs to find available configurations if app.single_config_mode: config_ids = [app.single_config_id] if app.single_config_id else [] + else: config_ids = [ f @@ -323,16 +323,43 @@ async def get_models(): ) ] - # Convert configurations to OpenAI model format models = [] for config_id in config_ids: - model = Model( - id=config_id, - object="model", - created=int(time.time()), # Use current time as created timestamp - owned_by="nemo-guardrails", - ) - models.append(model) + try: + # Load the RailsConfig to extract model information + if app.single_config_mode: + config_path = app.rails_config_path + else: + config_path = os.path.join(app.rails_config_path, config_id) + + rails_config = RailsConfig.from_path(config_path) + # Extract all models from this config + config_models = rails_config.models + + if len(config_models) == 0: + guardrails_model = GuardrailsModel( + id=config_id, + object="model", + created=int(time.time()), + owned_by="nemo-guardrails", + guardrails_config_id=config_id, + ) + models.append(guardrails_model) + else: + for model in config_models: + # Only include models with a model name + if model.model: + guardrails_model = GuardrailsModel( + id=model.model, + object="model", + created=int(time.time()), + owned_by="nemo-guardrails", + guardrails_config_id=config_id, + ) + models.append(guardrails_model) + except Exception as ex: + log.warning(f"Could not load model info for config {config_id}: {ex}") + continue return ModelsResponse(data=models) @@ -377,6 +404,14 @@ def _generate_cache_key(config_ids: List[str]) -> str: return "-".join((config_ids)) # remove sorted +def _get_main_model_name(rails_config: RailsConfig) -> Optional[str]: + """Extracts the main model name from a RailsConfig.""" + main_models = [m for m in rails_config.models if m.type == "main"] + if main_models and main_models[0].model: + return main_models[0].model + return None + + def _get_rails(config_ids: List[str]) -> LLMRails: """Returns the rails instance for the given config id.""" @@ -422,9 +457,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails: llm_rails_instances[configs_cache_key] = llm_rails # If we have a cache for the events, we restore it - llm_rails.events_history_cache = llm_rails_events_history_cache.get( - configs_cache_key, {} - ) + llm_rails.events_history_cache = llm_rails_events_history_cache.get(configs_cache_key, {}) return llm_rails @@ -508,9 +541,7 @@ async def chat_completion(body: RequestBody, request: Request): """ log.info("Got request for config %s", body.config_id) for logger in registered_loggers: - asyncio.get_event_loop().create_task( - logger({"endpoint": "/v1/chat/completions", "body": body.json()}) - ) + asyncio.get_event_loop().create_task(logger({"endpoint": "/v1/chat/completions", "body": body.json()})) # Save the request headers in a context variable. api_request_headers.set(request.headers) @@ -518,16 +549,16 @@ async def chat_completion(body: RequestBody, request: Request): # Use Request config_ids if set, otherwise use the FastAPI default config. # If neither is available we can't generate any completions as we have no config_id config_ids = body.config_ids + if not config_ids: if app.default_config_id: config_ids = [app.default_config_id] else: - raise GuardrailsConfigurationError( - "No request config_ids provided and server has no default configuration" - ) + raise GuardrailsConfigurationError("No request config_ids provided and server has no default configuration") try: llm_rails = _get_rails(config_ids) + except ValueError as ex: log.exception(ex) return ResponseBody( @@ -550,6 +581,10 @@ async def chat_completion(body: RequestBody, request: Request): ) try: + main_model_name = _get_main_model_name(llm_rails.config) + if main_model_name is None: + main_model_name = config_ids[0] if config_ids else "unknown" + messages = body.messages or [] if body.context: messages.insert(0, {"role": "context", "content": body.context}) @@ -560,14 +595,13 @@ async def chat_completion(body: RequestBody, request: Request): if body.thread_id: if datastore is None: raise RuntimeError("No DataStore has been configured.") - # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: return ResponseBody( id=f"chatcmpl-{uuid.uuid4()}", object="chat.completion", created=int(time.time()), - model=config_ids[0] if config_ids else "unknown", + model=main_model_name, choices=[ Choice( index=0, @@ -608,12 +642,7 @@ async def chat_completion(body: RequestBody, request: Request): generation_options.llm_params["presence_penalty"] = body.presence_penalty if body.frequency_penalty is not None: generation_options.llm_params["frequency_penalty"] = body.frequency_penalty - - if ( - body.stream - and llm_rails.config.streaming_supported - and llm_rails.main_llm_supports_streaming - ): + if body.stream and llm_rails.config.streaming_supported and llm_rails.main_llm_supports_streaming: # Create the streaming handler instance streaming_handler = StreamingHandler() @@ -628,15 +657,11 @@ async def chat_completion(body: RequestBody, request: Request): ) return StreamingResponse( - _format_streaming_response( - streaming_handler, model_name=config_ids[0] if config_ids else None - ), + _format_streaming_response(streaming_handler, model_name=main_model_name), media_type="text/event-stream", ) else: - res = await llm_rails.generate_async( - messages=messages, options=generation_options, state=body.state - ) + res = await llm_rails.generate_async(messages=messages, options=generation_options, state=body.state) if isinstance(res, GenerationResponse): bot_message_content = res.response[0] @@ -654,12 +679,12 @@ async def chat_completion(body: RequestBody, request: Request): if body.thread_id and datastore is not None and datastore_key is not None: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions + # Build the response with OpenAI-compatible format response_kwargs = { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), - "model": config_ids[0] if config_ids else "unknown", + "model": main_model_name, "choices": [ Choice( index=0, @@ -688,7 +713,7 @@ async def chat_completion(body: RequestBody, request: Request): id=f"chatcmpl-{uuid.uuid4()}", object="chat.completion", created=int(time.time()), - model="unknown", + model=config_ids[0] if config_ids else "unknown", choices=[ Choice( index=0, @@ -750,9 +775,7 @@ def on_any_event(self, event): return None elif event.event_type == "created" or event.event_type == "modified": - log.info( - f"Watchdog received {event.event_type} event for file {event.src_path}" - ) + log.info(f"Watchdog received {event.event_type} event for file {event.src_path}") # Compute the relative path src_path_str = str(event.src_path) @@ -776,9 +799,7 @@ def on_any_event(self, event): # We save the events history cache, to restore it on the new instance llm_rails_events_history_cache[config_id] = val - log.info( - f"Configuration {config_id} has changed. Clearing cache." - ) + log.info(f"Configuration {config_id} has changed. Clearing cache.") observer = Observer() event_handler = Handler() @@ -793,9 +814,7 @@ def on_any_event(self, event): except ImportError: # Since this is running in a separate thread, we just print the error. - print( - "The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`." - ) + print("The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`.") # Force close everything. os._exit(-1) diff --git a/nemoguardrails/server/schemas/openai.py b/nemoguardrails/server/schemas/openai.py index a935a9c91..fff6d020b 100644 --- a/nemoguardrails/server/schemas/openai.py +++ b/nemoguardrails/server/schemas/openai.py @@ -25,22 +25,27 @@ class ResponseBody(ChatCompletion): """OpenAI API response body with NeMo-Guardrails extensions.""" - state: Optional[dict] = Field( - default=None, description="State object for continuing the conversation." - ) - llm_output: Optional[dict] = Field( - default=None, description="Additional LLM output data." - ) - output_data: Optional[dict] = Field( - default=None, description="Additional output data." + guardrails_config_id: Optional[str] = Field( + default=None, + description="The guardrails configuration ID associated with this response.", ) + state: Optional[dict] = Field(default=None, description="State object for continuing the conversation.") + llm_output: Optional[dict] = Field(default=None, description="Additional LLM output data.") + output_data: Optional[dict] = Field(default=None, description="Additional output data.") log: Optional[dict] = Field(default=None, description="Generation log data.") -class ModelsResponse(BaseModel): - """OpenAI API models list response.""" +class GuardrailsModel(Model): + """OpenAI API model with NeMo-Guardrails extensions.""" - object: str = Field( - default="list", description="The object type, which is always 'list'." + guardrails_config_id: Optional[str] = Field( + default=None, + description="[NeMo Guardrails extension] The guardrails configuration ID associated with this model.", ) - data: List[Model] = Field(description="The list of models.") + + +class ModelsResponse(BaseModel): + """OpenAI API models list response with NeMo-Guardrails extensions.""" + + object: str = Field(default="list", description="The object type, which is always 'list'.") + data: List[GuardrailsModel] = Field(description="The list of models.") diff --git a/nemoguardrails/streaming.py b/nemoguardrails/streaming.py index 65f990c7c..5a862c18d 100644 --- a/nemoguardrails/streaming.py +++ b/nemoguardrails/streaming.py @@ -196,9 +196,7 @@ async def _process( { "text": chunk, "generation_info": ( - self.current_generation_info.copy() - if self.current_generation_info - else {} + self.current_generation_info.copy() if self.current_generation_info else {} ), } ) diff --git a/poetry.lock b/poetry.lock index f2811d34b..9e24d2a40 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -22,7 +22,7 @@ tests = ["hypothesis", "pytest"] name = "aiofiles" version = "24.1.0" description = "File support for asyncio." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, @@ -918,7 +918,7 @@ files = [ name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, @@ -1722,7 +1722,7 @@ i18n = ["Babel (>=2.7)"] name = "jiter" version = "0.10.0" description = "Fast iterable JSON parser." -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303"}, @@ -2921,7 +2921,7 @@ sympy = "*" name = "openai" version = "1.102.0" description = "The official Python library for the openai API" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "openai-1.102.0-py3-none-any.whl", hash = "sha256:d751a7e95e222b5325306362ad02a7aa96e1fab3ed05b5888ce1c7ca63451345"}, @@ -4023,13 +4023,13 @@ dev = ["build", "flake8", "mypy", "pytest", "twine"] [[package]] name = "pyright" -version = "1.1.407" +version = "1.1.405" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.407-py3-none-any.whl", hash = "sha256:6dd419f54fcc13f03b52285796d65e639786373f433e243f8b94cf93a7444d21"}, - {file = "pyright-1.1.407.tar.gz", hash = "sha256:099674dba5c10489832d4a4b2d302636152a9a42d317986c38474c76fe562262"}, + {file = "pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a"}, + {file = "pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763"}, ] [package.dependencies] @@ -6194,16 +6194,16 @@ files = [ cffi = ["cffi (>=1.17)"] [extras] -all = ["google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] +all = ["aiofiles", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] eval = ["numpy", "numpy", "numpy", "numpy", "streamlit", "tornado", "tqdm"] gcp = ["google-cloud-language"] jailbreak = ["yara-python"] nvidia = ["langchain-nvidia-ai-endpoints"] openai = ["langchain-openai"] sdd = ["presidio-analyzer", "presidio-anonymizer"] -tracing = ["opentelemetry-api"] +tracing = ["aiofiles", "opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "8d456424d7a10f6e08c69755568b81cd8d2779bae98baa8d29f2be06098c3bf5" +content-hash = "d5e8dc8fdbad5781141f4c65671d115060aa4c99abca0bd72ec025781352b775" diff --git a/tests/test_api.py b/tests/test_api.py index 66fb84406..8c50b5c01 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -24,6 +24,8 @@ from nemoguardrails.server.api import RequestBody, _format_streaming_response from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler +LIVE_TEST_MODE = os.environ.get("LIVE_TEST_MODE") or os.environ.get("TEST_LIVE_MODE") + client = TestClient(api.app) @@ -59,12 +61,16 @@ def test_get_models(): # Check each model has the required OpenAI format for model in result["data"]: assert "id" in model + assert "guardrails_config_id" in model assert model["object"] == "model" assert "created" in model assert model["owned_by"] == "nemo-guardrails" -@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +@pytest.mark.skipif( + not LIVE_TEST_MODE, + reason="This test requires LIVE_TEST_MODE or TEST_LIVE_MODE environment variable to be set for live testing", +) def test_chat_completion(): response = client.post( "/v1/chat/completions", @@ -90,7 +96,10 @@ def test_chat_completion(): assert res["choices"][0]["message"]["role"] == "assistant" -@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +@pytest.mark.skipif( + not LIVE_TEST_MODE, + reason="This test requires LIVE_TEST_MODE or TEST_LIVE_MODE environment variable to be set for live testing", +) def test_chat_completion_with_default_configs(): api.set_default_config_id("general") diff --git a/tests/test_openai_integration.py b/tests/test_openai_integration.py index 735651a66..9d2523668 100644 --- a/tests/test_openai_integration.py +++ b/tests/test_openai_integration.py @@ -13,10 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import time import pytest from fastapi.testclient import TestClient from openai import OpenAI +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.model import Model from nemoguardrails.server import api @@ -25,9 +29,7 @@ def set_rails_config_path(): """Set the rails config path to the test configs directory.""" original_path = api.app.rails_config_path - api.app.rails_config_path = os.path.normpath( - os.path.join(os.path.dirname(__file__), "test_configs/simple_server") - ) + api.app.rails_config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), "test_configs/simple_server")) yield # Restore the original path and clear cache after the test @@ -55,19 +57,14 @@ def openai_client(test_client): def test_openai_client_list_models(openai_client): models = openai_client.models.list() - # Verify the response structure matches OpenAI's ModelList - assert models is not None - assert hasattr(models, "data") - assert len(models.data) > 0 - - # Check first model has required fields - model = models.data[0] - assert hasattr(model, "id") - assert hasattr(model, "object") - assert model.object == "model" - assert hasattr(model, "created") - assert hasattr(model, "owned_by") - assert model.owned_by == "nemo-guardrails" + # Verify the response structure matches the GuardrailsModel schema + assert models.data[0] == Model( + id="config_1", + object="model", + created=int(time.time()), + owned_by="nemo-guardrails", + guardrails_config_id="config_1", + ) def test_openai_client_chat_completion(openai_client): @@ -77,32 +74,24 @@ def test_openai_client_chat_completion(openai_client): stream=False, ) - # Verify response structure matches OpenAI's ChatCompletion object - assert response is not None - assert hasattr(response, "id") + assert isinstance(response, ChatCompletion) assert response.id is not None - assert hasattr(response, "object") - assert response.object == "chat.completion" + + assert response.choices[0] == Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="Hello!", + refusal=None, + role="assistant", + annotations=None, + audio=None, + function_call=None, + tool_calls=None, + ), + ) assert hasattr(response, "created") - assert response.created > 0 - assert hasattr(response, "model") - assert response.model == "config_1" - - # Verify choices structure - assert hasattr(response, "choices") - assert len(response.choices) == 1 - choice = response.choices[0] - assert hasattr(choice, "index") - assert choice.index == 0 - assert hasattr(choice, "message") - assert hasattr(choice.message, "role") - assert choice.message.role == "assistant" - assert hasattr(choice.message, "content") - assert choice.message.content is not None - assert isinstance(choice.message.content, str) - assert len(choice.message.content) > 0 - assert hasattr(choice, "finish_reason") - assert choice.finish_reason == "stop" def test_openai_client_chat_completion_parameterized(openai_client): @@ -115,8 +104,20 @@ def test_openai_client_chat_completion_parameterized(openai_client): ) # Verify response exists - assert response is not None - assert response.choices[0].message.content is not None + assert isinstance(response, ChatCompletion) + assert response.id is not None + assert response.choices[0] == Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="Hello!", + refusal=None, + role="assistant", + annotations=None, + ), + ) + assert hasattr(response, "created") def test_openai_client_chat_completion_input_rails(openai_client): @@ -127,9 +128,10 @@ def test_openai_client_chat_completion_input_rails(openai_client): ) # Verify response exists - assert response is not None - assert response.choices[0].message.content is not None - assert isinstance(response.choices[0].message.content, str) + assert isinstance(response, ChatCompletion) + assert response.id is not None + assert isinstance(response.choices[0], Choice) + assert hasattr(response, "created") @pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") @@ -144,10 +146,7 @@ def test_openai_client_chat_completion_streaming(openai_client): assert len(chunks) > 0 # Verify at least one chunk has content - has_content = any( - hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content - for chunk in chunks - ) + has_content = any(hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content for chunk in chunks) assert has_content, "At least one chunk should contain content" @@ -158,8 +157,6 @@ def test_openai_client_error_handling_invalid_model(openai_client): stream=False, ) - # The server should return a response (not raise an exception) - assert response is not None # The error should be in the content assert ( "Could not load" in response.choices[0].message.content diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c7f59a7d1..fa7ffaa49 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -815,3 +815,62 @@ def test_main_llm_supports_streaming_flag_disabled_when_no_streaming(): assert rails.main_llm_supports_streaming is False, ( "main_llm_supports_streaming should be False when streaming is disabled" ) + + +def test_main_llm_supports_streaming_with_multiple_model_types( + custom_streaming_providers, +): + """Test that streaming is properly configured when config has multiple model types.""" + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "main", + "engine": "custom_streaming", + "model": "test-model", + }, + { + "type": "content_safety", + "engine": "custom_streaming", + "model": "safety-model", + }, + ], + "streaming": True, + } + ) + + rails = LLMRails(config) + + assert rails.main_llm_supports_streaming is True, ( + "main_llm_supports_streaming should be True when streaming is enabled " + "and config has multiple model types including a streaming-capable main LLM" + ) + # Verify the main LLM's streaming attribute was set + assert hasattr(rails.llm, "streaming") and rails.llm.streaming is True, ( + "Main LLM's streaming attribute should be set to True" + ) + + +def test_main_llm_supports_streaming_with_specialized_models_only( + custom_streaming_providers, +): + """Test streaming config when only specialized models are defined (no main).""" + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "content_safety", + "engine": "custom_streaming", + "model": "safety-model", + }, + ], + "streaming": True, + } + ) + + rails = LLMRails(config) + + # Verify that main_llm_supports_streaming is False when no main LLM is configured + assert rails.main_llm_supports_streaming is False, ( + "main_llm_supports_streaming should be False when no main LLM is configured" + ) From 720e6aaf487001b37ddb3f4700aa6d6bff9f6a32 Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Wed, 3 Dec 2025 08:04:19 -0500 Subject: [PATCH 7/7] Update poetry.lock --- poetry.lock | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9e24d2a40..5701c647a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -22,7 +22,7 @@ tests = ["hypothesis", "pytest"] name = "aiofiles" version = "24.1.0" description = "File support for asyncio." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, @@ -918,7 +918,7 @@ files = [ name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, @@ -1722,7 +1722,7 @@ i18n = ["Babel (>=2.7)"] name = "jiter" version = "0.10.0" description = "Fast iterable JSON parser." -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303"}, @@ -2921,7 +2921,7 @@ sympy = "*" name = "openai" version = "1.102.0" description = "The official Python library for the openai API" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "openai-1.102.0-py3-none-any.whl", hash = "sha256:d751a7e95e222b5325306362ad02a7aa96e1fab3ed05b5888ce1c7ca63451345"}, @@ -6194,16 +6194,16 @@ files = [ cffi = ["cffi (>=1.17)"] [extras] -all = ["aiofiles", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] +all = ["google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] eval = ["numpy", "numpy", "numpy", "numpy", "streamlit", "tornado", "tqdm"] gcp = ["google-cloud-language"] jailbreak = ["yara-python"] nvidia = ["langchain-nvidia-ai-endpoints"] openai = ["langchain-openai"] sdd = ["presidio-analyzer", "presidio-anonymizer"] -tracing = ["aiofiles", "opentelemetry-api"] +tracing = ["opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "d5e8dc8fdbad5781141f4c65671d115060aa4c99abca0bd72ec025781352b775" +content-hash = "a048d4ecee654c25ea1be4a65cfccf4bb51289b2aa4db72afd5d096f3d2add1a"