diff --git a/backend/app/crud/config/config.py b/backend/app/crud/config/config.py index 00ac3b92..69d4bced 100644 --- a/backend/app/crud/config/config.py +++ b/backend/app/crud/config/config.py @@ -47,7 +47,7 @@ def create_or_raise( version = ConfigVersion( config_id=config.id, version=1, - config_blob=config_create.config_blob, + config_blob=config_create.config_blob.model_dump(), commit_message=config_create.commit_message, ) diff --git a/backend/app/crud/config/version.py b/backend/app/crud/config/version.py index cf4a3ae2..f834c168 100644 --- a/backend/app/crud/config/version.py +++ b/backend/app/crud/config/version.py @@ -34,7 +34,7 @@ def create_or_raise(self, version_create: ConfigVersionCreate) -> ConfigVersion: version = ConfigVersion( config_id=self.config_id, version=next_version, - config_blob=version_create.config_blob, + config_blob=version_create.config_blob.model_dump(), commit_message=version_create.commit_message, ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index d2246e75..9a351825 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -87,6 +87,8 @@ from .job import Job, JobType, JobStatus, JobUpdate from .llm import ( + ConfigBlob, + CompletionConfig, LLMCallRequest, LLMCallResponse, ) diff --git a/backend/app/models/config/config.py b/backend/app/models/config/config.py index f1378980..18bbbcdf 100644 --- a/backend/app/models/config/config.py +++ b/backend/app/models/config/config.py @@ -2,10 +2,11 @@ from datetime import datetime from typing import TYPE_CHECKING, Any -from sqlmodel import Field, SQLModel, UniqueConstraint, Index, text +from sqlmodel import Field, SQLModel, Index, text from pydantic import field_validator from app.core.util import now +from app.models.llm.request import ConfigBlob from .version import ConfigVersionPublic @@ -56,7 +57,7 @@ class ConfigCreate(ConfigBase): """Create new configuration""" # Initial version data - config_blob: dict[str, Any] = Field(description="Provider-specific parameters") + config_blob: ConfigBlob = Field(description="Provider-specific parameters") commit_message: str | None = Field( default=None, max_length=512, diff --git a/backend/app/models/config/version.py b/backend/app/models/config/version.py index 0169b048..bb44531d 100644 --- a/backend/app/models/config/version.py +++ b/backend/app/models/config/version.py @@ -8,6 +8,7 @@ from sqlmodel import Field, SQLModel, UniqueConstraint, Index, text from app.core.util import now +from app.models.llm.request import ConfigBlob class ConfigVersionBase(SQLModel): @@ -60,7 +61,12 @@ class ConfigVersion(ConfigVersionBase, table=True): class ConfigVersionCreate(ConfigVersionBase): - pass + # Store config_blob as JSON in the DB. Validation uses ConfigBlob only at creation + # time, since schema may evolve. When fetching, it is returned as a raw dict and + # re-validated against the latest schema before use. + config_blob: ConfigBlob = Field( + description="Provider-specific configuration parameters (temperature, max_tokens, etc.)", + ) class ConfigVersionPublic(ConfigVersionBase): diff --git a/backend/app/models/llm/__init__.py b/backend/app/models/llm/__init__.py index c1db4a0e..f06954de 100644 --- a/backend/app/models/llm/__init__.py +++ b/backend/app/models/llm/__init__.py @@ -1,2 +1,7 @@ -from app.models.llm.request import LLMCallRequest, CompletionConfig, QueryParams +from app.models.llm.request import ( + LLMCallRequest, + CompletionConfig, + QueryParams, + ConfigBlob, +) from app.models.llm.response import LLMCallResponse, LLMResponse, LLMOutput, Usage diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 249bb2fa..a63de1eb 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -1,5 +1,6 @@ from typing import Any, Literal +from uuid import UUID from sqlmodel import Field, SQLModel from pydantic import model_validator, HttpUrl @@ -57,8 +58,8 @@ class CompletionConfig(SQLModel): ) -class LLMCallConfig(SQLModel): - """Complete configuration for LLM call including all processing stages.""" +class ConfigBlob(SQLModel): + """Raw JSON blob of config.""" completion: CompletionConfig = Field(..., description="Completion configuration") # Future additions: @@ -66,11 +67,85 @@ class LLMCallConfig(SQLModel): # pre_filter: PreFilterConfig | None = None +class LLMCallConfig(SQLModel): + """ + Complete configuration for LLM call including all processing stages. + Either references a stored config (id + version) or provides an ad-hoc config blob. + Depending on which is provided, only one of the two options should be used. + """ + + id: UUID | None = Field( + default=None, + description=( + "Identifier for an existing LLM call configuration. [require version if provided]" + ), + ) + version: int | None = Field( + default=None, + ge=1, + description=( + "Version of the stored config to use. [require if id is provided]" + ), + ) + + blob: ConfigBlob | None = Field( + default=None, + description=( + "Raw JSON blob of the full configuration. Used for ad-hoc configurations without storing." + "Either this or (id + version) must be provided." + ), + ) + + @model_validator(mode="after") + def validate_config_logic(self): + has_stored = self.id is not None or self.version is not None + has_blob = self.blob is not None + + if has_stored and has_blob: + raise ValueError( + "Provide either 'id' with 'version' for stored config OR 'blob' for ad-hoc config, not both." + ) + + if has_stored: + if not self.id or not self.version: + raise ValueError( + "'id' and 'version' must both be provided together for stored config." + ) + return self + + if not has_blob: + raise ValueError( + "Must provide either a stored config (id + version) or an ad-hoc config (blob)." + ) + + return self + + @property + def is_stored_config(self) -> bool: + """Check if the config refers to a stored config or not.""" + return self.id is not None and self.version is not None + + class LLMCallRequest(SQLModel): - """User-facing API request for LLM completion.""" + """ + API request for an LLM completion. + + The `config` field accepts either: + - **Stored config (id + version)** — recommended for all production use. + - **Inline config blob** — for testing or validating new configs. + + Prefer stored configs in production; use blobs only for development/testing/validations. + """ query: QueryParams = Field(..., description="Query-specific parameters") - config: LLMCallConfig = Field(..., description="Configuration for the LLM call") + config: LLMCallConfig = Field( + ..., + description=( + "Complete LLM call configuration, provided either by reference (id + version) " + "or as config blob. Use the blob only for testing/validation; " + "in production, always use the id + version." + ), + ) callback_url: HttpUrl | None = Field( default=None, description="Webhook URL for async response delivery" ) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index ca9e77c1..a8ad9d83 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -6,8 +6,10 @@ from sqlmodel import Session from app.core.db import engine +from app.crud.config import ConfigVersionCrud from app.crud.jobs import JobCrud -from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMCallResponse +from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest +from app.models.llm.request import ConfigBlob, LLMCallConfig from app.utils import APIResponse, send_callback from app.celery.utils import start_high_priority_job from app.services.llm.providers.registry import get_llm_provider @@ -76,6 +78,41 @@ def handle_job_error( return callback_response.model_dump() +def resolve_config_blob( + config_crud: ConfigVersionCrud, config: LLMCallConfig +) -> tuple[ConfigBlob | None, str | None]: + """Fetch and parse stored config version into ConfigBlob. + + Returns: + (config_blob, error_message) + - config_blob: ConfigBlob if successful, else None + - error_message: human-safe error string if an error occurs, else None + """ + try: + config_version = config_crud.exists_or_raise(version_number=config.version) + except HTTPException as e: + return None, f"Failed to retrieve stored configuration: {e.detail}" + except Exception: + logger.error( + f"[resolve_config_blob] Unexpected error retrieving config version | " + f"config_id={config.id}, version={config.version}", + exc_info=True, + ) + return None, "Unexpected error occurred while retrieving stored configuration" + + try: + return ConfigBlob(**config_version.config_blob), None + except (TypeError, ValueError) as e: + return None, f"Stored configuration blob is invalid: {str(e)}" + except Exception: + logger.error( + f"[resolve_config_blob] Unexpected error parsing config blob | " + f"config_id={config.id}, version={config.version}", + exc_info=True, + ) + return None, "Unexpected error occurred while parsing stored configuration" + + def execute_job( request_data: dict, project_id: int, @@ -93,53 +130,72 @@ def execute_job( request = LLMCallRequest(**request_data) job_id: UUID = UUID(job_id) + # one of (id, version) or blob is guaranteed to be present due to prior validation config = request.config - provider = config.completion.provider - callback = None + callback_response = None + config_blob: ConfigBlob | None = None logger.info( f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}, " - f"provider={provider}" ) try: - # Update job status to PROCESSING with Session(engine) as session: + # Update job status to PROCESSING job_crud = JobCrud(session=session) job_crud.update( job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) ) + # if stored config, fetch blob from DB + if config.is_stored_config: + config_crud = ConfigVersionCrud( + session=session, project_id=project_id, config_id=config.id + ) + + # blob is dynamic, need to resolve to ConfigBlob format + config_blob, error = resolve_config_blob(config_crud, config) + + if error: + callback_response = APIResponse.failure_response( + error=error, + metadata=request.request_metadata, + ) + return handle_job_error( + job_id, request.callback_url, callback_response + ) + + else: + config_blob = config.blob + try: provider_instance = get_llm_provider( session=session, - provider_type=provider, + provider_type=config_blob.completion.provider, project_id=project_id, organization_id=organization_id, ) except ValueError as ve: - callback = APIResponse.failure_response( + callback_response = APIResponse.failure_response( error=str(ve), metadata=request.request_metadata, ) - - if callback: - return handle_job_error(job_id, request.callback_url, callback) + return handle_job_error(job_id, request.callback_url, callback_response) response, error = provider_instance.execute( - completion_config=config.completion, + completion_config=config_blob.completion, query=request.query, include_provider_raw_response=request.include_provider_raw_response, ) if response: - callback = APIResponse.success_response( + callback_response = APIResponse.success_response( data=response, metadata=request.request_metadata ) if request.callback_url: send_callback( callback_url=request.callback_url, - data=callback.model_dump(), + data=callback_response.model_dump(), ) with Session(engine) as session: @@ -152,21 +208,21 @@ def execute_job( f"[execute_job] Successfully completed LLM job | job_id={job_id}, " f"provider_response_id={response.response.provider_response_id}, tokens={response.usage.total_tokens}" ) - return callback.model_dump() + return callback_response.model_dump() - callback = APIResponse.failure_response( + callback_response = APIResponse.failure_response( error=error or "Unknown error occurred", metadata=request.request_metadata, ) - return handle_job_error(job_id, request.callback_url, callback) + return handle_job_error(job_id, request.callback_url, callback_response) except Exception as e: - callback = APIResponse.failure_response( + callback_response = APIResponse.failure_response( error=f"Unexpected error occurred", metadata=request.request_metadata, ) logger.error( - f"[execute_job] {callback.error} {str(e)} | job_id={job_id}, task_id={task_id}", + f"[execute_job] Unknown error occurred: {str(e)} | job_id={job_id}, task_id={task_id}", exc_info=True, ) - return handle_job_error(job_id, request.callback_url, callback) + return handle_job_error(job_id, request.callback_url, callback_response) diff --git a/backend/app/tests/api/routes/configs/test_config.py b/backend/app/tests/api/routes/configs/test_config.py index 631f746e..8f094f53 100644 --- a/backend/app/tests/api/routes/configs/test_config.py +++ b/backend/app/tests/api/routes/configs/test_config.py @@ -18,9 +18,14 @@ def test_create_config_success( "name": "test-llm-config", "description": "A test LLM configuration", "config_blob": { - "model": "gpt-4", - "temperature": 0.8, - "max_tokens": 2000, + "completion": { + "provider": "openai", + "params": { + "model": "gpt-4", + "temperature": 0.8, + "max_tokens": 2000, + }, + } }, "commit_message": "Initial configuration", } @@ -81,7 +86,12 @@ def test_create_config_duplicate_name_fails( config_data = { "name": "duplicate-config", "description": "Should fail", - "config_blob": {"model": "gpt-4"}, + "config_blob": { + "completion": { + "provider": "openai", + "params": {"model": "gpt-4"}, + } + }, "commit_message": "Initial", } @@ -406,7 +416,12 @@ def test_create_config_requires_authentication( config_data = { "name": "test-config", "description": "Test", - "config_blob": {"model": "gpt-4"}, + "config_blob": { + "completion": { + "provider": "openai", + "params": {"model": "gpt-4"}, + } + }, "commit_message": "Initial", } diff --git a/backend/app/tests/api/routes/configs/test_version.py b/backend/app/tests/api/routes/configs/test_version.py index 882e57fc..acb9f252 100644 --- a/backend/app/tests/api/routes/configs/test_version.py +++ b/backend/app/tests/api/routes/configs/test_version.py @@ -10,6 +10,7 @@ create_test_project, create_test_version, ) +from app.models import ConfigBlob, CompletionConfig def test_create_version_success( @@ -26,9 +27,14 @@ def test_create_version_success( version_data = { "config_blob": { - "model": "gpt-4-turbo", - "temperature": 0.9, - "max_tokens": 3000, + "completion": { + "provider": "openai", + "params": { + "model": "gpt-4-turbo", + "temperature": 0.9, + "max_tokens": 3000, + }, + } }, "commit_message": "Updated model to gpt-4-turbo", } @@ -83,7 +89,12 @@ def test_create_version_nonexistent_config( """Test creating a version for a non-existent config returns 404.""" fake_uuid = uuid4() version_data = { - "config_blob": {"model": "gpt-4"}, + "config_blob": { + "completion": { + "provider": "openai", + "params": {"model": "gpt-4"}, + } + }, "commit_message": "Test", } @@ -109,7 +120,12 @@ def test_create_version_different_project_fails( ) version_data = { - "config_blob": {"model": "gpt-4"}, + "config_blob": { + "completion": { + "provider": "openai", + "params": {"model": "gpt-4"}, + } + }, "commit_message": "Should fail", } @@ -136,7 +152,12 @@ def test_create_version_auto_increments( # Create multiple versions and verify they increment for i in range(2, 5): version_data = { - "config_blob": {"model": f"gpt-4-version-{i}"}, + "config_blob": { + "completion": { + "provider": "openai", + "params": {"model": f"gpt-4-version-{i}"}, + } + }, "commit_message": f"Version {i}", } @@ -277,12 +298,16 @@ def test_get_version_by_number( name="test-config", ) - # Create additional version version = create_test_version( db=db, config_id=config.id, project_id=user_api_key.project_id, - config_blob={"model": "gpt-4-turbo", "temperature": 0.5}, + config_blob=ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-4-turbo", "temperature": 0.5}, + ) + ), commit_message="Updated config", ) @@ -421,7 +446,12 @@ def test_create_version_requires_authentication( ) -> None: """Test that creating a version without authentication fails.""" version_data = { - "config_blob": {"model": "gpt-4"}, + "config_blob": { + "completion": { + "provider": "openai", + "params": {"model": "gpt-4"}, + } + }, "commit_message": "Test", } diff --git a/backend/app/tests/api/routes/test_llm.py b/backend/app/tests/api/routes/test_llm.py index 08414cfd..430ca77c 100644 --- a/backend/app/tests/api/routes/test_llm.py +++ b/backend/app/tests/api/routes/test_llm.py @@ -1,7 +1,12 @@ from unittest.mock import patch from fastapi.testclient import TestClient from app.models import LLMCallRequest -from app.models.llm.request import QueryParams, LLMCallConfig, CompletionConfig +from app.models.llm.request import ( + QueryParams, + LLMCallConfig, + CompletionConfig, + ConfigBlob, +) def test_llm_call_success(client: TestClient, user_api_key_header: dict[str, str]): @@ -12,12 +17,14 @@ def test_llm_call_success(client: TestClient, user_api_key_header: dict[str, str payload = LLMCallRequest( query=QueryParams(input="What is the capital of France?"), config=LLMCallConfig( - completion=CompletionConfig( - provider="openai", - params={ - "model": "gpt-4", - "temperature": 0.7, - }, + blob=ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={ + "model": "gpt-4", + "temperature": 0.7, + }, + ) ) ), callback_url="https://example.com/callback", diff --git a/backend/app/tests/crud/config/test_config.py b/backend/app/tests/crud/config/test_config.py index 6b753e0b..e7837b98 100644 --- a/backend/app/tests/crud/config/test_config.py +++ b/backend/app/tests/crud/config/test_config.py @@ -5,6 +5,8 @@ from app.models import ( Config, + ConfigBlob, + CompletionConfig, ConfigCreate, ConfigUpdate, ) @@ -13,21 +15,30 @@ from app.tests.utils.utils import random_lower_string -def test_create_config(db: Session) -> None: +@pytest.fixture +def example_config_blob(): + return ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={ + "model": "gpt-4", + "temperature": 0.8, + "max_tokens": 1500, + }, + ) + ) + + +def test_create_config(db: Session, example_config_blob: ConfigBlob) -> None: """Test creating a new configuration with initial version.""" project = create_test_project(db) config_crud = ConfigCrud(session=db, project_id=project.id) config_name = f"test-config-{random_lower_string()}" - config_blob = { - "model": "gpt-4", - "temperature": 0.7, - "max_tokens": 1000, - } config_create = ConfigCreate( name=config_name, description="Test configuration", - config_blob=config_blob, + config_blob=example_config_blob, commit_message="Initial version", ) @@ -43,11 +54,13 @@ def test_create_config(db: Session) -> None: assert version.id is not None assert version.config_id == config.id assert version.version == 1 - assert version.config_blob == config_blob + assert version.config_blob == example_config_blob.model_dump() assert version.commit_message == "Initial version" -def test_create_config_duplicate_name(db: Session) -> None: +def test_create_config_duplicate_name( + db: Session, example_config_blob: ConfigBlob +) -> None: """Test creating a configuration with a duplicate name raises HTTPException.""" project = create_test_project(db) config_crud = ConfigCrud(session=db, project_id=project.id) @@ -56,7 +69,7 @@ def test_create_config_duplicate_name(db: Session) -> None: config_create = ConfigCreate( name=config_name, description="Test configuration", - config_blob={"model": "gpt-4"}, + config_blob=example_config_blob, commit_message="Initial version", ) @@ -70,13 +83,15 @@ def test_create_config_duplicate_name(db: Session) -> None: config_crud.create_or_raise(config_create) -def test_create_config_different_projects_same_name(db: Session) -> None: +def test_create_config_different_projects_same_name( + db: Session, example_config_blob: ConfigBlob +) -> None: """Test creating configs with same name in different projects succeeds.""" project1 = create_test_project(db) project2 = create_test_project(db) config_name = f"test-config-{random_lower_string()}" - config_blob = {"model": "gpt-4"} + config_blob = example_config_blob # Create config in project1 config_crud1 = ConfigCrud(session=db, project_id=project1.id) diff --git a/backend/app/tests/crud/config/test_version.py b/backend/app/tests/crud/config/test_version.py index d62265d5..c3c4bd58 100644 --- a/backend/app/tests/crud/config/test_version.py +++ b/backend/app/tests/crud/config/test_version.py @@ -3,7 +3,7 @@ from sqlmodel import Session from fastapi import HTTPException -from app.models import ConfigVersionCreate +from app.models import ConfigVersionCreate, ConfigBlob, CompletionConfig from app.crud.config import ConfigVersionCrud from app.tests.utils.test_data import ( create_test_project, @@ -12,18 +12,28 @@ ) -def test_create_version(db: Session) -> None: +@pytest.fixture +def example_config_blob(): + return ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={ + "model": "gpt-4", + "temperature": 0.8, + "max_tokens": 1500, + }, + ) + ) + + +def test_create_version(db: Session, example_config_blob: ConfigBlob) -> None: """Test creating a new version for an existing configuration.""" config = create_test_config(db) version_crud = ConfigVersionCrud( session=db, project_id=config.project_id, config_id=config.id ) - config_blob = { - "model": "gpt-4-turbo", - "temperature": 0.8, - "max_tokens": 2000, - } + config_blob = example_config_blob.model_dump() version_create = ConfigVersionCreate( config_blob=config_blob, commit_message="Updated model and parameters", @@ -39,7 +49,9 @@ def test_create_version(db: Session) -> None: assert version.deleted_at is None -def test_create_version_auto_increment(db: Session) -> None: +def test_create_version_auto_increment( + db: Session, example_config_blob: ConfigBlob +) -> None: """Test that version numbers auto-increment correctly.""" config = create_test_config(db) version_crud = ConfigVersionCrud( @@ -48,13 +60,13 @@ def test_create_version_auto_increment(db: Session) -> None: # Create multiple versions version2 = version_crud.create_or_raise( - ConfigVersionCreate(config_blob={"model": "gpt-4"}, commit_message="Version 2") + ConfigVersionCreate(config_blob=example_config_blob, commit_message="Version 2") ) version3 = version_crud.create_or_raise( - ConfigVersionCreate(config_blob={"model": "gpt-4"}, commit_message="Version 3") + ConfigVersionCreate(config_blob=example_config_blob, commit_message="Version 3") ) version4 = version_crud.create_or_raise( - ConfigVersionCreate(config_blob={"model": "gpt-4"}, commit_message="Version 4") + ConfigVersionCreate(config_blob=example_config_blob, commit_message="Version 4") ) assert version2.version == 2 @@ -62,7 +74,9 @@ def test_create_version_auto_increment(db: Session) -> None: assert version4.version == 4 -def test_create_version_config_not_found(db: Session) -> None: +def test_create_version_config_not_found( + db: Session, example_config_blob: ConfigBlob +) -> None: """Test creating a version for a non-existent config raises HTTPException.""" project = create_test_project(db) non_existent_config_id = uuid4() @@ -72,7 +86,7 @@ def test_create_version_config_not_found(db: Session) -> None: ) version_create = ConfigVersionCreate( - config_blob={"model": "gpt-4"}, commit_message="Test" + config_blob=example_config_blob, commit_message="Test" ) with pytest.raises( @@ -81,14 +95,14 @@ def test_create_version_config_not_found(db: Session) -> None: version_crud.create_or_raise(version_create) -def test_read_one_version(db: Session) -> None: +def test_read_one_version(db: Session, example_config_blob: ConfigBlob) -> None: """Test reading a specific version by its version number.""" config = create_test_config(db) version = create_test_version( db, config_id=config.id, project_id=config.project_id, - config_blob={"model": "gpt-4-turbo"}, + config_blob=example_config_blob, commit_message="Test version", ) @@ -102,7 +116,7 @@ def test_read_one_version(db: Session) -> None: assert fetched_version.id == version.id assert fetched_version.version == version.version assert fetched_version.config_id == config.id - assert fetched_version.config_blob == {"model": "gpt-4-turbo"} + assert fetched_version.config_blob == example_config_blob.model_dump() def test_read_one_version_not_found(db: Session) -> None: @@ -228,7 +242,6 @@ def test_read_all_versions_excludes_blob(db: Session) -> None: db, config_id=config.id, project_id=config.project_id, - config_blob={"model": "gpt-4-turbo"}, ) version_crud = ConfigVersionCrud( @@ -360,7 +373,9 @@ def test_exists_version_deleted(db: Session) -> None: version_crud.exists_or_raise(version.version) -def test_create_version_different_configs(db: Session) -> None: +def test_create_version_different_configs( + db: Session, example_config_blob: ConfigBlob +) -> None: """Test that version numbers are independent across different configs.""" project = create_test_project(db) @@ -373,7 +388,7 @@ def test_create_version_different_configs(db: Session) -> None: session=db, project_id=project.id, config_id=config1.id ) version2_config1 = version_crud1.create_or_raise( - ConfigVersionCreate(config_blob={"model": "gpt-4"}, commit_message="V2") + ConfigVersionCreate(config_blob=example_config_blob, commit_message="V2") ) # Create versions for config2 @@ -381,7 +396,7 @@ def test_create_version_different_configs(db: Session) -> None: session=db, project_id=project.id, config_id=config2.id ) version2_config2 = version_crud2.create_or_raise( - ConfigVersionCreate(config_blob={"model": "gpt-4"}, commit_message="V2") + ConfigVersionCreate(config_blob=example_config_blob, commit_message="V2") ) # Both should have version 2 (independent numbering) diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index e301c74b..2f08b40c 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -5,11 +5,12 @@ from unittest.mock import patch, MagicMock from fastapi import HTTPException -from sqlmodel import Session +from sqlmodel import Session, select from app.crud import JobCrud +from app.crud.config import ConfigVersionCrud from app.utils import APIResponse -from app.models import JobStatus, JobType +from app.models import ConfigVersion, JobStatus, JobType from app.models.llm import ( LLMCallRequest, CompletionConfig, @@ -19,9 +20,15 @@ LLMOutput, Usage, ) -from app.models.llm.request import LLMCallConfig -from app.services.llm.jobs import start_job, handle_job_error, execute_job +from app.models.llm.request import ConfigBlob, LLMCallConfig +from app.services.llm.jobs import ( + start_job, + handle_job_error, + execute_job, + resolve_config_blob, +) from app.tests.utils.utils import get_project +from app.tests.utils.test_data import create_test_config class TestStartJob: @@ -32,9 +39,11 @@ def llm_call_request(self): return LLMCallRequest( query=QueryParams(input="Test query"), config=LLMCallConfig( - completion=CompletionConfig( - provider="openai", - params={"model": "gpt-4"}, + blob=ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-4"}, + ) ) ), ) @@ -215,7 +224,9 @@ def request_data(self): return { "query": {"input": "Test query"}, "config": { - "completion": {"provider": "openai", "params": {"model": "gpt-4"}} + "blob": { + "completion": {"provider": "openai", "params": {"model": "gpt-4"}} + } }, "include_provider_raw_response": False, "callback_url": None, @@ -378,3 +389,282 @@ def test_metadata_in_error_callback( env["send_callback"].assert_called_once() callback_data = env["send_callback"].call_args[1]["data"] assert callback_data["metadata"] == {"tracking_id": "track-456"} + + def test_stored_config_success(self, db, job_for_execution, mock_llm_response): + """Test successful execution with stored config (id + version).""" + project = get_project(db) + + # Create a real config in the database + config_blob = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-4", "temperature": 0.7}, + ) + ) + config = create_test_config(db, project_id=project.id, config_blob=config_blob) + db.commit() + + # Build request data with stored config + stored_request_data = { + "query": {"input": "Test query"}, + "config": { + "id": str(config.id), + "version": 1, + }, + "include_provider_raw_response": False, + "callback_url": None, + } + + with ( + patch("app.services.llm.jobs.Session") as mock_session_class, + patch("app.services.llm.jobs.get_llm_provider") as mock_get_provider, + ): + mock_session_class.return_value.__enter__.return_value = db + mock_session_class.return_value.__exit__.return_value = None + + # Mock LLM provider + mock_provider = MagicMock() + mock_provider.execute.return_value = (mock_llm_response, None) + mock_get_provider.return_value = mock_provider + + result = self._execute_job(job_for_execution, db, stored_request_data) + + # Verify provider was called + mock_get_provider.assert_called_once() + mock_provider.execute.assert_called_once() + + # Verify success + assert result["success"] + db.refresh(job_for_execution) + assert job_for_execution.status == JobStatus.SUCCESS + + def test_stored_config_with_callback( + self, db, job_for_execution, mock_llm_response + ): + """Test stored config with callback URL.""" + project = get_project(db) + + config_blob = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-3.5-turbo", "temperature": 0.5}, + ) + ) + config = create_test_config(db, project_id=project.id, config_blob=config_blob) + db.commit() + + stored_request_data = { + "query": {"input": "Test query with callback"}, + "config": { + "id": str(config.id), + "version": 1, + }, + "include_provider_raw_response": False, + "callback_url": "https://example.com/callback", + } + + with ( + patch("app.services.llm.jobs.Session") as mock_session_class, + patch("app.services.llm.jobs.get_llm_provider") as mock_get_provider, + patch("app.services.llm.jobs.send_callback") as mock_send_callback, + ): + mock_session_class.return_value.__enter__.return_value = db + mock_session_class.return_value.__exit__.return_value = None + + # Mock LLM provider + mock_provider = MagicMock() + mock_provider.execute.return_value = (mock_llm_response, None) + mock_get_provider.return_value = mock_provider + + result = self._execute_job(job_for_execution, db, stored_request_data) + + # Verify callback was sent + mock_send_callback.assert_called_once() + callback_data = mock_send_callback.call_args[1]["data"] + assert callback_data["success"] + + # Verify success + assert result["success"] + db.refresh(job_for_execution) + assert job_for_execution.status == JobStatus.SUCCESS + + def test_stored_config_version_not_found(self, db, job_for_execution): + """Test stored config when version doesn't exist.""" + project = get_project(db) + + config_blob = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-4"}, + ) + ) + config = create_test_config(db, project_id=project.id, config_blob=config_blob) + db.commit() + + stored_request_data = { + "query": {"input": "Test query"}, + "config": { + "id": str(config.id), + "version": 999, + }, + "include_provider_raw_response": False, + "callback_url": None, + } + + with patch("app.services.llm.jobs.Session") as mock_session_class: + mock_session_class.return_value.__enter__.return_value = db + mock_session_class.return_value.__exit__.return_value = None + + result = self._execute_job(job_for_execution, db, stored_request_data) + + # Verify failure + assert not result["success"] + assert "Failed to retrieve stored configuration" in result["error"] + db.refresh(job_for_execution) + assert job_for_execution.status == JobStatus.FAILED + + +class TestResolveConfigBlob: + """Test suite for resolve_config_blob function.""" + + def test_resolve_config_blob_success(self, db: Session): + """Test successful resolution of stored config blob.""" + project = get_project(db) + + config_blob = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-4", "temperature": 0.8}, + ) + ) + config = create_test_config(db, project_id=project.id, config_blob=config_blob) + db.commit() + + config_crud = ConfigVersionCrud( + session=db, project_id=project.id, config_id=config.id + ) + llm_call_config = LLMCallConfig(id=str(config.id), version=1) + + resolved_blob, error = resolve_config_blob(config_crud, llm_call_config) + + assert error is None + assert resolved_blob is not None + assert resolved_blob.completion.provider == "openai" + assert resolved_blob.completion.params["model"] == "gpt-4" + assert resolved_blob.completion.params["temperature"] == 0.8 + + def test_resolve_config_blob_version_not_found(self, db: Session): + """Test resolve_config_blob when version doesn't exist.""" + project = get_project(db) + + config_blob = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-4"}, + ) + ) + config = create_test_config(db, project_id=project.id, config_blob=config_blob) + db.commit() + + config_crud = ConfigVersionCrud( + session=db, project_id=project.id, config_id=config.id + ) + llm_call_config = LLMCallConfig(id=str(config.id), version=999) + + resolved_blob, error = resolve_config_blob(config_crud, llm_call_config) + + assert resolved_blob is None + assert error is not None + assert "Failed to retrieve stored configuration" in error + + def test_resolve_config_blob_invalid_blob_data(self, db: Session): + """Test resolve_config_blob when config blob is malformed.""" + + project = get_project(db) + + config_blob = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-4"}, + ) + ) + config = create_test_config(db, project_id=project.id, config_blob=config_blob) + db.commit() + + # Query the config version directly from the database + statement = select(ConfigVersion).where(ConfigVersion.config_id == config.id) + config_version = db.exec(statement).first() + + # Manually corrupt the config_blob in the database + # Set invalid data that can't be parsed as ConfigBlob + config_version.config_blob = {"invalid": "structure", "missing": "completion"} + db.add(config_version) + db.commit() + + config_crud = ConfigVersionCrud( + session=db, project_id=project.id, config_id=config.id + ) + llm_call_config = LLMCallConfig(id=str(config.id), version=1) + + resolved_blob, error = resolve_config_blob(config_crud, llm_call_config) + + assert resolved_blob is None + assert error is not None + assert "Stored configuration blob is invalid" in error + + def test_resolve_config_blob_with_multiple_versions(self, db: Session): + """Test resolving specific version when multiple versions exist.""" + from app.models.config import ConfigVersionCreate + + project = get_project(db) + + # Create a config with version 1 + config_blob_v1 = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-3.5-turbo", "temperature": 0.5}, + ) + ) + config = create_test_config( + db, project_id=project.id, config_blob=config_blob_v1 + ) + db.commit() + + # Create version 2 using ConfigVersionCrud + config_version_crud = ConfigVersionCrud( + session=db, project_id=project.id, config_id=config.id + ) + config_blob_v2 = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={"model": "gpt-4", "temperature": 0.9}, + ) + ) + version_create = ConfigVersionCreate( + config_blob=config_blob_v2, + commit_message="Updated to gpt-4", + ) + config_version_crud.create_or_raise(version_create) + db.commit() + + # Test resolving version 1 + llm_call_config_v1 = LLMCallConfig(id=str(config.id), version=1) + resolved_blob_v1, error_v1 = resolve_config_blob( + config_version_crud, llm_call_config_v1 + ) + + assert error_v1 is None + assert resolved_blob_v1 is not None + assert resolved_blob_v1.completion.params["model"] == "gpt-3.5-turbo" + assert resolved_blob_v1.completion.params["temperature"] == 0.5 + + # Test resolving version 2 + llm_call_config_v2 = LLMCallConfig(id=str(config.id), version=2) + resolved_blob_v2, error_v2 = resolve_config_blob( + config_version_crud, llm_call_config_v2 + ) + + assert error_v2 is None + assert resolved_blob_v2 is not None + assert resolved_blob_v2.completion.params["model"] == "gpt-4" + assert resolved_blob_v2.completion.params["temperature"] == 0.9 diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index abb62f54..b33656b2 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -8,6 +8,8 @@ Credential, OrganizationCreate, ProjectCreate, + ConfigBlob, + CompletionConfig, CredsCreate, FineTuningJobCreate, Fine_Tuning, @@ -18,6 +20,7 @@ ConfigCreate, ConfigVersion, ConfigVersionCreate, + ConfigBase, ) from app.crud import ( create_organization, @@ -238,7 +241,7 @@ def create_test_config( project_id: int | None = None, name: str | None = None, description: str | None = None, - config_blob: dict | None = None, + config_blob: ConfigBlob | None = None, ) -> Config: """ Creates and returns a test configuration with an initial version. @@ -253,11 +256,16 @@ def create_test_config( name = f"test-config-{random_lower_string()}" if config_blob is None: - config_blob = { - "model": "gpt-4", - "temperature": 0.7, - "max_tokens": 1000, - } + config_blob = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={ + "model": "gpt-4", + "temperature": 0.7, + "max_tokens": 1000, + }, + ) + ) config_create = ConfigCreate( name=name, @@ -276,7 +284,7 @@ def create_test_version( db: Session, config_id, project_id: int, - config_blob: dict | None = None, + config_blob: ConfigBlob | None = None, commit_message: str | None = None, ) -> ConfigVersion: """ @@ -285,11 +293,16 @@ def create_test_version( Persists the version to the database. """ if config_blob is None: - config_blob = { - "model": "gpt-4", - "temperature": 0.8, - "max_tokens": 1500, - } + config_blob = ConfigBlob( + completion=CompletionConfig( + provider="openai", + params={ + "model": "gpt-4", + "temperature": 0.8, + "max_tokens": 1500, + }, + ) + ) version_create = ConfigVersionCreate( config_blob=config_blob,