Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,12 +1350,13 @@ def check_reasoning_traces_with_dialog_rails(cls, values):
@root_validator(pre=True, allow_reuse=True)
def check_prompt_exist_for_self_check_rails(cls, values):
rails = values.get("rails", {})
prompts = values.get("prompts", []) or []

enabled_input_rails = rails.get("input", {}).get("flows", [])
enabled_output_rails = rails.get("output", {}).get("flows", [])
provided_task_prompts = [
prompt.task if hasattr(prompt, "task") else prompt.get("task")
for prompt in values.get("prompts", [])
for prompt in prompts
]

# Input moderation prompt verification
Expand Down Expand Up @@ -1410,7 +1411,7 @@ def check_output_parser_exists(cls, values):
# "content_safety_check input $model",
# "content_safety_check output $model",
]
prompts = values.get("prompts", [])
prompts = values.get("prompts") or []
for prompt in prompts:
task = prompt.task if hasattr(prompt, "task") else prompt.get("task")
output_parser = (
Expand Down Expand Up @@ -1657,12 +1658,12 @@ def _join_rails_configs(
combined_rails_config_dict = _join_dict(
base_rails_config.dict(), updated_rails_config.dict()
)
combined_rails_config_dict["config_path"] = ",".join(
[
base_rails_config.dict()["config_path"],
updated_rails_config.dict()["config_path"],
]
)
# filter out empty strings to avoid leading/trailing commas
config_paths = [
base_rails_config.dict()["config_path"] or "",
updated_rails_config.dict()["config_path"] or "",
]
combined_rails_config_dict["config_path"] = ",".join(filter(None, config_paths))
combined_rails_config = RailsConfig(**combined_rails_config_dict)
return combined_rails_config

Expand Down
186 changes: 185 additions & 1 deletion tests/rails/llm/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
import pytest
from pydantic import ValidationError

from nemoguardrails.rails.llm.config import TaskPrompt
from nemoguardrails.rails.llm.config import (
Document,
Instruction,
Model,
RailsConfig,
TaskPrompt,
)


def test_task_prompt_valid_content():
Expand Down Expand Up @@ -123,3 +129,181 @@ def test_task_prompt_max_tokens_validation():
with pytest.raises(ValidationError) as excinfo:
TaskPrompt(task="example_task", content="Test prompt", max_tokens=-1)
assert "Input should be greater than or equal to 1" in str(excinfo.value)


def test_rails_config_addition():
"""Tests that adding two RailsConfig objects merges both into a single RailsConfig."""
config1 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
config_path="test_config.yml",
)
config2 = RailsConfig(
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
config_path="test_config2.yml",
)

result = config1 + config2

assert isinstance(result, RailsConfig)
assert len(result.models) == 2
assert result.config_path == "test_config.yml,test_config2.yml"


def test_rails_config_model_conflicts():
"""Tests that adding two RailsConfig objects with conflicting models raises an error."""
config1 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
config_path="config1.yml",
)

# Different engine for same model type
config2 = RailsConfig(
models=[Model(type="main", engine="nim", model="gpt-3.5-turbo")],
config_path="config2.yml",
)
with pytest.raises(
ValueError,
match="Both config files should have the same engine for the same model type",
):
config1 + config2

# Different model for same model type
config3 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-4")],
config_path="config3.yml",
)
with pytest.raises(
ValueError,
match="Both config files should have the same model for the same model type",
):
config1 + config3


def test_rails_config_actions_server_url_conflicts():
"""Tests that adding two RailsConfig objects with different values for `actions_server_url` raises an error."""
config1 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
actions_server_url="http://localhost:8000",
)

config2 = RailsConfig(
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
actions_server_url="http://localhost:9000",
)

with pytest.raises(
ValueError, match="Both config files should have the same actions_server_url"
):
config1 + config2


def test_rails_config_simple_field_overwriting():
"""Tests that fields from the second config overwrite fields from the first config."""
config1 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
streaming=False,
lowest_temperature=0.1,
colang_version="1.0",
)

config2 = RailsConfig(
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
streaming=True,
lowest_temperature=0.5,
colang_version="2.x",
)

result = config1 + config2

assert result.streaming is True
assert result.lowest_temperature == 0.5
assert result.colang_version == "2.x"


def test_rails_config_nested_dictionary_merging():
"""Tests nested dictionaries are merged correctly."""
config1 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
rails={
"input": {"flows": ["flow1"], "parallel": False},
"output": {"flows": ["flow2"]},
},
knowledge_base={
"folder": "kb1",
"embedding_search_provider": {"name": "provider1"},
},
custom_data={"setting1": "value1", "nested": {"key1": "val1"}},
)

config2 = RailsConfig(
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
rails={
"input": {"flows": ["flow3"], "parallel": True},
"retrieval": {"flows": ["flow4"]},
},
knowledge_base={
"folder": "kb2",
"embedding_search_provider": {"name": "provider2"},
},
custom_data={"setting2": "value2", "nested": {"key2": "val2"}},
)

result = config1 + config2

assert result.rails.input.flows == ["flow3", "flow1"]
assert result.rails.input.parallel is True
assert result.rails.output.flows == ["flow2"]
assert result.rails.retrieval.flows == ["flow4"]

assert result.knowledge_base.folder == "kb2"
assert result.knowledge_base.embedding_search_provider.name == "provider2"

assert result.custom_data["setting1"] == "value1"
assert result.custom_data["setting2"] == "value2"
assert result.custom_data["nested"]["key1"] == "val1"
assert result.custom_data["nested"]["key2"] == "val2"


def test_rails_config_none_prompts():
"""Test that configs with None prompts can be added without errors."""
config1 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
prompts=None,
rails={"input": {"flows": ["self_check_input"]}},
)
config2 = RailsConfig(
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
prompts=[],
)

result = config1 + config2
assert result is not None
assert result.prompts is not None


def test_rails_config_none_config_path():
"""Test that configs with None config_path can be added."""
config1 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
config_path=None,
)
config2 = RailsConfig(
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
config_path="config2.yml",
)

result = config1 + config2
# should not have leading comma after fix
assert result.config_path == "config2.yml"

config3 = RailsConfig(
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
config_path=None,
)
config4 = RailsConfig(
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
config_path=None,
)

result2 = config3 + config4
assert result2.config_path == ""