diff --git a/backend/app/crud/llm.py b/backend/app/crud/llm.py index 37fd7368a..44e605fb7 100644 --- a/backend/app/crud/llm.py +++ b/backend/app/crud/llm.py @@ -14,6 +14,8 @@ QueryInput, ImageInput, PDFInput, + LLMCallConfig, + QueryParams, ) logger = logging.getLogger(__name__) @@ -125,7 +127,7 @@ def create_llm_call( } else: config_dict = { - "config_blob": resolved_config.model_dump(), + "config_blob": resolved_config.model_dump(mode="json"), } # Extract conversation info if present @@ -249,6 +251,74 @@ def update_llm_call_input( ) +def save_rephrase_guardrail_call( + *, + session: Session, + query: QueryParams, + config: LLMCallConfig, + request_metadata: dict | None, + config_blob: ConfigBlob, + guardrail_direct_response: str, + job_id: UUID, + project_id: int, + organization_id: int, + chain_id: UUID | None, +) -> UUID | None: + """Persist the LLM call record for a guardrail rephrase response. + + Returns the llm_call_id on success, None if the DB write fails (non-fatal). + """ + try: + rephrase_call_request = LLMCallRequest( + query=query, + config=config, + request_metadata=request_metadata, + ) + rephrase_llm_call = create_llm_call( + session, + request=rephrase_call_request, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + resolved_config=config_blob, + original_provider=str(config_blob.completion.provider), + chain_id=chain_id, + ) + try: + update_llm_call_response( + session, + llm_call_id=rephrase_llm_call.id, + provider_response_id=None, + content={ + "type": "text", + "content": { + "format": "text", + "value": guardrail_direct_response, + }, + }, + # No LLM was invoked, so token counts are genuinely zero. + usage={ + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + }, + ) + except Exception: + try: + session.delete(rephrase_llm_call) + session.commit() + except Exception: + pass + raise + return rephrase_llm_call.id + except Exception as e: + logger.error( + f"[save_rephrase_guardrail_call] Failed to record rephrase guardrail call: {e} | job_id={job_id}", + exc_info=True, + ) + return None + + def get_llm_call_by_id( session: Session, llm_call_id: UUID, diff --git a/backend/app/services/llm/guardrails.py b/backend/app/services/llm/guardrails.py index 7ba8d72fe..916c8bd94 100644 --- a/backend/app/services/llm/guardrails.py +++ b/backend/app/services/llm/guardrails.py @@ -54,7 +54,7 @@ def run_guardrails_validation( } try: - with httpx.Client(timeout=10.0) as client: + with httpx.Client(timeout=45.0) as client: response = client.post( f"{settings.KAAPI_GUARDRAILS_URL}/", json=payload, diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 76757b06f..f818ea489 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -33,6 +33,7 @@ serialize_input, update_llm_call_input, update_llm_call_response, + save_rephrase_guardrail_call, ) from app.crud.llm_chain import create_llm_chain, update_llm_chain_status from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest @@ -523,6 +524,12 @@ def execute_llm_call( else: config_blob = config.blob + original_input_value = ( + query.input.content.value + if isinstance(query.input, TextInput) + else None + ) + if config_blob.prompt_template and isinstance(query.input, TextInput): template = config_blob.prompt_template.template interpolated = template.replace("{{input}}", query.input.content.value) @@ -561,10 +568,25 @@ def execute_llm_call( ), usage=guardrail_usage, ) + if original_input_value is not None: + query.input.content.value = original_input_value + llm_call_id = save_rephrase_guardrail_call( + session=session, + query=query, + config=config, + request_metadata=request_metadata, + config_blob=config_blob, + guardrail_direct_response=guardrail_direct_response, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + chain_id=chain_id, + ) return BlockResult( response=llm_response, usage=guardrail_usage, metadata=request_metadata, + llm_call_id=llm_call_id, ) if input_error: guard_span.set_status( diff --git a/backend/app/tests/services/llm/test_guardrails.py b/backend/app/tests/services/llm/test_guardrails.py index 161056980..990b7364f 100644 --- a/backend/app/tests/services/llm/test_guardrails.py +++ b/backend/app/tests/services/llm/test_guardrails.py @@ -1,14 +1,29 @@ import uuid +from typing import Any from unittest.mock import MagicMock, patch +from uuid import UUID, uuid4 import httpx +import pytest +from sqlmodel import Session, select from app.core.config import settings -from app.models.llm.request import Validator +from app.crud.jobs import JobCrud +from app.crud.llm import save_rephrase_guardrail_call +from app.models import Job, JobType +from app.models.llm.request import ( + ConfigBlob, + LLMCallConfig, + LlmCall, + NativeCompletionConfig, + QueryParams, + Validator, +) from app.services.llm.guardrails import ( list_validators_config, run_guardrails_validation, ) +from app.tests.utils.utils import get_project TEST_JOB_ID = uuid.uuid4() @@ -253,3 +268,110 @@ def test_list_validators_config_network_error_fails_open(mock_client_cls) -> Non assert input_guardrails == [] assert output_guardrails == [] + + +_SAFE_TEXT = "Please rephrase: content not allowed." +_CONFIG_BLOB = ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + params={"model": "gpt-4o"}, + type="text", + ), + input_guardrails=[Validator(validator_config_id=uuid4())], +) +_CONFIG = LLMCallConfig(blob=_CONFIG_BLOB) + + +class TestSaveRephraseGuardrailCall: + @pytest.fixture + def job(self, db: Session) -> Job: + j = JobCrud(session=db).create( + job_type=JobType.LLM_API, trace_id="rephrase-test" + ) + db.commit() + return j + + def _call(self, db: Session, job: Job, **overrides: Any) -> UUID | None: + project = get_project(db) + kwargs = dict( + session=db, + query=QueryParams(input="original unsafe input"), + config=_CONFIG, + request_metadata=None, + config_blob=_CONFIG_BLOB, + guardrail_direct_response=_SAFE_TEXT, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + chain_id=None, + ) + kwargs.update(overrides) + return save_rephrase_guardrail_call(**kwargs) + + def test_success_returns_uuid(self, db: Session, job: Job) -> None: + result = self._call(db, job) + assert isinstance(result, UUID) + + def test_success_saves_original_input_and_job_id( + self, db: Session, job: Job + ) -> None: + llm_call_id = self._call(db, job) + llm_call = db.exec(select(LlmCall).where(LlmCall.id == llm_call_id)).first() + assert llm_call is not None + assert llm_call.input == "original unsafe input" + assert llm_call.job_id == job.id + + def test_success_saves_safe_text_as_content(self, db: Session, job: Job) -> None: + llm_call_id = self._call(db, job) + llm_call = db.exec(select(LlmCall).where(LlmCall.id == llm_call_id)).first() + assert llm_call.content == { + "type": "text", + "content": {"format": "text", "value": _SAFE_TEXT}, + } + + def test_success_saves_zero_usage(self, db: Session, job: Job) -> None: + llm_call_id = self._call(db, job) + llm_call = db.exec(select(LlmCall).where(LlmCall.id == llm_call_id)).first() + assert llm_call.usage == { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + + def test_create_llm_call_error_returns_none(self, db: Session, job: Job) -> None: + with patch( + "app.crud.llm.create_llm_call", + side_effect=Exception("DB insert failed"), + ): + result = self._call(db, job) + assert result is None + + def test_update_llm_call_response_error_returns_none( + self, db: Session, job: Job + ) -> None: + with patch( + "app.crud.llm.update_llm_call_response", + side_effect=Exception("DB update failed"), + ): + result = self._call(db, job) + assert result is None + + def test_chain_id_forwarded_to_create_llm_call(self, db: Session, job: Job) -> None: + chain_id = uuid4() + with patch("app.crud.llm.create_llm_call") as mock_create: + mock_create.return_value = MagicMock(id=uuid4()) + with patch("app.crud.llm.update_llm_call_response"): + self._call(db, job, chain_id=chain_id) + _, kwargs = mock_create.call_args + assert kwargs["chain_id"] == chain_id + + def test_request_metadata_forwarded_to_llm_call_request( + self, db: Session, job: Job + ) -> None: + metadata = {"request_id": "abc", "user": "test"} + with patch("app.crud.llm.create_llm_call") as mock_create: + mock_create.return_value = MagicMock(id=uuid4()) + with patch("app.crud.llm.update_llm_call_response"): + self._call(db, job, request_metadata=metadata) + _, kwargs = mock_create.call_args + assert kwargs["request"].request_metadata == metadata diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index c15f28fcb..4f134e782 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -1174,6 +1174,62 @@ def test_guardrails_rephrase_needed_allows_job_with_sanitized_input( result["data"]["response"]["output"]["content"]["value"] == "Rephrased text" ) + def test_guardrails_rephrase_saves_original_input_and_safe_text_in_db( + self, db, job_env, job_for_execution + ): + from app.models.llm.request import LlmCall + + env = job_env + + with ( + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + ): + mock_guardrails.return_value = { + "success": True, + "bypassed": False, + "data": { + "safe_text": "Please rephrase the query without unsafe content. Input is outside the allowed topic scope.", + "rephrase_needed": True, + }, + } + mock_fetch_configs.return_value = ( + [{"type": "policy", "stage": "input"}], + [], + ) + + request_data = { + "query": {"input": "unsafe user query"}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "type": "text", + "params": {"model": "gpt-4o"}, + }, + "input_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_1} + ], + "output_guardrails": [], + } + }, + } + self._execute_job(job_for_execution, db, request_data) + + llm_call = db.exec( + select(LlmCall).where(LlmCall.job_id == job_for_execution.id) + ).first() + + assert llm_call is not None + assert llm_call.input == "unsafe user query" + assert llm_call.content == { + "type": "text", + "content": { + "format": "text", + "value": "Please rephrase the query without unsafe content. Input is outside the allowed topic scope.", + }, + } + def test_execute_job_fetches_validator_configs_from_blob_refs( self, db, job_env, job_for_execution ):