Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion backend/app/crud/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
QueryInput,
ImageInput,
PDFInput,
LLMCallConfig,
QueryParams,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -125,7 +127,7 @@ def create_llm_call(
}
else:
config_dict = {
"config_blob": resolved_config.model_dump(),
"config_blob": resolved_config.model_dump(mode="json"),
}

# Extract conversation info if present
Expand Down Expand Up @@ -249,6 +251,74 @@ def update_llm_call_input(
)


def save_rephrase_guardrail_call(
*,
session: Session,
query: QueryParams,
config: LLMCallConfig,
request_metadata: dict | None,
config_blob: ConfigBlob,
guardrail_direct_response: str,
job_id: UUID,
project_id: int,
organization_id: int,
chain_id: UUID | None,
) -> UUID | None:
"""Persist the LLM call record for a guardrail rephrase response.

Returns the llm_call_id on success, None if the DB write fails (non-fatal).
"""
try:
rephrase_call_request = LLMCallRequest(
query=query,
config=config,
request_metadata=request_metadata,
)
rephrase_llm_call = create_llm_call(
session,
request=rephrase_call_request,
job_id=job_id,
project_id=project_id,
organization_id=organization_id,
resolved_config=config_blob,
original_provider=str(config_blob.completion.provider),
chain_id=chain_id,
)
try:
update_llm_call_response(
session,
llm_call_id=rephrase_llm_call.id,
provider_response_id=None,
content={
"type": "text",
"content": {
"format": "text",
"value": guardrail_direct_response,
},
},
# No LLM was invoked, so token counts are genuinely zero.
usage={
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
},
)
except Exception:
try:
session.delete(rephrase_llm_call)
session.commit()
except Exception:
pass
raise
return rephrase_llm_call.id
Comment thread
coderabbitai[bot] marked this conversation as resolved.
except Exception as e:
logger.error(
f"[save_rephrase_guardrail_call] Failed to record rephrase guardrail call: {e} | job_id={job_id}",
exc_info=True,
)
return None


def get_llm_call_by_id(
session: Session,
llm_call_id: UUID,
Expand Down
2 changes: 1 addition & 1 deletion backend/app/services/llm/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def run_guardrails_validation(
}

try:
with httpx.Client(timeout=10.0) as client:
with httpx.Client(timeout=45.0) as client:
response = client.post(
f"{settings.KAAPI_GUARDRAILS_URL}/",
json=payload,
Expand Down
22 changes: 22 additions & 0 deletions backend/app/services/llm/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
serialize_input,
update_llm_call_input,
update_llm_call_response,
save_rephrase_guardrail_call,
)
from app.crud.llm_chain import create_llm_chain, update_llm_chain_status
from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest
Expand Down Expand Up @@ -523,6 +524,12 @@ def execute_llm_call(
else:
config_blob = config.blob

original_input_value = (
query.input.content.value
if isinstance(query.input, TextInput)
else None
)

if config_blob.prompt_template and isinstance(query.input, TextInput):
template = config_blob.prompt_template.template
interpolated = template.replace("{{input}}", query.input.content.value)
Expand Down Expand Up @@ -561,10 +568,25 @@ def execute_llm_call(
),
usage=guardrail_usage,
)
if original_input_value is not None:
query.input.content.value = original_input_value
llm_call_id = save_rephrase_guardrail_call(
session=session,
query=query,
config=config,
request_metadata=request_metadata,
config_blob=config_blob,
guardrail_direct_response=guardrail_direct_response,
job_id=job_id,
project_id=project_id,
organization_id=organization_id,
chain_id=chain_id,
)
return BlockResult(
response=llm_response,
usage=guardrail_usage,
metadata=request_metadata,
llm_call_id=llm_call_id,
)
if input_error:
guard_span.set_status(
Expand Down
124 changes: 123 additions & 1 deletion backend/app/tests/services/llm/test_guardrails.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
import uuid
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import UUID, uuid4

import httpx
import pytest
from sqlmodel import Session, select

from app.core.config import settings
from app.models.llm.request import Validator
from app.crud.jobs import JobCrud
from app.crud.llm import save_rephrase_guardrail_call
from app.models import Job, JobType
from app.models.llm.request import (
ConfigBlob,
LLMCallConfig,
LlmCall,
NativeCompletionConfig,
QueryParams,
Validator,
)
from app.services.llm.guardrails import (
list_validators_config,
run_guardrails_validation,
)
from app.tests.utils.utils import get_project


TEST_JOB_ID = uuid.uuid4()
Expand Down Expand Up @@ -253,3 +268,110 @@ def test_list_validators_config_network_error_fails_open(mock_client_cls) -> Non

assert input_guardrails == []
assert output_guardrails == []


_SAFE_TEXT = "Please rephrase: content not allowed."
_CONFIG_BLOB = ConfigBlob(
completion=NativeCompletionConfig(
provider="openai-native",
params={"model": "gpt-4o"},
type="text",
),
input_guardrails=[Validator(validator_config_id=uuid4())],
)
_CONFIG = LLMCallConfig(blob=_CONFIG_BLOB)


class TestSaveRephraseGuardrailCall:
@pytest.fixture
def job(self, db: Session) -> Job:
j = JobCrud(session=db).create(
job_type=JobType.LLM_API, trace_id="rephrase-test"
)
db.commit()
return j

def _call(self, db: Session, job: Job, **overrides: Any) -> UUID | None:
project = get_project(db)
kwargs = dict(
session=db,
query=QueryParams(input="original unsafe input"),
config=_CONFIG,
request_metadata=None,
config_blob=_CONFIG_BLOB,
guardrail_direct_response=_SAFE_TEXT,
job_id=job.id,
project_id=project.id,
organization_id=project.organization_id,
chain_id=None,
)
kwargs.update(overrides)
return save_rephrase_guardrail_call(**kwargs)

def test_success_returns_uuid(self, db: Session, job: Job) -> None:
result = self._call(db, job)
assert isinstance(result, UUID)

def test_success_saves_original_input_and_job_id(
self, db: Session, job: Job
) -> None:
llm_call_id = self._call(db, job)
llm_call = db.exec(select(LlmCall).where(LlmCall.id == llm_call_id)).first()
assert llm_call is not None
assert llm_call.input == "original unsafe input"
assert llm_call.job_id == job.id

def test_success_saves_safe_text_as_content(self, db: Session, job: Job) -> None:
llm_call_id = self._call(db, job)
llm_call = db.exec(select(LlmCall).where(LlmCall.id == llm_call_id)).first()
assert llm_call.content == {
"type": "text",
"content": {"format": "text", "value": _SAFE_TEXT},
}

def test_success_saves_zero_usage(self, db: Session, job: Job) -> None:
llm_call_id = self._call(db, job)
llm_call = db.exec(select(LlmCall).where(LlmCall.id == llm_call_id)).first()
assert llm_call.usage == {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
}

def test_create_llm_call_error_returns_none(self, db: Session, job: Job) -> None:
with patch(
"app.crud.llm.create_llm_call",
side_effect=Exception("DB insert failed"),
):
result = self._call(db, job)
assert result is None

def test_update_llm_call_response_error_returns_none(
self, db: Session, job: Job
) -> None:
with patch(
"app.crud.llm.update_llm_call_response",
side_effect=Exception("DB update failed"),
):
result = self._call(db, job)
assert result is None

def test_chain_id_forwarded_to_create_llm_call(self, db: Session, job: Job) -> None:
chain_id = uuid4()
with patch("app.crud.llm.create_llm_call") as mock_create:
mock_create.return_value = MagicMock(id=uuid4())
with patch("app.crud.llm.update_llm_call_response"):
self._call(db, job, chain_id=chain_id)
_, kwargs = mock_create.call_args
assert kwargs["chain_id"] == chain_id

def test_request_metadata_forwarded_to_llm_call_request(
self, db: Session, job: Job
) -> None:
metadata = {"request_id": "abc", "user": "test"}
with patch("app.crud.llm.create_llm_call") as mock_create:
mock_create.return_value = MagicMock(id=uuid4())
with patch("app.crud.llm.update_llm_call_response"):
self._call(db, job, request_metadata=metadata)
_, kwargs = mock_create.call_args
assert kwargs["request"].request_metadata == metadata
56 changes: 56 additions & 0 deletions backend/app/tests/services/llm/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,62 @@ def test_guardrails_rephrase_needed_allows_job_with_sanitized_input(
result["data"]["response"]["output"]["content"]["value"] == "Rephrased text"
)

def test_guardrails_rephrase_saves_original_input_and_safe_text_in_db(
self, db, job_env, job_for_execution
):
Comment thread
nishika26 marked this conversation as resolved.
from app.models.llm.request import LlmCall

env = job_env

with (
patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails,
patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs,
):
mock_guardrails.return_value = {
"success": True,
"bypassed": False,
"data": {
"safe_text": "Please rephrase the query without unsafe content. Input is outside the allowed topic scope.",
"rephrase_needed": True,
},
}
mock_fetch_configs.return_value = (
[{"type": "policy", "stage": "input"}],
[],
)

request_data = {
"query": {"input": "unsafe user query"},
"config": {
"blob": {
"completion": {
"provider": "openai-native",
"type": "text",
"params": {"model": "gpt-4o"},
},
"input_guardrails": [
{"validator_config_id": VALIDATOR_CONFIG_ID_1}
],
"output_guardrails": [],
}
},
}
self._execute_job(job_for_execution, db, request_data)

llm_call = db.exec(
select(LlmCall).where(LlmCall.job_id == job_for_execution.id)
).first()

assert llm_call is not None
assert llm_call.input == "unsafe user query"
assert llm_call.content == {
"type": "text",
"content": {
"format": "text",
"value": "Please rephrase the query without unsafe content. Input is outside the allowed topic scope.",
},
}

def test_execute_job_fetches_validator_configs_from_blob_refs(
self, db, job_env, job_for_execution
):
Expand Down
Loading