diff --git a/.env.example b/.env.example index a6df893df..7c2cd9f02 100644 --- a/.env.example +++ b/.env.example @@ -23,6 +23,12 @@ FIRST_SUPERUSER=superuser@example.com FIRST_SUPERUSER_PASSWORD=changethis EMAIL_TEST_USER="test@example.com" +# API Base URL for cron scripts (defaults to http://localhost:8000 if not set) +API_BASE_URL=http://localhost:8000 + +# Cron interval in minutes (defaults to 5 minutes if not set) +CRON_INTERVAL_MINUTES=5 + # Postgres POSTGRES_SERVER=localhost POSTGRES_PORT=5432 diff --git a/backend/app/alembic/versions/6fe772038a5a_create_evaluation_run_table.py b/backend/app/alembic/versions/6fe772038a5a_create_evaluation_run_table.py new file mode 100644 index 000000000..c9fd595aa --- /dev/null +++ b/backend/app/alembic/versions/6fe772038a5a_create_evaluation_run_table.py @@ -0,0 +1,249 @@ +"""create_evaluation_run_table, batch_job_table, and evaluation_dataset_table + +Revision ID: 6fe772038a5a +Revises: 219033c644de +Create Date: 2025-11-05 22:47:18.266070 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = "6fe772038a5a" +down_revision = "219033c644de" +branch_labels = None +depends_on = None + + +def upgrade(): + # Create batch_job table first (as evaluation_run will reference it) + op.create_table( + "batch_job", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "provider", + sa.String(), + nullable=False, + comment="LLM provider name (e.g., 'openai', 'anthropic')", + ), + sa.Column( + "job_type", + sa.String(), + nullable=False, + comment="Type of batch job (e.g., 'evaluation', 'classification', 'embedding')", + ), + sa.Column( + "config", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text("'{}'::jsonb"), + comment="Complete batch configuration", + ), + sa.Column( + "provider_batch_id", + sa.String(), + nullable=True, + comment="Provider's batch job ID", + ), + sa.Column( + "provider_file_id", + sa.String(), + nullable=True, + comment="Provider's input file ID", + ), + sa.Column( + "provider_output_file_id", + sa.String(), + nullable=True, + comment="Provider's output file ID", + ), + sa.Column( + "provider_status", + sa.String(), + nullable=True, + comment="Provider-specific status (e.g., OpenAI: validating, in_progress, completed, failed)", + ), + sa.Column( + "raw_output_url", + sa.String(), + nullable=True, + comment="S3 URL of raw batch output file", + ), + sa.Column( + "total_items", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + comment="Total number of items in the batch", + ), + sa.Column( + "error_message", + sa.Text(), + nullable=True, + comment="Error message if batch failed", + ), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_batch_job_job_type"), "batch_job", ["job_type"], unique=False + ) + op.create_index( + op.f("ix_batch_job_organization_id"), + "batch_job", + ["organization_id"], + unique=False, + ) + op.create_index( + op.f("ix_batch_job_project_id"), "batch_job", ["project_id"], unique=False + ) + op.create_index( + "idx_batch_job_status_org", + "batch_job", + ["provider_status", "organization_id"], + unique=False, + ) + op.create_index( + "idx_batch_job_status_project", + "batch_job", + ["provider_status", "project_id"], + unique=False, + ) + + # Create evaluation_dataset table + op.create_table( + "evaluation_dataset", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column( + "dataset_metadata", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text("'{}'::jsonb"), + ), + sa.Column( + "object_store_url", sqlmodel.sql.sqltypes.AutoString(), nullable=True + ), + sa.Column( + "langfuse_dataset_id", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + ), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "name", + "organization_id", + "project_id", + name="uq_evaluation_dataset_name_org_project", + ), + ) + op.create_index( + op.f("ix_evaluation_dataset_name"), + "evaluation_dataset", + ["name"], + unique=False, + ) + + # Create evaluation_run table with all columns and foreign key references + op.create_table( + "evaluation_run", + sa.Column("run_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("dataset_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("config", sa.JSON(), nullable=False), + sa.Column("batch_job_id", sa.Integer(), nullable=True), + sa.Column( + "embedding_batch_job_id", + sa.Integer(), + nullable=True, + comment="Reference to the batch_job for embedding-based similarity scoring", + ), + sa.Column("dataset_id", sa.Integer(), nullable=False), + sa.Column("status", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "object_store_url", sqlmodel.sql.sqltypes.AutoString(), nullable=True + ), + sa.Column("total_items", sa.Integer(), nullable=False), + sa.Column("score", sa.JSON(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["batch_job_id"], + ["batch_job.id"], + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["embedding_batch_job_id"], + ["batch_job.id"], + name="fk_evaluation_run_embedding_batch_job_id", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["evaluation_dataset.id"], + name="fk_evaluation_run_dataset_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_evaluation_run_run_name"), "evaluation_run", ["run_name"], unique=False + ) + op.create_index( + "idx_eval_run_status_org", + "evaluation_run", + ["status", "organization_id"], + unique=False, + ) + op.create_index( + "idx_eval_run_status_project", + "evaluation_run", + ["status", "project_id"], + unique=False, + ) + + +def downgrade(): + # Drop evaluation_run table first (has foreign keys to batch_job and evaluation_dataset) + op.drop_index("idx_eval_run_status_project", table_name="evaluation_run") + op.drop_index("idx_eval_run_status_org", table_name="evaluation_run") + op.drop_index(op.f("ix_evaluation_run_run_name"), table_name="evaluation_run") + op.drop_table("evaluation_run") + + # Drop evaluation_dataset table + op.drop_index(op.f("ix_evaluation_dataset_name"), table_name="evaluation_dataset") + op.drop_table("evaluation_dataset") + + # Drop batch_job table + op.drop_index("idx_batch_job_status_project", table_name="batch_job") + op.drop_index("idx_batch_job_status_org", table_name="batch_job") + op.drop_index(op.f("ix_batch_job_project_id"), table_name="batch_job") + op.drop_index(op.f("ix_batch_job_organization_id"), table_name="batch_job") + op.drop_index(op.f("ix_batch_job_job_type"), table_name="batch_job") + op.drop_table("batch_job") diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 59678d2f9..73cb77427 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -70,7 +70,7 @@ def get_current_user( if not user: raise HTTPException(status_code=404, detail="User not found") if not user.is_active: - raise HTTPException(status_code=400, detail="Inactive user") + raise HTTPException(status_code=403, detail="Inactive user") return user # Return only User object diff --git a/backend/app/api/docs/evaluation/create_evaluation.md b/backend/app/api/docs/evaluation/create_evaluation.md new file mode 100644 index 000000000..313ad0079 --- /dev/null +++ b/backend/app/api/docs/evaluation/create_evaluation.md @@ -0,0 +1,80 @@ +Start an evaluation using OpenAI Batch API. + +This endpoint: +1. Fetches the dataset from database and validates it has Langfuse dataset ID +2. Creates an EvaluationRun record in the database +3. Fetches dataset items from Langfuse +4. Builds JSONL for batch processing (config is used as-is) +5. Creates a batch job via the generic batch infrastructure +6. Returns the evaluation run details with batch_job_id + +The batch will be processed asynchronously by Celery Beat (every 60s). +Use GET /evaluations/{evaluation_id} to check progress. + +## Request Body + +- **dataset_id** (required): ID of the evaluation dataset (from /evaluations/datasets) +- **experiment_name** (required): Name for this evaluation experiment/run +- **config** (optional): Configuration dict that will be used as-is in JSONL generation. Can include any OpenAI Responses API parameters like: + - model: str (e.g., "gpt-4o", "gpt-5") + - instructions: str + - tools: list (e.g., [{"type": "file_search", "vector_store_ids": [...]}]) + - reasoning: dict (e.g., {"effort": "low"}) + - text: dict (e.g., {"verbosity": "low"}) + - temperature: float + - include: list (e.g., ["file_search_call.results"]) + - Note: "input" will be added automatically from the dataset +- **assistant_id** (optional): Assistant ID to fetch configuration from. If provided, configuration will be fetched from the assistant in the database. Config can be passed as empty dict {} when using assistant_id. + +## Example with config + +```json +{ + "dataset_id": 123, + "experiment_name": "test_run", + "config": { + "model": "gpt-4.1", + "instructions": "You are a helpful FAQ assistant.", + "tools": [ + { + "type": "file_search", + "vector_store_ids": ["vs_12345"], + "max_num_results": 3 + } + ], + "include": ["file_search_call.results"] + } +} +``` + +## Example with assistant_id + +```json +{ + "dataset_id": 123, + "experiment_name": "test_run", + "config": {}, + "assistant_id": "asst_xyz" +} +``` + +## Returns + +EvaluationRunPublic with batch details and status: +- id: Evaluation run ID +- run_name: Name of the evaluation run +- dataset_name: Name of the dataset used +- dataset_id: ID of the dataset used +- config: Configuration used for the evaluation +- batch_job_id: ID of the batch job processing this evaluation +- status: Current status (pending, running, completed, failed) +- total_items: Total number of items being evaluated +- completed_items: Number of items completed so far +- results: Evaluation results (when completed) +- error_message: Error message if failed + +## Error Responses + +- **404**: Dataset or assistant not found or not accessible +- **400**: Missing required credentials (OpenAI or Langfuse), dataset missing Langfuse ID, or config missing required fields +- **500**: Failed to configure API clients or start batch evaluation diff --git a/backend/app/api/docs/evaluation/delete_dataset.md b/backend/app/api/docs/evaluation/delete_dataset.md new file mode 100644 index 000000000..461c30fce --- /dev/null +++ b/backend/app/api/docs/evaluation/delete_dataset.md @@ -0,0 +1,18 @@ +Delete a dataset by ID. + +This will remove the dataset record from the database. The CSV file in object store (if exists) will remain for audit purposes, but the dataset will no longer be accessible for creating new evaluations. + +## Path Parameters + +- **dataset_id**: ID of the dataset to delete + +## Returns + +Success message with deleted dataset details: +- message: Confirmation message +- dataset_id: ID of the deleted dataset + +## Error Responses + +- **404**: Dataset not found or not accessible to your organization/project +- **400**: Dataset cannot be deleted (e.g., has active evaluation runs) diff --git a/backend/app/api/docs/evaluation/get_dataset.md b/backend/app/api/docs/evaluation/get_dataset.md new file mode 100644 index 000000000..02e1e73aa --- /dev/null +++ b/backend/app/api/docs/evaluation/get_dataset.md @@ -0,0 +1,22 @@ +Get details of a specific dataset by ID. + +Retrieves comprehensive information about a dataset including metadata, object store URL, and Langfuse integration details. + +## Path Parameters + +- **dataset_id**: ID of the dataset to retrieve + +## Returns + +DatasetUploadResponse with dataset details: +- dataset_id: Unique identifier for the dataset +- dataset_name: Name of the dataset (sanitized) +- total_items: Total number of items including duplication +- original_items: Number of original items before duplication +- duplication_factor: Factor by which items were duplicated +- langfuse_dataset_id: ID of the dataset in Langfuse +- object_store_url: URL to the CSV file in object storage + +## Error Responses + +- **404**: Dataset not found or not accessible to your organization/project diff --git a/backend/app/api/docs/evaluation/get_evaluation.md b/backend/app/api/docs/evaluation/get_evaluation.md new file mode 100644 index 000000000..509e27640 --- /dev/null +++ b/backend/app/api/docs/evaluation/get_evaluation.md @@ -0,0 +1,32 @@ +Get the current status of a specific evaluation run. + +Retrieves comprehensive information about an evaluation run including its current processing status, results (if completed), and error details (if failed). + +## Path Parameters + +- **evaluation_id**: ID of the evaluation run + +## Returns + +EvaluationRunPublic with current status and results: +- id: Evaluation run ID +- run_name: Name of the evaluation run +- dataset_name: Name of the dataset used +- dataset_id: ID of the dataset used +- config: Configuration used for the evaluation +- batch_job_id: ID of the batch job processing this evaluation +- status: Current status (pending, running, completed, failed) +- total_items: Total number of items being evaluated +- completed_items: Number of items completed so far +- results: Evaluation results (when completed) +- error_message: Error message if failed +- created_at: Timestamp when the evaluation was created +- updated_at: Timestamp when the evaluation was last updated + +## Usage + +Use this endpoint to poll for evaluation progress. The evaluation is processed asynchronously by Celery Beat (every 60s), so you should poll periodically to check if the status has changed to "completed" or "failed". + +## Error Responses + +- **404**: Evaluation run not found or not accessible to this organization/project diff --git a/backend/app/api/docs/evaluation/list_datasets.md b/backend/app/api/docs/evaluation/list_datasets.md new file mode 100644 index 000000000..bd5576efc --- /dev/null +++ b/backend/app/api/docs/evaluation/list_datasets.md @@ -0,0 +1,19 @@ +List all datasets for the current organization and project. + +Returns a paginated list of dataset records ordered by most recent first. + +## Query Parameters + +- **limit**: Maximum number of datasets to return (default 50, max 100) +- **offset**: Number of datasets to skip for pagination (default 0) + +## Returns + +List of DatasetUploadResponse objects, each containing: +- dataset_id: Unique identifier for the dataset +- dataset_name: Name of the dataset (sanitized) +- total_items: Total number of items including duplication +- original_items: Number of original items before duplication +- duplication_factor: Factor by which items were duplicated +- langfuse_dataset_id: ID of the dataset in Langfuse +- object_store_url: URL to the CSV file in object storage diff --git a/backend/app/api/docs/evaluation/list_evaluations.md b/backend/app/api/docs/evaluation/list_evaluations.md new file mode 100644 index 000000000..64c667726 --- /dev/null +++ b/backend/app/api/docs/evaluation/list_evaluations.md @@ -0,0 +1,25 @@ +List all evaluation runs for the current organization and project. + +Returns a paginated list of evaluation runs ordered by most recent first. Each evaluation run represents a batch processing job evaluating a dataset against a specific configuration. + +## Query Parameters + +- **limit**: Maximum number of runs to return (default 50) +- **offset**: Number of runs to skip (for pagination, default 0) + +## Returns + +List of EvaluationRunPublic objects, each containing: +- id: Evaluation run ID +- run_name: Name of the evaluation run +- dataset_name: Name of the dataset used +- dataset_id: ID of the dataset used +- config: Configuration used for the evaluation +- batch_job_id: ID of the batch job processing this evaluation +- status: Current status (pending, running, completed, failed) +- total_items: Total number of items being evaluated +- completed_items: Number of items completed so far +- results: Evaluation results (when completed) +- error_message: Error message if failed +- created_at: Timestamp when the evaluation was created +- updated_at: Timestamp when the evaluation was last updated diff --git a/backend/app/api/docs/evaluation/upload_dataset.md b/backend/app/api/docs/evaluation/upload_dataset.md new file mode 100644 index 000000000..b73902860 --- /dev/null +++ b/backend/app/api/docs/evaluation/upload_dataset.md @@ -0,0 +1,42 @@ +Upload a CSV file containing Golden Q&A pairs. + +This endpoint: +1. Sanitizes the dataset name (removes spaces, special characters) +2. Validates and parses the CSV file +3. Uploads CSV to object store (if credentials configured) +4. Uploads dataset to Langfuse (for immediate use) +5. Stores metadata in database + +## Dataset Name + +- Will be sanitized for Langfuse compatibility +- Spaces replaced with underscores +- Special characters removed +- Converted to lowercase +- Example: "My Dataset 01!" becomes "my_dataset_01" + +## CSV Format + +- Must contain 'question' and 'answer' columns +- Can have additional columns (will be ignored) +- Missing values in 'question' or 'answer' rows will be skipped + +## Duplication Factor + +- Minimum: 1 (no duplication) +- Maximum: 5 +- Default: 5 +- Each item in the dataset will be duplicated this many times +- Used to ensure statistical significance in evaluation results + +## Example CSV + +``` +question,answer +"What is the capital of France?","Paris" +"What is 2+2?","4" +``` + +## Returns + +DatasetUploadResponse with dataset_id, object_store_url, and Langfuse details (dataset_name in response will be the sanitized version) diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 62d5db5b9..e2b473f92 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -18,6 +18,8 @@ utils, onboarding, credentials, + cron, + evaluation, fine_tuning, model_evaluation, collection_job, @@ -30,8 +32,10 @@ api_router.include_router(collections.router) api_router.include_router(collection_job.router) api_router.include_router(credentials.router) +api_router.include_router(cron.router) api_router.include_router(documents.router) api_router.include_router(doc_transformation_job.router) +api_router.include_router(evaluation.router) api_router.include_router(llm.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) diff --git a/backend/app/api/routes/cron.py b/backend/app/api/routes/cron.py new file mode 100644 index 000000000..a9e7b66ed --- /dev/null +++ b/backend/app/api/routes/cron.py @@ -0,0 +1,64 @@ +import logging + +from app.api.permissions import Permission, require_permission +from fastapi import APIRouter, Depends +from sqlmodel import Session + +from app.api.deps import SessionDep, AuthContextDep +from app.crud.evaluations import process_all_pending_evaluations_sync +from app.models import User + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["cron"]) + + +@router.get( + "/cron/evaluations", + include_in_schema=False, + dependencies=[Depends(require_permission(Permission.SUPERUSER))], +) +def evaluation_cron_job( + session: SessionDep, +) -> dict: + """ + Cron job endpoint for periodic evaluation tasks. + + This endpoint: + 1. Gets all organizations + 2. For each org, polls their pending evaluations + 3. Processes completed batches automatically + 4. Returns aggregated results + + Hidden from Swagger documentation. + Requires authentication via FIRST_SUPERUSER credentials. + """ + logger.info("[evaluation_cron_job] Cron job invoked") + + try: + # Process all pending evaluations across all organizations + result = process_all_pending_evaluations_sync(session=session) + + logger.info( + f"[evaluation_cron_job] Completed: " + f"orgs={result.get('organizations_processed', 0)}, " + f"processed={result.get('total_processed', 0)}, " + f"failed={result.get('total_failed', 0)}, " + f"still_processing={result.get('total_still_processing', 0)}" + ) + + return result + + except Exception as e: + logger.error( + f"[evaluation_cron_job] Error executing cron job: {e}", + exc_info=True, + ) + return { + "status": "error", + "error": str(e), + "organizations_processed": 0, + "total_processed": 0, + "total_failed": 0, + "total_still_processing": 0, + } diff --git a/backend/app/api/routes/evaluation.py b/backend/app/api/routes/evaluation.py new file mode 100644 index 000000000..a62048d13 --- /dev/null +++ b/backend/app/api/routes/evaluation.py @@ -0,0 +1,627 @@ +import csv +import io +import logging +import re +from pathlib import Path + +from fastapi import APIRouter, Body, File, Form, HTTPException, UploadFile + +from app.api.deps import AuthContextDep, SessionDep +from app.core.cloud import get_cloud_storage +from app.crud.assistants import get_assistant_by_id +from app.crud.evaluations import ( + create_evaluation_dataset, + create_evaluation_run, + get_dataset_by_id, + get_evaluation_run_by_id, + list_datasets, + start_evaluation_batch, + upload_csv_to_object_store, + upload_dataset_to_langfuse_from_csv, +) +from app.crud.evaluations import list_evaluation_runs as list_evaluation_runs_crud +from app.crud.evaluations.dataset import delete_dataset as delete_dataset_crud +from app.models.evaluation import ( + DatasetUploadResponse, + EvaluationRunPublic, +) +from app.utils import get_langfuse_client, get_openai_client, load_description + +logger = logging.getLogger(__name__) + +# File upload security constants +MAX_FILE_SIZE = 1024 * 1024 # 1 MB +ALLOWED_EXTENSIONS = {".csv"} +ALLOWED_MIME_TYPES = { + "text/csv", + "application/csv", + "text/plain", # Some systems report CSV as text/plain +} + +router = APIRouter(tags=["evaluation"]) + + +def sanitize_dataset_name(name: str) -> str: + """ + Sanitize dataset name for Langfuse compatibility. + + Langfuse has issues with spaces and special characters in dataset names. + This function ensures the name can be both created and fetched. + + Rules: + - Replace spaces with underscores + - Replace hyphens with underscores + - Keep only alphanumeric characters and underscores + - Convert to lowercase for consistency + - Remove leading/trailing underscores + - Collapse multiple consecutive underscores into one + + Args: + name: Original dataset name + + Returns: + Sanitized dataset name safe for Langfuse + + Examples: + "testing 0001" -> "testing_0001" + "My Dataset!" -> "my_dataset" + "Test--Data__Set" -> "test_data_set" + """ + # Convert to lowercase + sanitized = name.lower() + + # Replace spaces and hyphens with underscores + sanitized = sanitized.replace(" ", "_").replace("-", "_") + + # Keep only alphanumeric characters and underscores + sanitized = re.sub(r"[^a-z0-9_]", "", sanitized) + + # Collapse multiple underscores into one + sanitized = re.sub(r"_+", "_", sanitized) + + # Remove leading/trailing underscores + sanitized = sanitized.strip("_") + + # Ensure name is not empty + if not sanitized: + raise ValueError("Dataset name cannot be empty after sanitization") + + return sanitized + + +@router.post( + "/evaluations/datasets", + description=load_description("evaluation/upload_dataset.md"), + response_model=DatasetUploadResponse, +) +async def upload_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + file: UploadFile = File( + ..., description="CSV file with 'question' and 'answer' columns" + ), + dataset_name: str = Form(..., description="Name for the dataset"), + description: str | None = Form(None, description="Optional dataset description"), + duplication_factor: int = Form( + default=5, + ge=1, + le=5, + description="Number of times to duplicate each item (min: 1, max: 5)", + ), +) -> DatasetUploadResponse: + # Sanitize dataset name for Langfuse compatibility + original_name = dataset_name + try: + dataset_name = sanitize_dataset_name(dataset_name) + except ValueError as e: + raise HTTPException(status_code=422, detail=f"Invalid dataset name: {str(e)}") + + if original_name != dataset_name: + logger.info( + f"[upload_dataset] Dataset name sanitized | '{original_name}' -> '{dataset_name}'" + ) + + logger.info( + f"[upload_dataset] Uploading dataset | dataset={dataset_name} | " + f"duplication_factor={duplication_factor} | org_id={auth_context.organization.id} | " + f"project_id={auth_context.project.id}" + ) + + # Security validation: Check file extension + file_ext = Path(file.filename).suffix.lower() + if file_ext not in ALLOWED_EXTENSIONS: + raise HTTPException( + status_code=422, + detail=f"Invalid file type. Only CSV files are allowed. Got: {file_ext}", + ) + + # Security validation: Check MIME type + content_type = file.content_type + if content_type not in ALLOWED_MIME_TYPES: + raise HTTPException( + status_code=422, + detail=f"Invalid content type. Expected CSV, got: {content_type}", + ) + + # Security validation: Check file size + file.file.seek(0, 2) # Seek to end + file_size = file.file.tell() + file.file.seek(0) # Reset to beginning + + if file_size > MAX_FILE_SIZE: + raise HTTPException( + status_code=413, + detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024*1024):.0f}MB", + ) + + if file_size == 0: + raise HTTPException(status_code=422, detail="Empty file uploaded") + + # Read CSV content + csv_content = await file.read() + + # Step 1: Parse and validate CSV + try: + csv_text = csv_content.decode("utf-8") + csv_reader = csv.DictReader(io.StringIO(csv_text)) + csv_reader.fieldnames = [name.strip() for name in csv_reader.fieldnames] + + # Validate headers + if ( + "question" not in csv_reader.fieldnames + or "answer" not in csv_reader.fieldnames + ): + raise HTTPException( + status_code=422, + detail=f"CSV must contain 'question' and 'answer' columns. " + f"Found columns: {csv_reader.fieldnames}", + ) + + # Count original items + original_items = [] + for row in csv_reader: + question = row.get("question", "").strip() + answer = row.get("answer", "").strip() + if question and answer: + original_items.append({"question": question, "answer": answer}) + + if not original_items: + raise HTTPException( + status_code=422, detail="No valid items found in CSV file" + ) + + original_items_count = len(original_items) + total_items_count = original_items_count * duplication_factor + + logger.info( + f"[upload_dataset] Parsed items from CSV | original={original_items_count} | " + f"total_with_duplication={total_items_count}" + ) + + except Exception as e: + logger.error(f"[upload_dataset] Failed to parse CSV | {e}", exc_info=True) + raise HTTPException(status_code=422, detail=f"Invalid CSV file: {e}") + + # Step 2: Upload to object store (if credentials configured) + object_store_url = None + try: + storage = get_cloud_storage( + session=_session, project_id=auth_context.project.id + ) + object_store_url = upload_csv_to_object_store( + storage=storage, csv_content=csv_content, dataset_name=dataset_name + ) + if object_store_url: + logger.info( + f"[upload_dataset] Successfully uploaded CSV to object store | {object_store_url}" + ) + else: + logger.info( + "[upload_dataset] Object store upload returned None | continuing without object store storage" + ) + except Exception as e: + logger.warning( + f"[upload_dataset] Failed to upload CSV to object store (continuing without object store) | {e}", + exc_info=True, + ) + object_store_url = None + + # Step 3: Upload to Langfuse + langfuse_dataset_id = None + try: + # Get Langfuse client + langfuse = get_langfuse_client( + session=_session, + org_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + + # Upload to Langfuse + langfuse_dataset_id, _ = upload_dataset_to_langfuse_from_csv( + langfuse=langfuse, + csv_content=csv_content, + dataset_name=dataset_name, + duplication_factor=duplication_factor, + ) + + logger.info( + f"[upload_dataset] Successfully uploaded dataset to Langfuse | " + f"dataset={dataset_name} | id={langfuse_dataset_id}" + ) + + except Exception as e: + logger.error( + f"[upload_dataset] Failed to upload dataset to Langfuse | {e}", + exc_info=True, + ) + raise HTTPException( + status_code=500, detail=f"Failed to upload dataset to Langfuse: {e}" + ) + + # Step 4: Store metadata in database + metadata = { + "original_items_count": original_items_count, + "total_items_count": total_items_count, + "duplication_factor": duplication_factor, + } + + dataset = create_evaluation_dataset( + session=_session, + name=dataset_name, + description=description, + dataset_metadata=metadata, + object_store_url=object_store_url, + langfuse_dataset_id=langfuse_dataset_id, + organization_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + + logger.info( + f"[upload_dataset] Successfully created dataset record in database | " + f"id={dataset.id} | name={dataset_name}" + ) + + # Return response + return DatasetUploadResponse( + dataset_id=dataset.id, + dataset_name=dataset_name, + total_items=total_items_count, + original_items=original_items_count, + duplication_factor=duplication_factor, + langfuse_dataset_id=langfuse_dataset_id, + object_store_url=object_store_url, + ) + + +@router.get( + "/evaluations/datasets", + description=load_description("evaluation/list_datasets.md"), + response_model=list[DatasetUploadResponse], +) +def list_datasets_endpoint( + _session: SessionDep, + auth_context: AuthContextDep, + limit: int = 50, + offset: int = 0, +) -> list[DatasetUploadResponse]: + # Enforce maximum limit + if limit > 100: + limit = 100 + + datasets = list_datasets( + session=_session, + organization_id=auth_context.organization.id, + project_id=auth_context.project.id, + limit=limit, + offset=offset, + ) + + # Convert to response format + response = [] + for dataset in datasets: + response.append( + DatasetUploadResponse( + dataset_id=dataset.id, + dataset_name=dataset.name, + total_items=dataset.dataset_metadata.get("total_items_count", 0), + original_items=dataset.dataset_metadata.get("original_items_count", 0), + duplication_factor=dataset.dataset_metadata.get( + "duplication_factor", 1 + ), + langfuse_dataset_id=dataset.langfuse_dataset_id, + object_store_url=dataset.object_store_url, + ) + ) + + return response + + +@router.get( + "/evaluations/datasets/{dataset_id}", + description=load_description("evaluation/get_dataset.md"), + response_model=DatasetUploadResponse, +) +def get_dataset( + dataset_id: int, + _session: SessionDep, + auth_context: AuthContextDep, +) -> DatasetUploadResponse: + logger.info( + f"[get_dataset] Fetching dataset | id={dataset_id} | " + f"org_id={auth_context.organization.id} | " + f"project_id={auth_context.project.id}" + ) + + dataset = get_dataset_by_id( + session=_session, + dataset_id=dataset_id, + organization_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + + if not dataset: + raise HTTPException( + status_code=404, detail=f"Dataset {dataset_id} not found or not accessible" + ) + + return DatasetUploadResponse( + dataset_id=dataset.id, + dataset_name=dataset.name, + total_items=dataset.dataset_metadata.get("total_items_count", 0), + original_items=dataset.dataset_metadata.get("original_items_count", 0), + duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1), + langfuse_dataset_id=dataset.langfuse_dataset_id, + object_store_url=dataset.object_store_url, + ) + + +@router.delete( + "/evaluations/datasets/{dataset_id}", + description=load_description("evaluation/delete_dataset.md"), +) +def delete_dataset( + dataset_id: int, + _session: SessionDep, + auth_context: AuthContextDep, +) -> dict: + logger.info( + f"[delete_dataset] Deleting dataset | id={dataset_id} | " + f"org_id={auth_context.organization.id} | " + f"project_id={auth_context.project.id}" + ) + + success, message = delete_dataset_crud( + session=_session, + dataset_id=dataset_id, + organization_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + + if not success: + # Check if it's a not found error or other error type + if "not found" in message.lower(): + raise HTTPException(status_code=404, detail=message) + else: + raise HTTPException(status_code=400, detail=message) + + logger.info(f"[delete_dataset] Successfully deleted dataset | id={dataset_id}") + return {"message": message, "dataset_id": dataset_id} + + +@router.post( + "/evaluations", + description=load_description("evaluation/create_evaluation.md"), + response_model=EvaluationRunPublic, +) +def evaluate( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int = Body(..., description="ID of the evaluation dataset"), + experiment_name: str = Body( + ..., description="Name for this evaluation experiment/run" + ), + config: dict = Body(default_factory=dict, description="Evaluation configuration"), + assistant_id: str + | None = Body( + None, description="Optional assistant ID to fetch configuration from" + ), +) -> EvaluationRunPublic: + logger.info( + f"[evaluate] Starting evaluation | experiment_name={experiment_name} | " + f"dataset_id={dataset_id} | " + f"org_id={auth_context.organization.id} | " + f"assistant_id={assistant_id} | " + f"config_keys={list(config.keys())}" + ) + + # Step 1: Fetch dataset from database + dataset = get_dataset_by_id( + session=_session, + dataset_id=dataset_id, + organization_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset {dataset_id} not found or not accessible to this " + f"organization/project", + ) + + logger.info( + f"[evaluate] Found dataset | id={dataset.id} | name={dataset.name} | " + f"object_store_url={'present' if dataset.object_store_url else 'None'} | " + f"langfuse_id={dataset.langfuse_dataset_id}" + ) + + dataset_name = dataset.name + + # Get API clients + openai_client = get_openai_client( + session=_session, + org_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + langfuse = get_langfuse_client( + session=_session, + org_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + + # Validate dataset has Langfuse ID (should have been set during dataset creation) + if not dataset.langfuse_dataset_id: + raise HTTPException( + status_code=400, + detail=f"Dataset {dataset_id} does not have a Langfuse dataset ID. " + "Please ensure Langfuse credentials were configured when the dataset was created.", + ) + + # Handle assistant_id if provided + if assistant_id: + # Fetch assistant details from database + assistant = get_assistant_by_id( + session=_session, + assistant_id=assistant_id, + project_id=auth_context.project.id, + ) + + if not assistant: + raise HTTPException( + status_code=404, detail=f"Assistant {assistant_id} not found" + ) + + logger.info( + f"[evaluate] Found assistant in DB | id={assistant.id} | " + f"model={assistant.model} | instructions=" + f"{assistant.instructions[:50] if assistant.instructions else 'None'}..." + ) + + # Build config from assistant (use provided config values to override + # if present) + config = { + "model": config.get("model", assistant.model), + "instructions": config.get("instructions", assistant.instructions), + "temperature": config.get("temperature", assistant.temperature), + } + + # Add tools if vector stores are available + vector_store_ids = config.get( + "vector_store_ids", assistant.vector_store_ids or [] + ) + if vector_store_ids and len(vector_store_ids) > 0: + config["tools"] = [ + { + "type": "file_search", + "vector_store_ids": vector_store_ids, + } + ] + + logger.info("[evaluate] Using config from assistant") + else: + logger.info("[evaluate] Using provided config directly") + # Validate that config has minimum required fields + if not config.get("model"): + raise HTTPException( + status_code=400, + detail="Config must include 'model' when assistant_id is not provided", + ) + + # Create EvaluationRun record + eval_run = create_evaluation_run( + session=_session, + run_name=experiment_name, + dataset_name=dataset_name, + dataset_id=dataset_id, + config=config, + organization_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + + # Start the batch evaluation + try: + eval_run = start_evaluation_batch( + langfuse=langfuse, + openai_client=openai_client, + session=_session, + eval_run=eval_run, + config=config, + ) + + logger.info( + f"[evaluate] Evaluation started successfully | " + f"batch_job_id={eval_run.batch_job_id} | total_items={eval_run.total_items}" + ) + + return eval_run + + except Exception as e: + logger.error( + f"[evaluate] Failed to start evaluation | run_id={eval_run.id} | {e}", + exc_info=True, + ) + # Error is already handled in start_evaluation_batch + _session.refresh(eval_run) + return eval_run + + +@router.get( + "/evaluations", + description=load_description("evaluation/list_evaluations.md"), + response_model=list[EvaluationRunPublic], +) +def list_evaluation_runs( + _session: SessionDep, + auth_context: AuthContextDep, + limit: int = 50, + offset: int = 0, +) -> list[EvaluationRunPublic]: + logger.info( + f"[list_evaluation_runs] Listing evaluation runs | " + f"org_id={auth_context.organization.id} | " + f"project_id={auth_context.project.id} | limit={limit} | offset={offset}" + ) + + return list_evaluation_runs_crud( + session=_session, + organization_id=auth_context.organization.id, + project_id=auth_context.project.id, + limit=limit, + offset=offset, + ) + + +@router.get( + "/evaluations/{evaluation_id}", + description=load_description("evaluation/get_evaluation.md"), + response_model=EvaluationRunPublic, +) +def get_evaluation_run_status( + evaluation_id: int, + _session: SessionDep, + auth_context: AuthContextDep, +) -> EvaluationRunPublic: + logger.info( + f"[get_evaluation_run_status] Fetching status for evaluation run | " + f"evaluation_id={evaluation_id} | " + f"org_id={auth_context.organization.id} | " + f"project_id={auth_context.project.id}" + ) + + eval_run = get_evaluation_run_by_id( + session=_session, + evaluation_id=evaluation_id, + organization_id=auth_context.organization.id, + project_id=auth_context.project.id, + ) + + if not eval_run: + raise HTTPException( + status_code=404, + detail=( + f"Evaluation run {evaluation_id} not found or not accessible " + "to this organization" + ), + ) + + return eval_run diff --git a/backend/app/api/routes/llm.py b/backend/app/api/routes/llm.py index 4eed7c1bc..26c9ee423 100644 --- a/backend/app/api/routes/llm.py +++ b/backend/app/api/routes/llm.py @@ -35,7 +35,7 @@ def llm_callback_notification(body: APIResponse[LLMCallResponse]): response_model=APIResponse[Message], callbacks=llm_callback_router.routes, ) -async def llm_call( +def llm_call( _current_user: AuthContextDep, _session: SessionDep, request: LLMCallRequest ): """ diff --git a/backend/app/celery/celery_app.py b/backend/app/celery/celery_app.py index d67acdbcd..81fba8cb2 100644 --- a/backend/app/celery/celery_app.py +++ b/backend/app/celery/celery_app.py @@ -1,5 +1,6 @@ from celery import Celery -from kombu import Queue, Exchange +from kombu import Exchange, Queue + from app.core.config import settings # Create Celery instance @@ -7,7 +8,9 @@ "ai_platform", broker=settings.RABBITMQ_URL, backend=settings.REDIS_URL, - include=["app.celery.tasks.job_execution"], + include=[ + "app.celery.tasks.job_execution", + ], ) # Define exchanges and queues with priority @@ -82,14 +85,6 @@ # Connection settings from environment broker_connection_retry_on_startup=True, broker_pool_limit=settings.CELERY_BROKER_POOL_LIMIT, - # Beat configuration (for future cron jobs) - beat_schedule={ - # Example cron job (commented out) - # "example-cron": { - # "task": "app.celery.tasks.example_cron_task", - # "schedule": 60.0, # Every 60 seconds - # }, - }, ) # Auto-discover tasks diff --git a/backend/app/core/batch/__init__.py b/backend/app/core/batch/__init__.py new file mode 100644 index 000000000..9f7cd88d5 --- /dev/null +++ b/backend/app/core/batch/__init__.py @@ -0,0 +1,5 @@ +"""Batch processing infrastructure for LLM providers.""" + +from .base import BatchProvider + +__all__ = ["BatchProvider"] diff --git a/backend/app/core/batch/base.py b/backend/app/core/batch/base.py new file mode 100644 index 000000000..94e316e21 --- /dev/null +++ b/backend/app/core/batch/base.py @@ -0,0 +1,105 @@ +"""Abstract interface for LLM batch providers.""" + +from abc import ABC, abstractmethod +from typing import Any + + +class BatchProvider(ABC): + """Abstract base class for LLM batch providers (OpenAI, Anthropic, etc.).""" + + @abstractmethod + def create_batch( + self, jsonl_data: list[dict[str, Any]], config: dict[str, Any] + ) -> dict[str, Any]: + """ + Upload JSONL data and create a batch job with the provider. + + Args: + jsonl_data: List of dictionaries representing JSONL lines + config: Provider-specific configuration (model, temperature, etc.) + + Returns: + Dictionary containing: + - provider_batch_id: Provider's batch job ID + - provider_file_id: Provider's input file ID + - provider_status: Initial status from provider + - total_items: Number of items in the batch + - Any other provider-specific metadata + + Raises: + Exception: If batch creation fails + """ + pass + + @abstractmethod + def get_batch_status(self, batch_id: str) -> dict[str, Any]: + """ + Poll the provider for batch job status. + + Args: + batch_id: Provider's batch job ID + + Returns: + Dictionary containing: + - provider_status: Current status from provider + - provider_output_file_id: Output file ID (if completed) + - error_message: Error message (if failed) + - Any other provider-specific status info + + Raises: + Exception: If status check fails + """ + pass + + @abstractmethod + def download_batch_results(self, output_file_id: str) -> list[dict[str, Any]]: + """ + Download and parse batch results from the provider. + + Args: + output_file_id: Provider's output file ID + + Returns: + List of result dictionaries, each containing: + - custom_id: Item identifier from input + - response: Provider's response data + - error: Error info (if item failed) + - Any other provider-specific result data + + Raises: + Exception: If download or parsing fails + """ + pass + + @abstractmethod + def upload_file(self, content: str, purpose: str = "batch") -> str: + """ + Upload a file to the provider's file storage. + + Args: + content: File content (typically JSONL string) + purpose: Purpose of the file (e.g., "batch") + + Returns: + Provider's file ID + + Raises: + Exception: If upload fails + """ + pass + + @abstractmethod + def download_file(self, file_id: str) -> str: + """ + Download a file from the provider's file storage. + + Args: + file_id: Provider's file ID + + Returns: + File content as string + + Raises: + Exception: If download fails + """ + pass diff --git a/backend/app/core/batch/openai.py b/backend/app/core/batch/openai.py new file mode 100644 index 000000000..8bb4abe6a --- /dev/null +++ b/backend/app/core/batch/openai.py @@ -0,0 +1,254 @@ +"""OpenAI batch provider implementation.""" + +import json +import logging +from typing import Any + +from openai import OpenAI + +from .base import BatchProvider + +logger = logging.getLogger(__name__) + + +class OpenAIBatchProvider(BatchProvider): + """OpenAI implementation of the BatchProvider interface.""" + + def __init__(self, client: OpenAI): + """ + Initialize the OpenAI batch provider. + + Args: + client: Configured OpenAI client + """ + self.client = client + + def create_batch( + self, jsonl_data: list[dict[str, Any]], config: dict[str, Any] + ) -> dict[str, Any]: + """ + Upload JSONL data and create a batch job with OpenAI. + + Args: + jsonl_data: List of dictionaries representing JSONL lines + config: Provider-specific configuration with: + - endpoint: OpenAI endpoint (e.g., "/v1/responses") + - description: Optional batch description + - completion_window: Optional completion window (default "24h") + + Returns: + Dictionary containing: + - provider_batch_id: OpenAI batch ID + - provider_file_id: OpenAI input file ID + - provider_status: Initial status from OpenAI + - total_items: Number of items in the batch + + Raises: + Exception: If batch creation fails + """ + endpoint = config.get("endpoint", "/v1/responses") + description = config.get("description", "LLM batch job") + completion_window = config.get("completion_window", "24h") + + logger.info( + f"[create_batch] Creating OpenAI batch | items={len(jsonl_data)} | endpoint={endpoint}" + ) + + try: + # Step 1: Upload file + file_id = self.upload_file( + content="\n".join([json.dumps(line) for line in jsonl_data]), + purpose="batch", + ) + + # Step 2: Create batch job + batch = self.client.batches.create( + input_file_id=file_id, + endpoint=endpoint, + completion_window=completion_window, + metadata={"description": description}, + ) + + result = { + "provider_batch_id": batch.id, + "provider_file_id": file_id, + "provider_status": batch.status, + "total_items": len(jsonl_data), + } + + logger.info( + f"[create_batch] Created OpenAI batch | batch_id={batch.id} | status={batch.status} | items={len(jsonl_data)}" + ) + + return result + + except Exception as e: + logger.error(f"[create_batch] Failed to create OpenAI batch | {e}") + raise + + def get_batch_status(self, batch_id: str) -> dict[str, Any]: + """ + Poll OpenAI for batch job status. + + Args: + batch_id: OpenAI batch ID + + Returns: + Dictionary containing: + - provider_status: Current OpenAI status + - provider_output_file_id: Output file ID (if completed) + - error_message: Error message (if failed) + - request_counts: Dict with total/completed/failed counts + + Raises: + Exception: If status check fails + """ + logger.info( + f"[get_batch_status] Polling OpenAI batch status | batch_id={batch_id}" + ) + + try: + batch = self.client.batches.retrieve(batch_id) + + result = { + "provider_status": batch.status, + "provider_output_file_id": batch.output_file_id, + "error_file_id": batch.error_file_id, + "request_counts": { + "total": batch.request_counts.total, + "completed": batch.request_counts.completed, + "failed": batch.request_counts.failed, + }, + } + + # Add error message if batch failed + if batch.status in ["failed", "expired", "cancelled"]: + error_msg = f"Batch {batch.status}" + if batch.error_file_id: + error_msg += f" (error_file_id: {batch.error_file_id})" + result["error_message"] = error_msg + + logger.info( + f"[get_batch_status] OpenAI batch status | batch_id={batch_id} | status={batch.status} | completed={batch.request_counts.completed}/{batch.request_counts.total}" + ) + + return result + + except Exception as e: + logger.error( + f"[get_batch_status] Failed to poll OpenAI batch status | batch_id={batch_id} | {e}" + ) + raise + + def download_batch_results(self, output_file_id: str) -> list[dict[str, Any]]: + """ + Download and parse batch results from OpenAI. + + Args: + output_file_id: OpenAI output file ID + + Returns: + List of result dictionaries, each containing: + - custom_id: Item identifier from input + - response: OpenAI response data (body, status_code, request_id) + - error: Error info (if item failed) + + Raises: + Exception: If download or parsing fails + """ + logger.info( + f"[download_batch_results] Downloading OpenAI batch results | output_file_id={output_file_id}" + ) + + try: + # Download file content + jsonl_content = self.download_file(output_file_id) + + # Parse JSONL into list of dicts + results = [] + lines = jsonl_content.strip().split("\n") + + for line_num, line in enumerate(lines, 1): + try: + result = json.loads(line) + results.append(result) + except json.JSONDecodeError as e: + logger.error( + f"[download_batch_results] Failed to parse JSON | line={line_num} | {e}" + ) + continue + + logger.info( + f"[download_batch_results] Downloaded and parsed results from OpenAI batch output | results={len(results)}" + ) + + return results + + except Exception as e: + logger.error( + f"[download_batch_results] Failed to download OpenAI batch results | {e}" + ) + raise + + def upload_file(self, content: str, purpose: str = "batch") -> str: + """ + Upload a file to OpenAI file storage. + + Args: + content: File content (typically JSONL string) + purpose: Purpose of the file (e.g., "batch") + + Returns: + OpenAI file ID + + Raises: + Exception: If upload fails + """ + logger.info(f"[upload_file] Uploading file to OpenAI | bytes={len(content)}") + + try: + file_response = self.client.files.create( + file=("batch_input.jsonl", content.encode("utf-8")), + purpose=purpose, + ) + + logger.info( + f"[upload_file] Uploaded file to OpenAI | file_id={file_response.id}" + ) + + return file_response.id + + except Exception as e: + logger.error(f"[upload_file] Failed to upload file to OpenAI | {e}") + raise + + def download_file(self, file_id: str) -> str: + """ + Download a file from OpenAI file storage. + + Args: + file_id: OpenAI file ID + + Returns: + File content as string + + Raises: + Exception: If download fails + """ + logger.info(f"[download_file] Downloading file from OpenAI | file_id={file_id}") + + try: + file_content = self.client.files.content(file_id) + content = file_content.read().decode("utf-8") + + logger.info( + f"[download_file] Downloaded file from OpenAI | file_id={file_id} | bytes={len(content)}" + ) + + return content + + except Exception as e: + logger.error( + f"[download_file] Failed to download file from OpenAI | file_id={file_id} | {e}" + ) + raise diff --git a/backend/app/core/storage_utils.py b/backend/app/core/storage_utils.py new file mode 100644 index 000000000..63830d7d0 --- /dev/null +++ b/backend/app/core/storage_utils.py @@ -0,0 +1,167 @@ +""" +Shared storage utilities for uploading files to object store. + +This module provides common functions for uploading various file types +to cloud object storage, abstracting away provider-specific details. +""" + +import io +import json +import logging +from datetime import datetime +from io import BytesIO +from pathlib import Path + +from starlette.datastructures import Headers, UploadFile + +from app.core.cloud.storage import CloudStorage, CloudStorageError + +logger = logging.getLogger(__name__) + + +def upload_csv_to_object_store( + storage: CloudStorage, + csv_content: bytes, + filename: str, + subdirectory: str = "datasets", +) -> str | None: + """ + Upload CSV content to object store. + + Args: + storage: CloudStorage instance + csv_content: Raw CSV content as bytes + filename: Name of the file (can include timestamp) + subdirectory: Subdirectory path in object store (default: "datasets") + + Returns: + Object store URL as string if successful, None if failed + + Note: + This function handles errors gracefully and returns None on failure. + Callers should continue without object store URL when this returns None. + """ + logger.info( + f"[upload_csv_to_object_store] Preparing to upload '{filename}' | " + f"size={len(csv_content)} bytes, subdirectory='{subdirectory}'" + ) + + try: + # Create file path + file_path = Path(subdirectory) / filename + + # Create a mock UploadFile-like object for the storage put method + class CSVFile: + def __init__(self, content: bytes): + self.file = io.BytesIO(content) + self.content_type = "text/csv" + + csv_file = CSVFile(csv_content) + + # Upload to object store + destination = storage.put(source=csv_file, file_path=file_path) + object_store_url = str(destination) + + logger.info( + f"[upload_csv_to_object_store] Upload successful | " + f"filename='{filename}', url='{object_store_url}'" + ) + return object_store_url + + except CloudStorageError as e: + logger.warning( + f"[upload_csv_to_object_store] Upload failed for '{filename}': {e}. " + "Continuing without object store storage." + ) + return None + except Exception as e: + logger.warning( + f"[upload_csv_to_object_store] Unexpected error uploading '{filename}': {e}. " + "Continuing without object store storage.", + exc_info=True, + ) + return None + + +def upload_jsonl_to_object_store( + storage: CloudStorage, + results: list[dict], + filename: str, + subdirectory: str, +) -> str | None: + """ + Upload JSONL (JSON Lines) content to object store. + + Args: + storage: CloudStorage instance + results: List of dictionaries to be converted to JSONL + filename: Name of the file + subdirectory: Subdirectory path in object store (e.g., "evaluation/batch-123") + + Returns: + Object store URL as string if successful, None if failed + + Note: + This function handles errors gracefully and returns None on failure. + Callers should continue without object store URL when this returns None. + """ + logger.info( + f"[upload_jsonl_to_object_store] Preparing to upload '{filename}' | " + f"items={len(results)}, subdirectory='{subdirectory}'" + ) + + try: + # Create file path + file_path = Path(subdirectory) / filename + + # Convert results to JSONL + jsonl_content = "\n".join([json.dumps(result) for result in results]) + content_bytes = jsonl_content.encode("utf-8") + + # Create UploadFile-like object + headers = Headers({"content-type": "application/jsonl"}) + upload_file = UploadFile( + filename=filename, + file=BytesIO(content_bytes), + headers=headers, + ) + + # Upload to object store + destination = storage.put(source=upload_file, file_path=file_path) + object_store_url = str(destination) + + logger.info( + f"[upload_jsonl_to_object_store] Upload successful | " + f"filename='{filename}', url='{object_store_url}', " + f"size={len(content_bytes)} bytes" + ) + return object_store_url + + except CloudStorageError as e: + logger.warning( + f"[upload_jsonl_to_object_store] Upload failed for '{filename}': {e}. " + "Continuing without object store storage." + ) + return None + except Exception as e: + logger.warning( + f"[upload_jsonl_to_object_store] Unexpected error uploading '{filename}': {e}. " + "Continuing without object store storage.", + exc_info=True, + ) + return None + + +def generate_timestamped_filename(base_name: str, extension: str = "csv") -> str: + """ + Generate a filename with timestamp. + + Args: + base_name: Base name for the file (e.g., "dataset_name" or "batch-123") + extension: File extension without dot (default: "csv") + + Returns: + Filename with timestamp (e.g., "dataset_name_20250114_153045.csv") + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return f"{base_name}_{timestamp}.{extension}" diff --git a/backend/app/crud/batch_job.py b/backend/app/crud/batch_job.py new file mode 100644 index 000000000..c121a90cd --- /dev/null +++ b/backend/app/crud/batch_job.py @@ -0,0 +1,222 @@ +"""CRUD operations for batch_job table.""" + +import logging + +from sqlmodel import Session, select + +from app.core.util import now +from app.models.batch_job import BatchJob, BatchJobCreate, BatchJobUpdate + +logger = logging.getLogger(__name__) + + +def create_batch_job( + session: Session, + batch_job_create: BatchJobCreate, +) -> BatchJob: + """ + Create a new batch job record. + + Args: + session: Database session + batch_job_create: BatchJobCreate schema with all required fields + + Returns: + Created BatchJob object + + Raises: + Exception: If creation fails + """ + logger.info( + f"[create_batch_job] Creating batch job | " + f"provider={batch_job_create.provider} | " + f"job_type={batch_job_create.job_type} | " + f"org_id={batch_job_create.organization_id} | " + f"project_id={batch_job_create.project_id}" + ) + + try: + batch_job = BatchJob.model_validate(batch_job_create) + batch_job.inserted_at = now() + batch_job.updated_at = now() + + session.add(batch_job) + session.commit() + session.refresh(batch_job) + + logger.info(f"[create_batch_job] Created batch job | id={batch_job.id}") + + return batch_job + + except Exception as e: + logger.error( + f"[create_batch_job] Failed to create batch job | {e}", exc_info=True + ) + session.rollback() + raise + + +def get_batch_job(session: Session, batch_job_id: int) -> BatchJob | None: + """ + Get a batch job by ID. + + Args: + session: Database session + batch_job_id: Batch job ID + + Returns: + BatchJob object if found, None otherwise + """ + statement = select(BatchJob).where(BatchJob.id == batch_job_id) + batch_job = session.exec(statement).first() + + return batch_job + + +def update_batch_job( + session: Session, + batch_job: BatchJob, + batch_job_update: BatchJobUpdate, +) -> BatchJob: + """ + Update a batch job record. + + Args: + session: Database session + batch_job: BatchJob object to update + batch_job_update: BatchJobUpdate schema with fields to update + + Returns: + Updated BatchJob object + + Raises: + Exception: If update fails + """ + logger.info(f"[update_batch_job] Updating batch job | id={batch_job.id}") + + try: + # Update fields if provided + update_data = batch_job_update.model_dump(exclude_unset=True) + + for key, value in update_data.items(): + setattr(batch_job, key, value) + + batch_job.updated_at = now() + + session.add(batch_job) + session.commit() + session.refresh(batch_job) + + logger.info(f"[update_batch_job] Updated batch job | id={batch_job.id}") + + return batch_job + + except Exception as e: + logger.error( + f"[update_batch_job] Failed to update batch job | id={batch_job.id} | {e}", + exc_info=True, + ) + session.rollback() + raise + + +def get_batch_jobs_by_ids( + session: Session, + batch_job_ids: list[int], +) -> list[BatchJob]: + """ + Get batch jobs by their IDs. + + This is used by parent tables to get their associated batch jobs for polling. + + Args: + session: Database session + batch_job_ids: List of batch job IDs + + Returns: + List of BatchJob objects + """ + if not batch_job_ids: + return [] + + statement = select(BatchJob).where(BatchJob.id.in_(batch_job_ids)) + results = session.exec(statement).all() + + logger.info( + f"[get_batch_jobs_by_ids] Found batch jobs | found={len(results)} | requested={len(batch_job_ids)}" + ) + + return list(results) + + +def get_batches_by_type( + session: Session, + job_type: str, + organization_id: int | None = None, + project_id: int | None = None, + provider_status: str | None = None, +) -> list[BatchJob]: + """ + Get batch jobs by type with optional filters. + + Args: + session: Database session + job_type: Job type (e.g., "evaluation", "classification") + organization_id: Optional filter by organization ID + project_id: Optional filter by project ID + provider_status: Optional filter by provider status + + Returns: + List of BatchJob objects matching filters + """ + statement = select(BatchJob).where(BatchJob.job_type == job_type) + + if organization_id: + statement = statement.where(BatchJob.organization_id == organization_id) + + if project_id: + statement = statement.where(BatchJob.project_id == project_id) + + if provider_status: + statement = statement.where(BatchJob.provider_status == provider_status) + + results = session.exec(statement).all() + + logger.info( + f"[get_batches_by_type] Found batch jobs | " + f"count={len(results)} | " + f"job_type={job_type} | " + f"org_id={organization_id} | " + f"project_id={project_id} | " + f"provider_status={provider_status}" + ) + + return list(results) + + +def delete_batch_job(session: Session, batch_job: BatchJob) -> None: + """ + Delete a batch job record. + + Args: + session: Database session + batch_job: BatchJob object to delete + + Raises: + Exception: If deletion fails + """ + logger.info(f"[delete_batch_job] Deleting batch job | id={batch_job.id}") + + try: + session.delete(batch_job) + session.commit() + + logger.info(f"[delete_batch_job] Deleted batch job | id={batch_job.id}") + + except Exception as e: + logger.error( + f"[delete_batch_job] Failed to delete batch job | id={batch_job.id} | {e}", + exc_info=True, + ) + session.rollback() + raise diff --git a/backend/app/crud/batch_operations.py b/backend/app/crud/batch_operations.py new file mode 100644 index 000000000..f2bb332e8 --- /dev/null +++ b/backend/app/crud/batch_operations.py @@ -0,0 +1,232 @@ +"""Generic batch operations orchestrator.""" + +import logging +from typing import Any + +from sqlmodel import Session + +from app.core.batch.base import BatchProvider +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import upload_jsonl_to_object_store as shared_upload_jsonl +from app.crud.batch_job import ( + create_batch_job, + update_batch_job, +) +from app.models.batch_job import BatchJob, BatchJobCreate, BatchJobUpdate + +logger = logging.getLogger(__name__) + + +def start_batch_job( + session: Session, + provider: BatchProvider, + provider_name: str, + job_type: str, + organization_id: int, + project_id: int, + jsonl_data: list[dict[str, Any]], + config: dict[str, Any], +) -> BatchJob: + """ + Create and start a batch job with the specified provider. + + Creates a batch_job record, calls the provider to create the batch, + and updates the record with provider IDs. + + Returns: + BatchJob with provider IDs populated + """ + logger.info( + f"[start_batch_job] Starting | provider={provider_name} | type={job_type} | " + f"org={organization_id} | project={project_id} | items={len(jsonl_data)}" + ) + + batch_job_create = BatchJobCreate( + provider=provider_name, + job_type=job_type, + organization_id=organization_id, + project_id=project_id, + config=config, + total_items=len(jsonl_data), + ) + + batch_job = create_batch_job(session=session, batch_job_create=batch_job_create) + + try: + batch_result = provider.create_batch(jsonl_data=jsonl_data, config=config) + + batch_job_update = BatchJobUpdate( + provider_batch_id=batch_result["provider_batch_id"], + provider_file_id=batch_result["provider_file_id"], + provider_status=batch_result["provider_status"], + total_items=batch_result.get("total_items", len(jsonl_data)), + ) + + batch_job = update_batch_job( + session=session, batch_job=batch_job, batch_job_update=batch_job_update + ) + + logger.info( + f"[start_batch_job] Success | id={batch_job.id} | " + f"provider_batch_id={batch_job.provider_batch_id}" + ) + + return batch_job + + except Exception as e: + logger.error(f"[start_batch_job] Failed | {e}", exc_info=True) + + batch_job_update = BatchJobUpdate( + error_message=f"Batch creation failed: {str(e)}" + ) + update_batch_job( + session=session, batch_job=batch_job, batch_job_update=batch_job_update + ) + + raise + + +def poll_batch_status( + session: Session, provider: BatchProvider, batch_job: BatchJob +) -> dict[str, Any]: + """Poll provider for batch status and update database.""" + logger.info( + f"[poll_batch_status] Polling | id={batch_job.id} | " + f"provider_batch_id={batch_job.provider_batch_id}" + ) + + try: + status_result = provider.get_batch_status(batch_job.provider_batch_id) + + provider_status = status_result["provider_status"] + if provider_status != batch_job.provider_status: + update_data = {"provider_status": provider_status} + + if status_result.get("provider_output_file_id"): + update_data["provider_output_file_id"] = status_result[ + "provider_output_file_id" + ] + + if status_result.get("error_message"): + update_data["error_message"] = status_result["error_message"] + + batch_job_update = BatchJobUpdate(**update_data) + batch_job = update_batch_job( + session=session, batch_job=batch_job, batch_job_update=batch_job_update + ) + + logger.info( + f"[poll_batch_status] Updated | id={batch_job.id} | " + f"{batch_job.provider_status} -> {provider_status}" + ) + + return status_result + + except Exception as e: + logger.error(f"[poll_batch_status] Failed | {e}", exc_info=True) + raise + + +def download_batch_results( + provider: BatchProvider, batch_job: BatchJob +) -> list[dict[str, Any]]: + """Download raw batch results from provider.""" + if not batch_job.provider_output_file_id: + raise ValueError( + f"Batch job {batch_job.id} does not have provider_output_file_id" + ) + + logger.info( + f"[download_batch_results] Downloading | id={batch_job.id} | " + f"output_file_id={batch_job.provider_output_file_id}" + ) + + try: + results = provider.download_batch_results(batch_job.provider_output_file_id) + + logger.info( + f"[download_batch_results] Downloaded | batch_job_id={batch_job.id} | " + f"results={len(results)}" + ) + + return results + + except Exception as e: + logger.error(f"[download_batch_results] Failed | {e}", exc_info=True) + raise + + +def process_completed_batch( + session: Session, + provider: BatchProvider, + batch_job: BatchJob, + upload_to_object_store: bool = True, +) -> tuple[list[dict[str, Any]], str | None]: + """ + Process a completed batch: download results and optionally upload to object store. + + Returns: + Tuple of (results, object_store_url) + """ + logger.info(f"[process_completed_batch] Processing | id={batch_job.id}") + + try: + results = download_batch_results(provider=provider, batch_job=batch_job) + + object_store_url = None + if upload_to_object_store: + try: + object_store_url = upload_batch_results_to_object_store( + session=session, batch_job=batch_job, results=results + ) + logger.info( + f"[process_completed_batch] Uploaded to object store | {object_store_url}" + ) + except Exception as store_error: + logger.warning( + f"[process_completed_batch] Object store upload failed " + f"(credentials may not be configured) | {store_error}", + exc_info=True, + ) + + if object_store_url: + batch_job_update = BatchJobUpdate(raw_output_url=object_store_url) + update_batch_job( + session=session, batch_job=batch_job, batch_job_update=batch_job_update + ) + + return results, object_store_url + + except Exception as e: + logger.error(f"[process_completed_batch] Failed | {e}", exc_info=True) + raise + + +def upload_batch_results_to_object_store( + session: Session, batch_job: BatchJob, results: list[dict[str, Any]] +) -> str | None: + """Upload batch results to object store.""" + logger.info( + f"[upload_batch_results_to_object_store] Uploading | batch_job_id={batch_job.id}" + ) + + try: + storage = get_cloud_storage(session=session, project_id=batch_job.project_id) + + subdirectory = f"{batch_job.job_type}/batch-{batch_job.id}" + filename = "results.jsonl" + + object_store_url = shared_upload_jsonl( + storage=storage, + results=results, + filename=filename, + subdirectory=subdirectory, + ) + + return object_store_url + + except Exception as e: + logger.error( + f"[upload_batch_results_to_object_store] Failed | {e}", exc_info=True + ) + raise diff --git a/backend/app/crud/evaluations/__init__.py b/backend/app/crud/evaluations/__init__.py new file mode 100644 index 000000000..d07cf8676 --- /dev/null +++ b/backend/app/crud/evaluations/__init__.py @@ -0,0 +1,66 @@ +"""Evaluation-related CRUD operations.""" + +from app.crud.evaluations.batch import start_evaluation_batch +from app.crud.evaluations.core import ( + create_evaluation_run, + get_evaluation_run_by_id, + list_evaluation_runs, +) +from app.crud.evaluations.cron import ( + process_all_pending_evaluations, + process_all_pending_evaluations_sync, +) +from app.crud.evaluations.dataset import ( + create_evaluation_dataset, + delete_dataset, + get_dataset_by_id, + list_datasets, + upload_csv_to_object_store, +) +from app.crud.evaluations.embeddings import ( + calculate_average_similarity, + calculate_cosine_similarity, + start_embedding_batch, +) +from app.crud.evaluations.langfuse import ( + create_langfuse_dataset_run, + update_traces_with_cosine_scores, + upload_dataset_to_langfuse_from_csv, +) +from app.crud.evaluations.processing import ( + check_and_process_evaluation, + poll_all_pending_evaluations, + process_completed_embedding_batch, + process_completed_evaluation, +) + +__all__ = [ + # Core + "create_evaluation_run", + "get_evaluation_run_by_id", + "list_evaluation_runs", + # Cron + "process_all_pending_evaluations", + "process_all_pending_evaluations_sync", + # Dataset + "create_evaluation_dataset", + "delete_dataset", + "get_dataset_by_id", + "list_datasets", + "upload_csv_to_object_store", + # Batch + "start_evaluation_batch", + # Processing + "check_and_process_evaluation", + "poll_all_pending_evaluations", + "process_completed_embedding_batch", + "process_completed_evaluation", + # Embeddings + "calculate_average_similarity", + "calculate_cosine_similarity", + "start_embedding_batch", + # Langfuse + "create_langfuse_dataset_run", + "update_traces_with_cosine_scores", + "upload_dataset_to_langfuse_from_csv", +] diff --git a/backend/app/crud/evaluations/batch.py b/backend/app/crud/evaluations/batch.py new file mode 100644 index 000000000..7e8b69043 --- /dev/null +++ b/backend/app/crud/evaluations/batch.py @@ -0,0 +1,212 @@ +""" +Evaluation-specific batch preparation and orchestration. + +This module handles: +1. Fetching dataset items from Langfuse +2. Building evaluation-specific JSONL for batch processing +3. Starting evaluation batches using generic batch infrastructure +""" + +import logging +from typing import Any + +from langfuse import Langfuse +from openai import OpenAI +from sqlmodel import Session + +from app.core.batch.openai import OpenAIBatchProvider +from app.crud.batch_operations import start_batch_job +from app.models import EvaluationRun + +logger = logging.getLogger(__name__) + + +def fetch_dataset_items(langfuse: Langfuse, dataset_name: str) -> list[dict[str, Any]]: + """ + Fetch all items from a Langfuse dataset. + + Args: + langfuse: Configured Langfuse client + dataset_name: Name of the dataset to fetch + + Returns: + List of dataset items with input and expected_output + + Raises: + ValueError: If dataset not found or empty + """ + try: + dataset = langfuse.get_dataset(dataset_name) + except Exception as e: + logger.error( + f"[fetch_dataset_items] Failed to fetch dataset | dataset={dataset_name} | {e}" + ) + raise ValueError(f"Dataset '{dataset_name}' not found: {e}") + + if not dataset.items: + raise ValueError(f"Dataset '{dataset_name}' is empty") + + items = [] + for item in dataset.items: + items.append( + { + "id": item.id, + "input": item.input, + "expected_output": item.expected_output, + "metadata": item.metadata if hasattr(item, "metadata") else {}, + } + ) + return items + + +def build_evaluation_jsonl( + dataset_items: list[dict[str, Any]], config: dict[str, Any] +) -> list[dict[str, Any]]: + """ + Build JSONL data for evaluation batch using OpenAI Responses API. + + Each line is a dict with: + - custom_id: Unique identifier for the request (dataset item ID) + - method: POST + - url: /v1/responses + - body: Response request using config as-is with input from dataset + + Args: + dataset_items: List of dataset items from Langfuse + config: Evaluation configuration dict with OpenAI Responses API parameters. + This config is used as-is in the body, with only "input" being added + from the dataset. Config can include any fields like: + - model (required) + - instructions + - tools + - reasoning + - text + - temperature + - include + etc. + + Returns: + List of dictionaries (JSONL data) + """ + jsonl_data = [] + + for item in dataset_items: + # Extract question from input + question = item["input"].get("question", "") + if not question: + logger.warning( + f"[build_evaluation_jsonl] Skipping item - no question found | item_id={item['id']}" + ) + continue + + # Build the batch request object for Responses API + # Use config as-is and only add the input field + batch_request = { + "custom_id": item["id"], + "method": "POST", + "url": "/v1/responses", + "body": { + **config, # Use config as-is + "input": question, # Add input from dataset + }, + } + + jsonl_data.append(batch_request) + return jsonl_data + + +def start_evaluation_batch( + langfuse: Langfuse, + openai_client: OpenAI, + session: Session, + eval_run: EvaluationRun, + config: dict[str, Any], +) -> EvaluationRun: + """ + Fetch data, build JSONL, and start evaluation batch. + + This function orchestrates the evaluation-specific logic and delegates + to the generic batch infrastructure for actual batch creation. + + Args: + langfuse: Configured Langfuse client + openai_client: Configured OpenAI client + session: Database session + eval_run: EvaluationRun database object (with run_name, dataset_name, config) + config: Evaluation configuration dict with llm, instructions, vector_store_ids + + Returns: + Updated EvaluationRun with batch_job_id populated + + Raises: + Exception: If any step fails + """ + try: + # Step 1: Fetch dataset items from Langfuse + logger.info( + f"[start_evaluation_batch] Starting evaluation batch | run={eval_run.run_name}" + ) + dataset_items = fetch_dataset_items( + langfuse=langfuse, dataset_name=eval_run.dataset_name + ) + + # Step 2: Build evaluation-specific JSONL + jsonl_data = build_evaluation_jsonl(dataset_items=dataset_items, config=config) + + if not jsonl_data: + raise ValueError( + "Evaluation dataset did not produce any JSONL entries (missing questions?)." + ) + + # Step 3: Create batch provider + provider = OpenAIBatchProvider(client=openai_client) + + # Step 4: Prepare batch configuration + batch_config = { + "endpoint": "/v1/responses", + "description": f"Evaluation: {eval_run.run_name}", + "completion_window": "24h", + # Store complete config for reference + "evaluation_config": config, + } + + # Step 5: Start batch job using generic infrastructure + batch_job = start_batch_job( + session=session, + provider=provider, + provider_name="openai", + job_type="evaluation", + organization_id=eval_run.organization_id, + project_id=eval_run.project_id, + jsonl_data=jsonl_data, + config=batch_config, + ) + + # Step 6: Link batch_job to evaluation_run + eval_run.batch_job_id = batch_job.id + eval_run.status = "processing" + eval_run.total_items = batch_job.total_items + + session.add(eval_run) + session.commit() + session.refresh(eval_run) + + logger.info( + f"[start_evaluation_batch] Successfully started evaluation batch | " + f"batch_job_id={batch_job.id} | " + f"provider_batch_id={batch_job.provider_batch_id} | " + f"run={eval_run.run_name} | items={batch_job.total_items}" + ) + + return eval_run + + except Exception as e: + logger.error( + f"[start_evaluation_batch] Failed to start evaluation batch | {e}", + exc_info=True, + ) + eval_run.status = "failed" + eval_run.error_message = str(e) + session.add(eval_run) + session.commit() + raise diff --git a/backend/app/crud/evaluations/core.py b/backend/app/crud/evaluations/core.py new file mode 100644 index 000000000..a964f26b9 --- /dev/null +++ b/backend/app/crud/evaluations/core.py @@ -0,0 +1,308 @@ +import csv +import io +import logging + +from fastapi import HTTPException +from sqlmodel import Session, select + +from app.core.util import now +from app.models import EvaluationRun, UserProjectOrg +from app.models.evaluation import DatasetUploadResponse +from app.utils import get_langfuse_client + +logger = logging.getLogger(__name__) + + +async def upload_dataset_to_langfuse( + csv_content: bytes, + dataset_name: str, + dataset_id: int, + duplication_factor: int, + _session: Session, + _current_user: UserProjectOrg, +) -> tuple[bool, DatasetUploadResponse | None, str | None]: + """ + Upload a CSV dataset to Langfuse with duplication for flakiness testing. + + Args: + csv_content: Raw CSV file content as bytes + dataset_name: Name for the dataset in Langfuse + dataset_id: Database ID of the created dataset + duplication_factor: Number of times to duplicate each item (default 5) + _session: Database session + _current_user: Current user organization + + Returns: + Tuple of (success, dataset_response, error_message) + """ + try: + # Get Langfuse client + try: + langfuse = get_langfuse_client( + session=_session, + org_id=_current_user.organization_id, + project_id=_current_user.project_id, + ) + except HTTPException as http_exc: + return False, None, http_exc.detail + + # Parse CSV content + csv_text = csv_content.decode("utf-8") + csv_reader = csv.DictReader(io.StringIO(csv_text)) + + # Validate CSV headers + if ( + "question" not in csv_reader.fieldnames + or "answer" not in csv_reader.fieldnames + ): + return ( + False, + None, + "CSV must contain 'question' and 'answer' columns. " + f"Found columns: {csv_reader.fieldnames}", + ) + + # Read all rows from CSV + original_items = [] + for row in csv_reader: + question = row.get("question", "").strip() + answer = row.get("answer", "").strip() + + if not question or not answer: + logger.warning(f"Skipping row with empty question or answer: {row}") + continue + + original_items.append({"question": question, "answer": answer}) + + if not original_items: + return False, None, "No valid items found in CSV file." + + logger.info( + f"Parsed {len(original_items)} items from CSV. " + f"Will duplicate {duplication_factor}x for a total of {len(original_items) * duplication_factor} items." + ) + + # Create or get dataset in Langfuse + dataset = langfuse.create_dataset(name=dataset_name) + + # Upload items with duplication + total_uploaded = 0 + for item in original_items: + # Duplicate each item N times + for duplicate_num in range(duplication_factor): + try: + langfuse.create_dataset_item( + dataset_name=dataset_name, + input={"question": item["question"]}, + expected_output={"answer": item["answer"]}, + metadata={ + "original_question": item["question"], + "duplicate_number": duplicate_num + 1, + "duplication_factor": duplication_factor, + }, + ) + total_uploaded += 1 + except Exception as e: + logger.error( + f"Failed to upload item (duplicate {duplicate_num + 1}): {item['question'][:50]}... Error: {e}" + ) + + # Flush to ensure all items are uploaded + langfuse.flush() + + logger.info( + f"Successfully uploaded {total_uploaded} items to dataset '{dataset_name}' " + f"({len(original_items)} original × {duplication_factor} duplicates)" + ) + + return ( + True, + DatasetUploadResponse( + dataset_id=dataset_id, + dataset_name=dataset_name, + total_items=total_uploaded, + original_items=len(original_items), + duplication_factor=duplication_factor, + langfuse_dataset_id=dataset.id if hasattr(dataset, "id") else None, + ), + None, + ) + + except Exception as e: + logger.error(f"Error uploading dataset: {str(e)}", exc_info=True) + return False, None, f"Failed to upload dataset: {str(e)}" + + +def create_evaluation_run( + session: Session, + run_name: str, + dataset_name: str, + dataset_id: int, + config: dict, + organization_id: int, + project_id: int, +) -> EvaluationRun: + """ + Create a new evaluation run record in the database. + + Args: + session: Database session + run_name: Name of the evaluation run/experiment + dataset_name: Name of the dataset being used + dataset_id: ID of the dataset + config: Configuration dict for the evaluation + organization_id: Organization ID + project_id: Project ID + + Returns: + The created EvaluationRun instance + """ + eval_run = EvaluationRun( + run_name=run_name, + dataset_name=dataset_name, + dataset_id=dataset_id, + config=config, + status="pending", + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + session.add(eval_run) + session.commit() + session.refresh(eval_run) + + logger.info(f"Created EvaluationRun record: id={eval_run.id}, run_name={run_name}") + + return eval_run + + +def list_evaluation_runs( + session: Session, + organization_id: int, + project_id: int, + limit: int = 50, + offset: int = 0, +) -> list[EvaluationRun]: + """ + List all evaluation runs for an organization and project. + + Args: + session: Database session + organization_id: Organization ID to filter by + project_id: Project ID to filter by + limit: Maximum number of runs to return (default 50) + offset: Number of runs to skip (for pagination) + + Returns: + List of EvaluationRun objects, ordered by most recent first + """ + statement = ( + select(EvaluationRun) + .where(EvaluationRun.organization_id == organization_id) + .where(EvaluationRun.project_id == project_id) + .order_by(EvaluationRun.inserted_at.desc()) + .limit(limit) + .offset(offset) + ) + + runs = session.exec(statement).all() + + logger.info( + f"Found {len(runs)} evaluation runs for org_id={organization_id}, " + f"project_id={project_id}" + ) + + return list(runs) + + +def get_evaluation_run_by_id( + session: Session, + evaluation_id: int, + organization_id: int, + project_id: int, +) -> EvaluationRun | None: + """ + Get a specific evaluation run by ID. + + Args: + session: Database session + evaluation_id: ID of the evaluation run + organization_id: Organization ID (for access control) + project_id: Project ID (for access control) + + Returns: + EvaluationRun if found and accessible, None otherwise + """ + statement = ( + select(EvaluationRun) + .where(EvaluationRun.id == evaluation_id) + .where(EvaluationRun.organization_id == organization_id) + .where(EvaluationRun.project_id == project_id) + ) + + eval_run = session.exec(statement).first() + + if eval_run: + logger.info( + f"Found evaluation run {evaluation_id}: status={eval_run.status}, " + f"batch_job_id={eval_run.batch_job_id}" + ) + else: + logger.warning( + f"Evaluation run {evaluation_id} not found or not accessible " + f"for org_id={organization_id}, project_id={project_id}" + ) + + return eval_run + + +def update_evaluation_run( + session: Session, + eval_run: EvaluationRun, + status: str | None = None, + error_message: str | None = None, + object_store_url: str | None = None, + score: dict | None = None, + embedding_batch_job_id: int | None = None, +) -> EvaluationRun: + """ + Update an evaluation run with new values and persist to database. + + This helper function ensures consistency when updating evaluation runs + by always updating the timestamp and properly committing changes. + + Args: + session: Database session + eval_run: EvaluationRun instance to update + status: New status value (optional) + error_message: New error message (optional) + object_store_url: New object store URL (optional) + score: New score dict (optional) + embedding_batch_job_id: New embedding batch job ID (optional) + + Returns: + Updated and refreshed EvaluationRun instance + """ + # Update provided fields + if status is not None: + eval_run.status = status + if error_message is not None: + eval_run.error_message = error_message + if object_store_url is not None: + eval_run.object_store_url = object_store_url + if score is not None: + eval_run.score = score + if embedding_batch_job_id is not None: + eval_run.embedding_batch_job_id = embedding_batch_job_id + + # Always update timestamp + eval_run.updated_at = now() + + # Persist to database + session.add(eval_run) + session.commit() + session.refresh(eval_run) + + return eval_run diff --git a/backend/app/crud/evaluations/cron.py b/backend/app/crud/evaluations/cron.py new file mode 100644 index 000000000..ca6bd2af2 --- /dev/null +++ b/backend/app/crud/evaluations/cron.py @@ -0,0 +1,158 @@ +""" +CRUD operations for evaluation cron jobs. + +This module provides functions that can be invoked periodically to process +pending evaluations across all organizations. +""" + +import asyncio +import logging +from typing import Any + +from sqlmodel import Session, select + +from app.crud.evaluations.processing import poll_all_pending_evaluations +from app.models import Organization + +logger = logging.getLogger(__name__) + + +async def process_all_pending_evaluations(session: Session) -> dict[str, Any]: + """ + Process all pending evaluations across all organizations. + + This function: + 1. Gets all organizations + 2. For each org, polls their pending evaluations + 3. Processes completed batches automatically + 4. Returns aggregated results + + This is the main function that should be called by the cron endpoint. + + Args: + session: Database session + + Returns: + Dict with aggregated results: + { + "status": "success", + "organizations_processed": 3, + "total_processed": 5, + "total_failed": 1, + "total_still_processing": 2, + "results": [ + { + "org_id": 1, + "org_name": "Org 1", + "summary": {...} + }, + ... + ] + } + """ + logger.info("[process_all_pending_evaluations] Starting evaluation processing") + + try: + # Get all organizations + orgs = session.exec(select(Organization)).all() + + if not orgs: + logger.info("[process_all_pending_evaluations] No organizations found") + return { + "status": "success", + "organizations_processed": 0, + "total_processed": 0, + "total_failed": 0, + "total_still_processing": 0, + "message": "No organizations to process", + "results": [], + } + + logger.info( + f"[process_all_pending_evaluations] Found {len(orgs)} organizations to process" + ) + + results = [] + total_processed = 0 + total_failed = 0 + total_still_processing = 0 + + # Process each organization + for org in orgs: + try: + logger.info( + f"[process_all_pending_evaluations] Processing org_id={org.id} ({org.name})" + ) + + # Poll all pending evaluations for this org + summary = await poll_all_pending_evaluations( + session=session, org_id=org.id + ) + + results.append( + { + "org_id": org.id, + "org_name": org.name, + "summary": summary, + } + ) + + total_processed += summary.get("processed", 0) + total_failed += summary.get("failed", 0) + total_still_processing += summary.get("still_processing", 0) + + except Exception as e: + logger.error( + f"[process_all_pending_evaluations] Error processing org_id={org.id}: {e}", + exc_info=True, + ) + session.rollback() + results.append( + {"org_id": org.id, "org_name": org.name, "error": str(e)} + ) + total_failed += 1 + + logger.info( + f"[process_all_pending_evaluations] Completed: " + f"{total_processed} processed, {total_failed} failed, " + f"{total_still_processing} still processing" + ) + + return { + "status": "success", + "organizations_processed": len(orgs), + "total_processed": total_processed, + "total_failed": total_failed, + "total_still_processing": total_still_processing, + "results": results, + } + + except Exception as e: + logger.error( + f"[process_all_pending_evaluations] Fatal error: {e}", + exc_info=True, + ) + return { + "status": "error", + "organizations_processed": 0, + "total_processed": 0, + "total_failed": 0, + "total_still_processing": 0, + "error": str(e), + "results": [], + } + + +def process_all_pending_evaluations_sync(session: Session) -> dict[str, Any]: + """ + Synchronous wrapper for process_all_pending_evaluations. + + This function can be called from synchronous contexts (like FastAPI endpoints). + + Args: + session: Database session + + Returns: + Dict with aggregated results (same as process_all_pending_evaluations) + """ + return asyncio.run(process_all_pending_evaluations(session=session)) diff --git a/backend/app/crud/evaluations/dataset.py b/backend/app/crud/evaluations/dataset.py new file mode 100644 index 000000000..7efa03d46 --- /dev/null +++ b/backend/app/crud/evaluations/dataset.py @@ -0,0 +1,387 @@ +""" +CRUD operations for evaluation datasets. + +This module handles database operations for evaluation datasets including: +1. Creating new datasets +2. Fetching datasets by ID or name +3. Listing datasets with pagination +4. Uploading CSV files to AWS S3 +""" + +import logging +from typing import Any + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.core.cloud.storage import CloudStorage +from app.core.storage_utils import ( + generate_timestamped_filename, +) +from app.core.storage_utils import ( + upload_csv_to_object_store as shared_upload_csv, +) +from app.core.util import now +from app.models import EvaluationDataset, EvaluationRun + +logger = logging.getLogger(__name__) + + +def create_evaluation_dataset( + session: Session, + name: str, + dataset_metadata: dict[str, Any], + organization_id: int, + project_id: int, + description: str | None = None, + object_store_url: str | None = None, + langfuse_dataset_id: str | None = None, +) -> EvaluationDataset: + """ + Create a new evaluation dataset record in the database. + + Args: + session: Database session + name: Name of the dataset + dataset_metadata: Dataset metadata (original_items_count, + total_items_count, duplication_factor) + organization_id: Organization ID + project_id: Project ID + description: Optional dataset description + object_store_url: Optional object store URL where CSV is stored + langfuse_dataset_id: Optional Langfuse dataset ID + + Returns: + Created EvaluationDataset object + + Raises: + HTTPException: 409 if dataset with same name exists, 500 for other errors + """ + try: + dataset = EvaluationDataset( + name=name, + description=description, + dataset_metadata=dataset_metadata, + object_store_url=object_store_url, + langfuse_dataset_id=langfuse_dataset_id, + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + session.add(dataset) + session.commit() + session.refresh(dataset) + + logger.info( + f"[create_evaluation_dataset] Created evaluation dataset | id={dataset.id} | name={name} | org_id={organization_id} | project_id={project_id}" + ) + + return dataset + + except IntegrityError as e: + session.rollback() + logger.error( + f"[create_evaluation_dataset] Database integrity error creating dataset | name={name} | {e}", + exc_info=True, + ) + raise HTTPException( + status_code=409, + detail=f"Dataset with name '{name}' already exists in this " + "organization and project. Please choose a different name.", + ) + + except Exception as e: + session.rollback() + logger.error( + f"[create_evaluation_dataset] Failed to create dataset record in database | {e}", + exc_info=True, + ) + raise HTTPException( + status_code=500, detail=f"Failed to save dataset metadata: {e}" + ) + + +def get_dataset_by_id( + session: Session, dataset_id: int, organization_id: int, project_id: int +) -> EvaluationDataset | None: + """ + Fetch an evaluation dataset by ID with organization and project validation. + + Args: + session: Database session + dataset_id: Dataset ID + organization_id: Organization ID for validation + project_id: Project ID for validation + + Returns: + EvaluationDataset if found and belongs to the org/project, None otherwise + """ + statement = ( + select(EvaluationDataset) + .where(EvaluationDataset.id == dataset_id) + .where(EvaluationDataset.organization_id == organization_id) + .where(EvaluationDataset.project_id == project_id) + ) + + dataset = session.exec(statement).first() + + if dataset: + logger.info( + f"[get_dataset_by_id] Found dataset | id={dataset_id} | name={dataset.name} | org_id={organization_id} | project_id={project_id}" + ) + else: + logger.warning( + f"[get_dataset_by_id] Dataset not found or not accessible | id={dataset_id} | org_id={organization_id} | project_id={project_id}" + ) + + return dataset + + +def get_dataset_by_name( + session: Session, name: str, organization_id: int, project_id: int +) -> EvaluationDataset | None: + """ + Fetch an evaluation dataset by name with organization and project validation. + + Args: + session: Database session + name: Dataset name + organization_id: Organization ID for validation + project_id: Project ID for validation + + Returns: + EvaluationDataset if found and belongs to the org/project, None otherwise + """ + statement = ( + select(EvaluationDataset) + .where(EvaluationDataset.name == name) + .where(EvaluationDataset.organization_id == organization_id) + .where(EvaluationDataset.project_id == project_id) + ) + + dataset = session.exec(statement).first() + + if dataset: + logger.info( + f"[get_dataset_by_name] Found dataset by name | name={name} | id={dataset.id} | org_id={organization_id} | project_id={project_id}" + ) + + return dataset + + +def list_datasets( + session: Session, + organization_id: int, + project_id: int, + limit: int = 50, + offset: int = 0, +) -> list[EvaluationDataset]: + """ + List all evaluation datasets for an organization and project with pagination. + + Args: + session: Database session + organization_id: Organization ID + project_id: Project ID + limit: Maximum number of datasets to return (default 50) + offset: Number of datasets to skip (for pagination) + + Returns: + List of EvaluationDataset objects, ordered by most recent first + """ + statement = ( + select(EvaluationDataset) + .where(EvaluationDataset.organization_id == organization_id) + .where(EvaluationDataset.project_id == project_id) + .order_by(EvaluationDataset.inserted_at.desc()) + .limit(limit) + .offset(offset) + ) + + datasets = session.exec(statement).all() + + logger.info( + f"[list_datasets] Listed datasets | count={len(datasets)} | org_id={organization_id} | project_id={project_id} | limit={limit} | offset={offset}" + ) + + return list(datasets) + + +def upload_csv_to_object_store( + storage: CloudStorage, + csv_content: bytes, + dataset_name: str, +) -> str | None: + """ + Upload CSV file to object store. + + This is a wrapper around the shared storage utility function, + providing dataset-specific file naming. + + Args: + storage: CloudStorage instance + csv_content: Raw CSV content as bytes + dataset_name: Name of the dataset (used for file naming) + + Returns: + Object store URL as string if successful, None if failed + + Note: + This function handles errors gracefully and returns None on failure. + Callers should continue without object store URL when this returns None. + """ + # Generate timestamped filename + filename = generate_timestamped_filename(dataset_name, extension="csv") + + # Use shared utility for upload + return shared_upload_csv( + storage=storage, + csv_content=csv_content, + filename=filename, + subdirectory="datasets", + ) + + +# Backward compatibility alias +upload_csv_to_s3 = upload_csv_to_object_store + + +def download_csv_from_object_store( + storage: CloudStorage, object_store_url: str +) -> bytes: + """ + Download CSV file from object store. + + Args: + storage: CloudStorage instance + object_store_url: Object store URL of the CSV file + + Returns: + CSV content as bytes + + Raises: + CloudStorageError: If download fails + ValueError: If object_store_url is None or empty + """ + if not object_store_url: + raise ValueError("object_store_url cannot be None or empty") + + try: + logger.info( + f"[download_csv_from_object_store] Downloading CSV from object store | {object_store_url}" + ) + body = storage.stream(object_store_url) + csv_content = body.read() + logger.info( + f"[download_csv_from_object_store] Successfully downloaded CSV from object store | bytes={len(csv_content)}" + ) + return csv_content + except Exception as e: + logger.error( + f"[download_csv_from_object_store] Failed to download CSV from object store | {object_store_url} | {e}", + exc_info=True, + ) + raise + + +# Backward compatibility alias +download_csv_from_s3 = download_csv_from_object_store + + +def update_dataset_langfuse_id( + session: Session, dataset_id: int, langfuse_dataset_id: str +) -> None: + """ + Update the langfuse_dataset_id for an existing dataset. + + Args: + session: Database session + dataset_id: Dataset ID + langfuse_dataset_id: Langfuse dataset ID to store + + Returns: + None + """ + dataset = session.get(EvaluationDataset, dataset_id) + if dataset: + dataset.langfuse_dataset_id = langfuse_dataset_id + dataset.updated_at = now() + session.add(dataset) + session.commit() + logger.info( + f"[update_dataset_langfuse_id] Updated langfuse_dataset_id | dataset_id={dataset_id} | langfuse_dataset_id={langfuse_dataset_id}" + ) + else: + logger.warning( + f"[update_dataset_langfuse_id] Dataset not found for langfuse_id update | dataset_id={dataset_id}" + ) + + +def delete_dataset( + session: Session, dataset_id: int, organization_id: int, project_id: int +) -> tuple[bool, str]: + """ + Delete an evaluation dataset by ID. + + This performs a hard delete from the database. The CSV file in object store (if exists) + will remain for audit purposes. + + Args: + session: Database session + dataset_id: Dataset ID to delete + organization_id: Organization ID for validation + project_id: Project ID for validation + + Returns: + Tuple of (success: bool, message: str) + """ + # First, fetch the dataset to ensure it exists and belongs to the org/project + dataset = get_dataset_by_id( + session=session, + dataset_id=dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + if not dataset: + return ( + False, + f"Dataset {dataset_id} not found or not accessible", + ) + + # Check if dataset is being used by any evaluation runs + statement = select(EvaluationRun).where(EvaluationRun.dataset_id == dataset_id) + evaluation_runs = session.exec(statement).all() + + if evaluation_runs: + return ( + False, + f"Cannot delete dataset {dataset_id}: it is being used by " + f"{len(evaluation_runs)} evaluation run(s). Please delete " + f"the evaluation runs first.", + ) + + # Delete the dataset + try: + session.delete(dataset) + session.commit() + + logger.info( + f"[delete_dataset] Deleted dataset | id={dataset_id} | name={dataset.name} | org_id={organization_id} | project_id={project_id}" + ) + + return ( + True, + f"Successfully deleted dataset '{dataset.name}' (id={dataset_id})", + ) + + except Exception as e: + session.rollback() + logger.error( + f"[delete_dataset] Failed to delete dataset | dataset_id={dataset_id} | {e}", + exc_info=True, + ) + return (False, f"Failed to delete dataset: {e}") diff --git a/backend/app/crud/evaluations/embeddings.py b/backend/app/crud/evaluations/embeddings.py new file mode 100644 index 000000000..70e374211 --- /dev/null +++ b/backend/app/crud/evaluations/embeddings.py @@ -0,0 +1,434 @@ +""" +Embedding-based similarity scoring for evaluation runs. + +This module handles: +1. Building JSONL for embedding batch requests +2. Parsing embedding results from batch API +3. Calculating cosine similarity between embeddings +4. Orchestrating embedding batch creation and processing +""" + +import logging +from typing import Any + +import numpy as np +from openai import OpenAI +from sqlmodel import Session + +from app.core.batch.openai import OpenAIBatchProvider +from app.core.util import now +from app.crud.batch_operations import start_batch_job +from app.models import EvaluationRun + +logger = logging.getLogger(__name__) + +# Valid embedding models with their dimensions +VALID_EMBEDDING_MODELS = { + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, +} + + +def validate_embedding_model(model: str) -> None: + """ + Validate that the embedding model is supported. + + Args: + model: The embedding model name + + Raises: + ValueError: If the model is not supported + """ + if model not in VALID_EMBEDDING_MODELS: + valid_models = ", ".join(VALID_EMBEDDING_MODELS.keys()) + raise ValueError( + f"Invalid embedding model '{model}'. " f"Supported models: {valid_models}" + ) + + +def build_embedding_jsonl( + results: list[dict[str, Any]], + trace_id_mapping: dict[str, str], + embedding_model: str = "text-embedding-3-large", +) -> list[dict[str, Any]]: + """ + Build JSONL data for embedding batch using OpenAI Embeddings API. + + Each line is a dict with: + - custom_id: Langfuse trace_id (for direct score updates) + - method: POST + - url: /v1/embeddings + - body: Embedding request with input array [output, ground_truth] + + Args: + results: List of evaluation results from parse_evaluation_output() + Format: [ + { + "item_id": "item_123", + "question": "What is 2+2?", + "generated_output": "The answer is 4", + "ground_truth": "4" + }, + ... + ] + trace_id_mapping: Mapping of item_id to Langfuse trace_id + embedding_model: OpenAI embedding model to use (default: text-embedding-3-large) + + Returns: + List of dictionaries (JSONL data) + """ + # Validate embedding model + validate_embedding_model(embedding_model) + + logger.info( + f"Building embedding JSONL for {len(results)} items with model {embedding_model}" + ) + + jsonl_data = [] + + for result in results: + item_id = result.get("item_id") + generated_output = result.get("generated_output", "") + ground_truth = result.get("ground_truth", "") + + if not item_id: + logger.warning("Skipping result with no item_id") + continue + + # Get trace_id from mapping + trace_id = trace_id_mapping.get(item_id) + if not trace_id: + logger.warning(f"Skipping item {item_id} - no trace_id found") + continue + + # Skip if either output or ground_truth is empty + if not generated_output or not ground_truth: + logger.warning(f"Skipping item {item_id} - empty output or ground_truth") + continue + + # Build the batch request object for Embeddings API + # Use trace_id as custom_id for direct score updates + batch_request = { + "custom_id": trace_id, + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": embedding_model, + "input": [ + generated_output, # Index 0 + ground_truth, # Index 1 + ], + "encoding_format": "float", + }, + } + + jsonl_data.append(batch_request) + + logger.info(f"Built {len(jsonl_data)} embedding JSONL lines") + return jsonl_data + + +def parse_embedding_results(raw_results: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Parse embedding batch output into structured embedding pairs. + + Args: + raw_results: Raw results from batch provider (list of JSONL lines) + + Returns: + List of embedding pairs in format: + [ + { + "trace_id": "trace-uuid-123", + "output_embedding": [0.1, 0.2, ...], + "ground_truth_embedding": [0.15, 0.22, ...] + }, + ... + ] + """ + logger.info(f"Parsing embedding results from {len(raw_results)} lines") + + embedding_pairs = [] + + for line_num, response in enumerate(raw_results, 1): + try: + # Extract custom_id (which is now the Langfuse trace_id) + trace_id = response.get("custom_id") + if not trace_id: + logger.warning(f"Line {line_num}: No custom_id found, skipping") + continue + + # Handle errors in batch processing + if response.get("error"): + error_msg = response["error"].get("message", "Unknown error") + logger.error(f"Trace {trace_id} had error: {error_msg}") + continue + + # Extract the response body + response_body = response.get("response", {}).get("body", {}) + embedding_data = response_body.get("data", []) + + if len(embedding_data) < 2: + logger.warning( + f"Trace {trace_id}: Expected 2 embeddings, got {len(embedding_data)}" + ) + continue + + # Extract embeddings by index + # Index 0 = generated_output embedding + # Index 1 = ground_truth embedding + output_embedding = None + ground_truth_embedding = None + + for emb_obj in embedding_data: + index = emb_obj.get("index") + embedding = emb_obj.get("embedding") + + if embedding is None: + continue + + if index == 0: + output_embedding = embedding + elif index == 1: + ground_truth_embedding = embedding + + if output_embedding is None or ground_truth_embedding is None: + logger.warning( + f"Trace {trace_id}: Missing embeddings (output={output_embedding is not None}, " + f"ground_truth={ground_truth_embedding is not None})" + ) + continue + + embedding_pairs.append( + { + "trace_id": trace_id, + "output_embedding": output_embedding, + "ground_truth_embedding": ground_truth_embedding, + } + ) + + except Exception as e: + logger.error(f"Line {line_num}: Unexpected error: {e}", exc_info=True) + continue + + logger.info( + f"Parsed {len(embedding_pairs)} embedding pairs from {len(raw_results)} lines" + ) + return embedding_pairs + + +def calculate_cosine_similarity(vec1: list[float], vec2: list[float]) -> float: + """ + Calculate cosine similarity between two vectors using numpy. + + Formula: similarity = dot(vec1, vec2) / (||vec1|| * ||vec2||) + + Args: + vec1: First embedding vector + vec2: Second embedding vector + + Returns: + Cosine similarity score (range: -1 to 1, typically 0 to 1 for embeddings) + """ + # Convert to numpy arrays + v1 = np.array(vec1) + v2 = np.array(vec2) + + # Calculate dot product + dot_product = np.dot(v1, v2) + + # Calculate norms + norm_v1 = np.linalg.norm(v1) + norm_v2 = np.linalg.norm(v2) + + # Handle edge case of zero vectors + if norm_v1 == 0 or norm_v2 == 0: + return 0.0 + + # Calculate cosine similarity + similarity = dot_product / (norm_v1 * norm_v2) + + return float(similarity) + + +def calculate_average_similarity( + embedding_pairs: list[dict[str, Any]] +) -> dict[str, Any]: + """ + Calculate cosine similarity statistics for all embedding pairs. + + Args: + embedding_pairs: List of embedding pairs from parse_embedding_results() + + Returns: + Dictionary with similarity statistics: + { + "cosine_similarity_avg": 0.87, + "cosine_similarity_std": 0.12, + "total_pairs": 50, + "per_item_scores": [...] # Individual scores with trace_ids + } + """ + logger.info(f"Calculating similarity for {len(embedding_pairs)} pairs") + + if not embedding_pairs: + return { + "cosine_similarity_avg": 0.0, + "cosine_similarity_std": 0.0, + "total_pairs": 0, + "per_item_scores": [], + } + + similarities = [] + per_item_scores = [] + + for pair in embedding_pairs: + try: + output_emb = pair["output_embedding"] + ground_truth_emb = pair["ground_truth_embedding"] + + similarity = calculate_cosine_similarity(output_emb, ground_truth_emb) + similarities.append(similarity) + + per_item_scores.append( + { + "trace_id": pair["trace_id"], + "cosine_similarity": similarity, + } + ) + + except Exception as e: + logger.error( + f"Error calculating similarity for trace {pair.get('trace_id')}: {e}" + ) + continue + + if not similarities: + logger.warning("No valid similarities calculated") + return { + "cosine_similarity_avg": 0.0, + "cosine_similarity_std": 0.0, + "total_pairs": 0, + "per_item_scores": [], + } + + # Calculate statistics + similarities_array = np.array(similarities) + + stats = { + "cosine_similarity_avg": float(np.mean(similarities_array)), + "cosine_similarity_std": float(np.std(similarities_array)), + "total_pairs": len(similarities), + "per_item_scores": per_item_scores, + } + + logger.info( + f"Calculated similarity stats: avg={stats['cosine_similarity_avg']:.3f}, " + f"std={stats['cosine_similarity_std']:.3f}" + ) + + return stats + + +def start_embedding_batch( + session: Session, + openai_client: OpenAI, + eval_run: EvaluationRun, + results: list[dict[str, Any]], + trace_id_mapping: dict[str, str], +) -> EvaluationRun: + """ + Start embedding batch for similarity scoring. + + This function orchestrates the embedding batch creation: + 1. Builds embedding JSONL from evaluation results with trace_ids + 2. Creates batch via generic infrastructure (job_type="embedding") + 3. Links embedding_batch_job_id to eval_run + 4. Keeps status as "processing" + + Args: + session: Database session + openai_client: Configured OpenAI client + eval_run: EvaluationRun database object + results: Parsed evaluation results (output + ground_truth pairs) + trace_id_mapping: Mapping of item_id to Langfuse trace_id + + Returns: + Updated EvaluationRun with embedding_batch_job_id populated + + Raises: + Exception: If any step fails + """ + try: + logger.info(f"Starting embedding batch for evaluation run {eval_run.id}") + + # Get embedding model from config (default: text-embedding-3-large) + embedding_model = eval_run.config.get( + "embedding_model", "text-embedding-3-large" + ) + + # Validate and fallback to default if invalid + try: + validate_embedding_model(embedding_model) + except ValueError as e: + logger.warning( + f"Invalid embedding model '{embedding_model}' in config: {e}. " + f"Falling back to text-embedding-3-large" + ) + embedding_model = "text-embedding-3-large" + + # Step 1: Build embedding JSONL with trace_ids + jsonl_data = build_embedding_jsonl( + results=results, + trace_id_mapping=trace_id_mapping, + embedding_model=embedding_model, + ) + + if not jsonl_data: + raise ValueError("No valid items to create embeddings for") + + # Step 2: Create batch provider + provider = OpenAIBatchProvider(client=openai_client) + + # Step 3: Prepare batch configuration + batch_config = { + "endpoint": "/v1/embeddings", + "description": f"Embeddings for evaluation: {eval_run.run_name}", + "completion_window": "24h", + "embedding_model": embedding_model, + } + + # Step 4: Start batch job using generic infrastructure + batch_job = start_batch_job( + session=session, + provider=provider, + provider_name="openai", + job_type="embedding", + organization_id=eval_run.organization_id, + project_id=eval_run.project_id, + jsonl_data=jsonl_data, + config=batch_config, + ) + + # Step 5: Link embedding_batch_job to evaluation_run + eval_run.embedding_batch_job_id = batch_job.id + # Keep status as "processing" - will change to "completed" after embeddings + eval_run.updated_at = now() + + session.add(eval_run) + session.commit() + session.refresh(eval_run) + + logger.info( + f"Successfully started embedding batch: batch_job_id={batch_job.id}, " + f"provider_batch_id={batch_job.provider_batch_id} " + f"for evaluation run {eval_run.id} with {batch_job.total_items} items" + ) + + return eval_run + + except Exception as e: + logger.error(f"Failed to start embedding batch: {e}", exc_info=True) + # Don't update eval_run status here - let caller decide + raise diff --git a/backend/app/crud/evaluations/langfuse.py b/backend/app/crud/evaluations/langfuse.py new file mode 100644 index 000000000..8117210c8 --- /dev/null +++ b/backend/app/crud/evaluations/langfuse.py @@ -0,0 +1,297 @@ +""" +Langfuse integration for evaluation runs. + +This module handles: +1. Creating dataset runs in Langfuse +2. Creating traces for each evaluation item +3. Uploading results to Langfuse for visualization +""" + +import logging +from typing import Any + +from langfuse import Langfuse + +logger = logging.getLogger(__name__) + + +def create_langfuse_dataset_run( + langfuse: Langfuse, + dataset_name: str, + run_name: str, + results: list[dict[str, Any]], +) -> dict[str, str]: + """ + Create a dataset run in Langfuse with traces for each evaluation item. + + This function: + 1. Gets the dataset from Langfuse (which already exists) + 2. For each result, creates a trace linked to the dataset item + 3. Logs input (question), output (generated_output), and expected (ground_truth) + 4. Returns a mapping of item_id -> trace_id for later score updates + + Args: + langfuse: Configured Langfuse client + dataset_name: Name of the dataset in Langfuse + run_name: Name for this evaluation run + results: List of evaluation results from parse_batch_output() + Format: [ + { + "item_id": "item_123", + "question": "What is 2+2?", + "generated_output": "4", + "ground_truth": "4", + "response_id": "resp_0b99aadfead1fb62006908e7f540c48197bd110183a347c1d8" + }, + ... + ] + + Returns: + dict[str, str]: Mapping of item_id to Langfuse trace_id + + Raises: + Exception: If Langfuse operations fail + """ + logger.info( + f"[create_langfuse_dataset_run] Creating Langfuse dataset run | " + f"run_name={run_name} | dataset={dataset_name} | items={len(results)}" + ) + + try: + # Get the dataset + dataset = langfuse.get_dataset(dataset_name) + dataset_items_map = {item.id: item for item in dataset.items} + + trace_id_mapping = {} + + # Create a trace for each result + for result in results: + item_id = result["item_id"] + question = result["question"] + generated_output = result["generated_output"] + ground_truth = result["ground_truth"] + response_id = result.get("response_id") + + dataset_item = dataset_items_map.get(item_id) + if not dataset_item: + logger.warning( + f"[create_langfuse_dataset_run] Dataset item not found, skipping | " + f"item_id={item_id}" + ) + continue + + try: + with dataset_item.observe(run_name=run_name) as trace_id: + metadata = { + "ground_truth": ground_truth, + "item_id": item_id, + } + if response_id: + metadata["response_id"] = response_id + + langfuse.trace( + id=trace_id, + input={"question": question}, + output={"answer": generated_output}, + metadata=metadata, + ) + trace_id_mapping[item_id] = trace_id + + except Exception as e: + logger.error( + f"[create_langfuse_dataset_run] Failed to create trace | " + f"item_id={item_id} | {e}", + exc_info=True, + ) + continue + + langfuse.flush() + logger.info( + f"[create_langfuse_dataset_run] Created Langfuse dataset run | " + f"run_name={run_name} | traces={len(trace_id_mapping)}" + ) + + return trace_id_mapping + + except Exception as e: + logger.error( + f"[create_langfuse_dataset_run] Failed to create Langfuse dataset run | " + f"run_name={run_name} | {e}", + exc_info=True, + ) + raise + + +def update_traces_with_cosine_scores( + langfuse: Langfuse, + per_item_scores: list[dict[str, Any]], +) -> None: + """ + Update Langfuse traces with cosine similarity scores. + + This function adds custom "cosine_similarity" scores to traces at the trace level, + allowing them to be visualized in the Langfuse UI. + + Args: + langfuse: Configured Langfuse client + per_item_scores: List of per-item score dictionaries from + calculate_average_similarity() + Format: [ + { + "trace_id": "trace-uuid-123", + "cosine_similarity": 0.95 + }, + ... + ] + + Note: + This function logs errors but does not raise exceptions to avoid blocking + evaluation completion if Langfuse updates fail. + """ + for score_item in per_item_scores: + trace_id = score_item.get("trace_id") + cosine_score = score_item.get("cosine_similarity") + + if not trace_id: + logger.warning( + "[update_traces_with_cosine_scores] Score item missing trace_id, skipping" + ) + continue + + try: + langfuse.score( + trace_id=trace_id, + name="cosine_similarity", + value=cosine_score, + comment=( + "Cosine similarity between generated output and " + "ground truth embeddings" + ), + ) + except Exception as e: + logger.error( + f"[update_traces_with_cosine_scores] Failed to add score | " + f"trace_id={trace_id} | {e}", + exc_info=True, + ) + + langfuse.flush() + + +def upload_dataset_to_langfuse_from_csv( + langfuse: Langfuse, + csv_content: bytes, + dataset_name: str, + duplication_factor: int, +) -> tuple[str, int]: + """ + Upload a dataset to Langfuse from CSV content. + + This function parses CSV content and uploads it to Langfuse with duplication. + Used when re-uploading datasets from S3 storage. + + Args: + langfuse: Configured Langfuse client + csv_content: Raw CSV content as bytes + dataset_name: Name for the dataset in Langfuse + duplication_factor: Number of times to duplicate each item + + Returns: + Tuple of (langfuse_dataset_id, total_items_uploaded) + + Raises: + ValueError: If CSV is invalid or empty + Exception: If Langfuse operations fail + """ + import csv + import io + + logger.info( + f"[upload_dataset_to_langfuse_from_csv] Uploading dataset to Langfuse from CSV | " + f"dataset={dataset_name} | duplication_factor={duplication_factor}" + ) + + try: + # Parse CSV content + csv_text = csv_content.decode("utf-8") + csv_reader = csv.DictReader(io.StringIO(csv_text)) + csv_reader.fieldnames = [name.strip() for name in csv_reader.fieldnames] + + # Validate CSV headers + if ( + "question" not in csv_reader.fieldnames + or "answer" not in csv_reader.fieldnames + ): + raise ValueError( + f"CSV must contain 'question' and 'answer' columns. " + f"Found columns: {csv_reader.fieldnames}" + ) + + # Read all rows from CSV + original_items = [] + for row in csv_reader: + question = row.get("question", "").strip() + answer = row.get("answer", "").strip() + + if not question or not answer: + logger.warning( + f"[upload_dataset_to_langfuse_from_csv] Skipping row with empty question or answer | {row}" + ) + continue + + original_items.append({"question": question, "answer": answer}) + + if not original_items: + raise ValueError("No valid items found in CSV file") + + logger.info( + f"[upload_dataset_to_langfuse_from_csv] Parsed items from CSV | " + f"original={len(original_items)} | duplication_factor={duplication_factor} | " + f"total={len(original_items) * duplication_factor}" + ) + + # Create or get dataset in Langfuse + dataset = langfuse.create_dataset(name=dataset_name) + + # Upload items with duplication + total_uploaded = 0 + for item in original_items: + # Duplicate each item N times + for duplicate_num in range(duplication_factor): + try: + langfuse.create_dataset_item( + dataset_name=dataset_name, + input={"question": item["question"]}, + expected_output={"answer": item["answer"]}, + metadata={ + "original_question": item["question"], + "duplicate_number": duplicate_num + 1, + "duplication_factor": duplication_factor, + }, + ) + total_uploaded += 1 + except Exception as e: + logger.error( + f"[upload_dataset_to_langfuse_from_csv] Failed to upload item | " + f"duplicate={duplicate_num + 1} | question={item['question'][:50]}... | {e}" + ) + + # Flush to ensure all items are uploaded + langfuse.flush() + + langfuse_dataset_id = dataset.id if hasattr(dataset, "id") else None + + logger.info( + f"[upload_dataset_to_langfuse_from_csv] Successfully uploaded items to Langfuse dataset | " + f"items={total_uploaded} | dataset={dataset_name} | id={langfuse_dataset_id}" + ) + + return langfuse_dataset_id, total_uploaded + + except Exception as e: + logger.error( + f"[upload_dataset_to_langfuse_from_csv] Failed to upload dataset to Langfuse | " + f"dataset={dataset_name} | {e}", + exc_info=True, + ) + raise diff --git a/backend/app/crud/evaluations/processing.py b/backend/app/crud/evaluations/processing.py new file mode 100644 index 000000000..50698d00c --- /dev/null +++ b/backend/app/crud/evaluations/processing.py @@ -0,0 +1,799 @@ +""" +Evaluation batch processing orchestrator. + +This module coordinates the evaluation-specific workflow: +1. Monitoring batch_job status for evaluations +2. Parsing evaluation results from batch output +3. Creating Langfuse dataset runs with traces +4. Updating evaluation_run with final status and scores +""" + +import ast +import json +import logging +from collections import defaultdict +from typing import Any + +from fastapi import HTTPException +from langfuse import Langfuse +from openai import OpenAI +from sqlmodel import Session, select + +from app.core.batch.openai import OpenAIBatchProvider +from app.crud.batch_job import get_batch_job +from app.crud.batch_operations import ( + download_batch_results, + upload_batch_results_to_object_store, +) +from app.crud.evaluations.batch import fetch_dataset_items +from app.crud.evaluations.core import update_evaluation_run +from app.crud.evaluations.embeddings import ( + calculate_average_similarity, + parse_embedding_results, + start_embedding_batch, +) +from app.crud.evaluations.langfuse import ( + create_langfuse_dataset_run, + update_traces_with_cosine_scores, +) +from app.models import EvaluationRun +from app.utils import get_langfuse_client, get_openai_client + +logger = logging.getLogger(__name__) + + +def parse_evaluation_output( + raw_results: list[dict[str, Any]], dataset_items: list[dict[str, Any]] +) -> list[dict[str, Any]]: + """ + Parse batch output into evaluation results. + + This function extracts the generated output from the batch results + and matches it with the ground truth from the dataset. + + Args: + raw_results: Raw results from batch provider (list of JSONL lines) + dataset_items: Original dataset items (for matching ground truth) + + Returns: + List of results in format: + [ + { + "item_id": "item_123", + "question": "What is 2+2?", + "generated_output": "4", + "ground_truth": "4", + "response_id": "resp_0b99aadfead1fb62006908e7f540c48197bd110183a347c1d8" + }, + ... + ] + """ + # Create lookup map for dataset items by ID + dataset_map = {item["id"]: item for item in dataset_items} + + results = [] + + for line_num, response in enumerate(raw_results, 1): + try: + # Extract custom_id (which is our dataset item ID) + item_id = response.get("custom_id") + if not item_id: + logger.warning( + f"[parse_evaluation_output] No custom_id found, skipping | line={line_num}" + ) + continue + + # Get original dataset item + dataset_item = dataset_map.get(item_id) + if not dataset_item: + logger.warning( + f"[parse_evaluation_output] No dataset item found | line={line_num} | item_id={item_id}" + ) + continue + + # Extract the response body + response_body = response.get("response", {}).get("body", {}) + + # Extract response ID from response.body.id + response_id = response_body.get("id") + + # Handle errors in batch processing + if response.get("error"): + error_msg = response["error"].get("message", "Unknown error") + logger.error( + f"[parse_evaluation_output] Item had error | item_id={item_id} | {error_msg}" + ) + generated_output = f"ERROR: {error_msg}" + else: + # Extract text from output (can be string, list, or complex structure) + output = response_body.get("output", "") + + # If string, try to parse it (may be JSON or Python repr of list) + if isinstance(output, str): + try: + output = json.loads(output) + except (json.JSONDecodeError, ValueError): + try: + output = ast.literal_eval(output) + except (ValueError, SyntaxError): + # Keep as string if parsing fails + generated_output = output + output = None + + # If we have a list structure, extract text from message items + if isinstance(output, list): + generated_output = "" + for item in output: + if isinstance(item, dict) and item.get("type") == "message": + for content in item.get("content", []): + if ( + isinstance(content, dict) + and content.get("type") == "output_text" + ): + generated_output = content.get("text", "") + break + if generated_output: + break + elif output is not None: + # output was not a string and not a list + generated_output = "" + logger.warning( + f"[parse_evaluation_output] Unexpected output type | item_id={item_id} | type={type(output)}" + ) + + # Extract question and ground truth from dataset item + question = dataset_item["input"].get("question", "") + ground_truth = dataset_item["expected_output"].get("answer", "") + + results.append( + { + "item_id": item_id, + "question": question, + "generated_output": generated_output, + "ground_truth": ground_truth, + "response_id": response_id, + } + ) + + except Exception as e: + logger.error( + f"[parse_evaluation_output] Unexpected error | line={line_num} | {e}" + ) + continue + + logger.info( + f"[parse_evaluation_output] Parsed evaluation results | results={len(results)} | output_lines={len(raw_results)}" + ) + return results + + +async def process_completed_evaluation( + eval_run: EvaluationRun, + session: Session, + openai_client: OpenAI, + langfuse: Langfuse, +) -> EvaluationRun: + """ + Process a completed evaluation batch. + + This function: + 1. Downloads batch output from provider + 2. Parses results into question/output/ground_truth format + 3. Creates Langfuse dataset run with traces + 4. Starts embedding batch for similarity scoring (keeps status as "processing") + + Args: + eval_run: EvaluationRun database object + session: Database session + openai_client: Configured OpenAI client + langfuse: Configured Langfuse client + + Returns: + Updated EvaluationRun object (with embedding_batch_job_id set) + + Raises: + Exception: If processing fails + """ + log_prefix = f"[org={eval_run.organization_id}][project={eval_run.project_id}][eval={eval_run.id}]" + logger.info( + f"[process_completed_evaluation] {log_prefix} Processing completed evaluation" + ) + + try: + # Step 1: Get batch_job + if not eval_run.batch_job_id: + raise ValueError(f"EvaluationRun {eval_run.id} has no batch_job_id") + + batch_job = get_batch_job(session=session, batch_job_id=eval_run.batch_job_id) + if not batch_job: + raise ValueError( + f"BatchJob {eval_run.batch_job_id} not found for evaluation {eval_run.id}" + ) + + # Step 2: Create provider and download results + logger.info( + f"[process_completed_evaluation] {log_prefix} Downloading batch results | batch_job_id={batch_job.id}" + ) + provider = OpenAIBatchProvider(client=openai_client) + raw_results = download_batch_results(provider=provider, batch_job=batch_job) + + # Step 2a: Upload raw results to object store for evaluation_run + object_store_url = None + try: + object_store_url = upload_batch_results_to_object_store( + session=session, batch_job=batch_job, results=raw_results + ) + except Exception as store_error: + logger.warning( + f"[process_completed_evaluation] {log_prefix} Object store upload failed | {store_error}" + ) + + # Step 3: Fetch dataset items (needed for matching ground truth) + logger.info( + f"[process_completed_evaluation] {log_prefix} Fetching dataset items | dataset={eval_run.dataset_name}" + ) + dataset_items = fetch_dataset_items( + langfuse=langfuse, dataset_name=eval_run.dataset_name + ) + + # Step 4: Parse evaluation results + results = parse_evaluation_output( + raw_results=raw_results, dataset_items=dataset_items + ) + + if not results: + raise ValueError("No valid results found in batch output") + + # Step 5: Create Langfuse dataset run with traces + trace_id_mapping = create_langfuse_dataset_run( + langfuse=langfuse, + dataset_name=eval_run.dataset_name, + run_name=eval_run.run_name, + results=results, + ) + + # Store object store URL in database + if object_store_url: + eval_run.object_store_url = object_store_url + session.add(eval_run) + session.commit() + + # Step 6: Start embedding batch for similarity scoring + # Pass trace_id_mapping directly without storing in DB + try: + eval_run = start_embedding_batch( + session=session, + openai_client=openai_client, + eval_run=eval_run, + results=results, + trace_id_mapping=trace_id_mapping, + ) + # Note: Status remains "processing" until embeddings complete + + except Exception as e: + logger.error( + f"[process_completed_evaluation] {log_prefix} Failed to start embedding batch | {e}", + exc_info=True, + ) + # Don't fail the entire evaluation, just mark as completed without embeddings + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + status="completed", + error_message=f"Embeddings failed: {str(e)}", + ) + + logger.info( + f"[process_completed_evaluation] {log_prefix} Processed evaluation | items={len(results)}" + ) + + return eval_run + + except Exception as e: + logger.error( + f"[process_completed_evaluation] {log_prefix} Failed to process completed evaluation | {e}", + exc_info=True, + ) + # Mark as failed + return update_evaluation_run( + session=session, + eval_run=eval_run, + status="failed", + error_message=f"Processing failed: {str(e)}", + ) + + +async def process_completed_embedding_batch( + eval_run: EvaluationRun, + session: Session, + openai_client: OpenAI, + langfuse: Langfuse, +) -> EvaluationRun: + """ + Process a completed embedding batch and calculate similarity scores. + + This function: + 1. Downloads embedding batch results + 2. Parses embeddings (output + ground_truth pairs) + 3. Calculates cosine similarity for each pair + 4. Calculates average and statistics + 5. Updates eval_run.score with results + 6. Updates Langfuse traces with per-item cosine similarity scores + 7. Marks evaluation as completed + + Args: + eval_run: EvaluationRun database object + session: Database session + openai_client: Configured OpenAI client + langfuse: Configured Langfuse client + + Returns: + Updated EvaluationRun object with similarity scores + + Raises: + Exception: If processing fails + """ + log_prefix = f"[org={eval_run.organization_id}][project={eval_run.project_id}][eval={eval_run.id}]" + logger.info( + f"[process_completed_embedding_batch] {log_prefix} Processing completed embedding batch" + ) + + try: + # Step 1: Get embedding_batch_job + if not eval_run.embedding_batch_job_id: + raise ValueError( + f"EvaluationRun {eval_run.id} has no embedding_batch_job_id" + ) + + embedding_batch_job = get_batch_job( + session=session, batch_job_id=eval_run.embedding_batch_job_id + ) + if not embedding_batch_job: + raise ValueError( + f"Embedding BatchJob {eval_run.embedding_batch_job_id} not found for evaluation {eval_run.id}" + ) + + # Step 2: Create provider and download results + provider = OpenAIBatchProvider(client=openai_client) + raw_results = download_batch_results( + provider=provider, batch_job=embedding_batch_job + ) + + # Step 3: Parse embedding results + embedding_pairs = parse_embedding_results(raw_results=raw_results) + + if not embedding_pairs: + raise ValueError("No valid embedding pairs found in batch output") + + # Step 4: Calculate similarity scores + similarity_stats = calculate_average_similarity(embedding_pairs=embedding_pairs) + + # Step 5: Update evaluation_run with scores + if eval_run.score is None: + eval_run.score = {} + + eval_run.score["cosine_similarity"] = { + "avg": similarity_stats["cosine_similarity_avg"], + "std": similarity_stats["cosine_similarity_std"], + "total_pairs": similarity_stats["total_pairs"], + } + + # Optionally store per-item scores if not too large + if len(similarity_stats.get("per_item_scores", [])) <= 100: + eval_run.score["cosine_similarity"]["per_item_scores"] = similarity_stats[ + "per_item_scores" + ] + + # Step 6: Update Langfuse traces with cosine similarity scores + logger.info( + f"[process_completed_embedding_batch] {log_prefix} Updating Langfuse traces with cosine similarity scores" + ) + per_item_scores = similarity_stats.get("per_item_scores", []) + if per_item_scores: + try: + update_traces_with_cosine_scores( + langfuse=langfuse, + per_item_scores=per_item_scores, + ) + except Exception as e: + # Log error but don't fail the evaluation + logger.error( + f"[process_completed_embedding_batch] {log_prefix} Failed to update Langfuse traces with scores | {e}", + exc_info=True, + ) + + # Step 7: Mark evaluation as completed + eval_run = update_evaluation_run( + session=session, eval_run=eval_run, status="completed", score=eval_run.score + ) + + logger.info( + f"[process_completed_embedding_batch] {log_prefix} Completed evaluation | avg_similarity={similarity_stats['cosine_similarity_avg']:.3f}" + ) + + return eval_run + + except Exception as e: + logger.error( + f"[process_completed_embedding_batch] {log_prefix} Failed to process completed embedding batch | {e}", + exc_info=True, + ) + # Mark as completed anyway, but with error message + return update_evaluation_run( + session=session, + eval_run=eval_run, + status="completed", + error_message=f"Embedding processing failed: {str(e)}", + ) + + +async def check_and_process_evaluation( + eval_run: EvaluationRun, + session: Session, + openai_client: OpenAI, + langfuse: Langfuse, +) -> dict[str, Any]: + """ + Check evaluation batch status and process if completed. + + This function handles both the response batch and embedding batch: + 1. If embedding_batch_job_id exists, checks and processes embedding batch first + 2. Otherwise, checks and processes the main response batch + 3. Triggers appropriate processing based on batch completion status + + Args: + eval_run: EvaluationRun database object + session: Database session + openai_client: Configured OpenAI client + langfuse: Configured Langfuse client + + Returns: + Dict with status information: + { + "run_id": 123, + "run_name": "test_run", + "previous_status": "processing", + "current_status": "completed", + "batch_status": "completed", + "action": "processed" | "embeddings_completed" | "embeddings_failed" | "failed" | "no_change" + } + """ + log_prefix = f"[org={eval_run.organization_id}][project={eval_run.project_id}][eval={eval_run.id}]" + previous_status = eval_run.status + + try: + # Check if we need to process embedding batch first + if eval_run.embedding_batch_job_id and eval_run.status == "processing": + embedding_batch_job = get_batch_job( + session=session, batch_job_id=eval_run.embedding_batch_job_id + ) + + if embedding_batch_job: + # Poll embedding batch status + provider = OpenAIBatchProvider(client=openai_client) + + # Local import to avoid circular dependency with batch_operations + from app.crud.batch_operations import poll_batch_status + + poll_batch_status( + session=session, provider=provider, batch_job=embedding_batch_job + ) + session.refresh(embedding_batch_job) + + embedding_status = embedding_batch_job.provider_status + + if embedding_status == "completed": + logger.info( + f"[check_and_process_evaluation] {log_prefix} Processing embedding batch | provider_batch_id={embedding_batch_job.provider_batch_id}" + ) + + await process_completed_embedding_batch( + eval_run=eval_run, + session=session, + openai_client=openai_client, + langfuse=langfuse, + ) + + return { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "previous_status": previous_status, + "current_status": eval_run.status, + "provider_status": embedding_status, + "action": "embeddings_completed", + } + + elif embedding_status in ["failed", "expired", "cancelled"]: + logger.error( + f"[check_and_process_evaluation] {log_prefix} Embedding batch failed | provider_batch_id={embedding_batch_job.provider_batch_id} | {embedding_batch_job.error_message}" + ) + # Mark as completed without embeddings + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + status="completed", + error_message=f"Embedding batch failed: {embedding_batch_job.error_message}", + ) + + return { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "previous_status": previous_status, + "current_status": "completed", + "provider_status": embedding_status, + "action": "embeddings_failed", + } + + else: + # Embedding batch still processing + return { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "previous_status": previous_status, + "current_status": eval_run.status, + "provider_status": embedding_status, + "action": "no_change", + } + + # Get batch_job (main response batch) + if not eval_run.batch_job_id: + raise ValueError(f"EvaluationRun {eval_run.id} has no batch_job_id") + + batch_job = get_batch_job(session=session, batch_job_id=eval_run.batch_job_id) + if not batch_job: + raise ValueError( + f"BatchJob {eval_run.batch_job_id} not found for evaluation {eval_run.id}" + ) + + # IMPORTANT: Poll OpenAI to get the latest status before checking + provider = OpenAIBatchProvider(client=openai_client) + from app.crud.batch_operations import poll_batch_status + + poll_batch_status(session=session, provider=provider, batch_job=batch_job) + + # Refresh batch_job to get the updated provider_status + session.refresh(batch_job) + provider_status = batch_job.provider_status + + # Handle different provider statuses + if provider_status == "completed": + # Process the completed evaluation + await process_completed_evaluation( + eval_run=eval_run, + session=session, + openai_client=openai_client, + langfuse=langfuse, + ) + + return { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "previous_status": previous_status, + "current_status": eval_run.status, + "provider_status": provider_status, + "action": "processed", + } + + elif provider_status in ["failed", "expired", "cancelled"]: + # Mark evaluation as failed based on provider status + error_msg = batch_job.error_message or f"Provider batch {provider_status}" + + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + status="failed", + error_message=error_msg, + ) + + logger.error( + f"[check_and_process_evaluation] {log_prefix} Batch failed | provider_batch_id={batch_job.provider_batch_id} | {error_msg}" + ) + + return { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "previous_status": previous_status, + "current_status": "failed", + "provider_status": provider_status, + "action": "failed", + "error": error_msg, + } + + else: + # Still in progress (validating, in_progress, finalizing) + return { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "previous_status": previous_status, + "current_status": eval_run.status, + "provider_status": provider_status, + "action": "no_change", + } + + except Exception as e: + logger.error( + f"[check_and_process_evaluation] {log_prefix} Error checking evaluation | {e}", + exc_info=True, + ) + + # Mark as failed + update_evaluation_run( + session=session, + eval_run=eval_run, + status="failed", + error_message=f"Checking failed: {str(e)}", + ) + + return { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "previous_status": previous_status, + "current_status": "failed", + "provider_status": "unknown", + "action": "failed", + "error": str(e), + } + + +async def poll_all_pending_evaluations(session: Session, org_id: int) -> dict[str, Any]: + """ + Poll all pending evaluations for an organization. + + Args: + session: Database session + org_id: Organization ID + + Returns: + Summary dict: + { + "total": 5, + "processed": 2, + "failed": 1, + "still_processing": 2, + "details": [...] + } + """ + # Get pending evaluations (status = "processing") + statement = select(EvaluationRun).where( + EvaluationRun.status == "processing", + EvaluationRun.organization_id == org_id, + ) + pending_runs = session.exec(statement).all() + + if not pending_runs: + return { + "total": 0, + "processed": 0, + "failed": 0, + "still_processing": 0, + "details": [], + } + # Group evaluations by project_id since credentials are per project + evaluations_by_project = defaultdict(list) + for run in pending_runs: + evaluations_by_project[run.project_id].append(run) + + # Process each project separately + all_results = [] + total_processed_count = 0 + total_failed_count = 0 + total_still_processing_count = 0 + + for project_id, project_runs in evaluations_by_project.items(): + try: + # Get API clients for this project + try: + openai_client = get_openai_client( + session=session, + org_id=org_id, + project_id=project_id, + ) + langfuse = get_langfuse_client( + session=session, + org_id=org_id, + project_id=project_id, + ) + except HTTPException as http_exc: + logger.error( + f"[poll_all_pending_evaluations] Failed to get API clients | org_id={org_id} | project_id={project_id} | error={http_exc.detail}" + ) + # Mark all runs in this project as failed due to client configuration error + for eval_run in project_runs: + # Persist failure status to database + update_evaluation_run( + session=session, + eval_run=eval_run, + status="failed", + error_message=http_exc.detail, + ) + + all_results.append( + { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "action": "failed", + "error": http_exc.detail, + } + ) + total_failed_count += 1 + continue + + # Process each evaluation in this project + for eval_run in project_runs: + try: + result = await check_and_process_evaluation( + eval_run=eval_run, + session=session, + openai_client=openai_client, + langfuse=langfuse, + ) + all_results.append(result) + + if result["action"] == "processed": + total_processed_count += 1 + elif result["action"] == "failed": + total_failed_count += 1 + else: + total_still_processing_count += 1 + + except Exception as e: + logger.error( + f"[poll_all_pending_evaluations] Failed to check evaluation run | run_id={eval_run.id} | {e}", + exc_info=True, + ) + # Persist failure status to database + update_evaluation_run( + session=session, + eval_run=eval_run, + status="failed", + error_message=f"Check failed: {str(e)}", + ) + + all_results.append( + { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "action": "failed", + "error": str(e), + } + ) + total_failed_count += 1 + + except Exception as e: + logger.error( + f"[poll_all_pending_evaluations] Failed to process project | project_id={project_id} | {e}", + exc_info=True, + ) + # Mark all runs in this project as failed + for eval_run in project_runs: + # Persist failure status to database + update_evaluation_run( + session=session, + eval_run=eval_run, + status="failed", + error_message=f"Project processing failed: {str(e)}", + ) + + all_results.append( + { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "action": "failed", + "error": f"Project processing failed: {str(e)}", + } + ) + total_failed_count += 1 + + summary = { + "total": len(pending_runs), + "processed": total_processed_count, + "failed": total_failed_count, + "still_processing": total_still_processing_count, + "details": all_results, + } + + logger.info( + f"[poll_all_pending_evaluations] Polling summary | org_id={org_id} | processed={total_processed_count} | failed={total_failed_count} | still_processing={total_still_processing_count}" + ) + + return summary diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 791da4ba1..6a9f41852 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -37,6 +37,22 @@ ) from .document_collection import DocumentCollection +from .batch_job import ( + BatchJob, + BatchJobCreate, + BatchJobPublic, + BatchJobUpdate, +) + +from .evaluation import ( + EvaluationDataset, + EvaluationDatasetCreate, + EvaluationDatasetPublic, + EvaluationRun, + EvaluationRunCreate, + EvaluationRunPublic, +) + from .fine_tuning import ( FineTuningJobBase, Fine_Tuning, diff --git a/backend/app/models/batch_job.py b/backend/app/models/batch_job.py new file mode 100644 index 000000000..3ef07f7f1 --- /dev/null +++ b/backend/app/models/batch_job.py @@ -0,0 +1,129 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional + +from sqlalchemy import Column +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field, Relationship, SQLModel + +from app.core.util import now + +if TYPE_CHECKING: + from .organization import Organization + from .project import Project + + +class BatchJob(SQLModel, table=True): + """Batch job table for tracking async LLM batch operations.""" + + __tablename__ = "batch_job" + + id: int | None = Field(default=None, primary_key=True) + + # Provider and job type + provider: str = Field(description="LLM provider name (e.g., 'openai', 'anthropic')") + job_type: str = Field( + description="Type of batch job (e.g., 'evaluation', 'classification', 'embedding')" + ) + + # Batch configuration - stores all provider-specific config + config: dict[str, Any] = Field( + default_factory=dict, + sa_column=Column(JSONB()), + description="Complete batch configuration including model, temperature, instructions, tools, etc.", + ) + + # Provider-specific batch tracking + provider_batch_id: str | None = Field( + default=None, description="Provider's batch job ID (e.g., OpenAI batch_id)" + ) + provider_file_id: str | None = Field( + default=None, description="Provider's input file ID" + ) + provider_output_file_id: str | None = Field( + default=None, description="Provider's output file ID" + ) + + # Provider status tracking + provider_status: str | None = Field( + default=None, + description="Provider-specific status (e.g., OpenAI: validating, in_progress, finalizing, completed, failed, expired, cancelling, cancelled)", + ) + + # Raw results (before parent-specific processing) + raw_output_url: str | None = Field( + default=None, description="S3 URL of raw batch output file" + ) + total_items: int = Field( + default=0, description="Total number of items in the batch" + ) + + # Error handling + error_message: str | None = Field( + default=None, description="Error message if batch failed" + ) + + # Foreign keys + organization_id: int = Field(foreign_key="organization.id") + project_id: int = Field(foreign_key="project.id") + + # Timestamps + inserted_at: datetime = Field( + default_factory=now, description="The timestamp when the batch job was started" + ) + updated_at: datetime = Field( + default_factory=now, + description="The timestamp when the batch job was last updated", + ) + + # Relationships + organization: Optional["Organization"] = Relationship() + project: Optional["Project"] = Relationship() + + +class BatchJobCreate(SQLModel): + """Schema for creating a new batch job.""" + + provider: str + job_type: str + config: dict[str, Any] = Field(default_factory=dict) + provider_batch_id: str | None = None + provider_file_id: str | None = None + provider_output_file_id: str | None = None + provider_status: str | None = None + raw_output_url: str | None = None + total_items: int = 0 + error_message: str | None = None + organization_id: int + project_id: int + + +class BatchJobUpdate(SQLModel): + """Schema for updating a batch job.""" + + provider_batch_id: str | None = None + provider_file_id: str | None = None + provider_output_file_id: str | None = None + provider_status: str | None = None + raw_output_url: str | None = None + total_items: int | None = None + error_message: str | None = None + + +class BatchJobPublic(SQLModel): + """Public schema for batch job responses.""" + + id: int + provider: str + job_type: str + config: dict[str, Any] + provider_batch_id: str | None + provider_file_id: str | None + provider_output_file_id: str | None + provider_status: str | None + raw_output_url: str | None + total_items: int + error_message: str | None + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py new file mode 100644 index 000000000..57a83d35d --- /dev/null +++ b/backend/app/models/evaluation.py @@ -0,0 +1,299 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import BaseModel, Field +from sqlalchemy import JSON, Column, Text, UniqueConstraint +from sqlmodel import Field as SQLField +from sqlmodel import Relationship, SQLModel + +from app.core.util import now + +if TYPE_CHECKING: + from .batch_job import BatchJob + from .organization import Organization + from .project import Project + + +class DatasetItem(BaseModel): + """Model for a single dataset item (Q&A pair).""" + + question: str = Field(..., description="The question/input") + answer: str = Field(..., description="The expected answer/output") + + +class DatasetUploadResponse(BaseModel): + """Response model for dataset upload.""" + + dataset_id: int = Field(..., description="Database ID of the created dataset") + dataset_name: str = Field(..., description="Name of the created dataset") + total_items: int = Field( + ..., description="Total number of items uploaded (after duplication)" + ) + original_items: int = Field( + ..., description="Number of original items before duplication" + ) + duplication_factor: int = Field( + default=5, description="Number of times each item was duplicated" + ) + langfuse_dataset_id: str | None = Field( + None, description="Langfuse dataset ID if available" + ) + object_store_url: str | None = Field( + None, description="Object store URL if uploaded" + ) + + +class EvaluationResult(BaseModel): + """Model for a single evaluation result.""" + + input: str = Field(..., description="The input question/prompt used for evaluation") + output: str = Field(..., description="The actual output from the assistant") + expected: str = Field(..., description="The expected output from the dataset") + response_id: str | None = Field(None, description="ID from the batch response body") + + +class Experiment(BaseModel): + """Model for the complete experiment evaluation response.""" + + experiment_name: str = Field(..., description="Name of the experiment") + dataset_name: str = Field( + ..., description="Name of the dataset used for evaluation" + ) + results: list[EvaluationResult] = Field( + ..., description="List of evaluation results" + ) + total_items: int = Field(..., description="Total number of items evaluated") + note: str = Field(..., description="Additional notes about the evaluation process") + + +# Database Models + + +class EvaluationDataset(SQLModel, table=True): + """Database table for evaluation datasets.""" + + __tablename__ = "evaluation_dataset" + __table_args__ = ( + UniqueConstraint( + "name", + "organization_id", + "project_id", + name="uq_evaluation_dataset_name_org_project", + ), + ) + + id: int = SQLField(default=None, primary_key=True) + + # Dataset information + name: str = SQLField(index=True, description="Name of the dataset") + description: str | None = SQLField( + default=None, description="Optional description of the dataset" + ) + + # Dataset metadata stored as JSON + dataset_metadata: dict[str, Any] = SQLField( + default_factory=dict, + sa_column=Column(JSON), + description=( + "Dataset metadata (original_items_count, total_items_count, " + "duplication_factor)" + ), + ) + + # Storage references + object_store_url: str | None = SQLField( + default=None, description="Object store URL where CSV is stored" + ) + langfuse_dataset_id: str | None = SQLField( + default=None, description="Langfuse dataset ID for reference" + ) + + # Foreign keys + organization_id: int = SQLField( + foreign_key="organization.id", nullable=False, ondelete="CASCADE" + ) + project_id: int = SQLField( + foreign_key="project.id", nullable=False, ondelete="CASCADE" + ) + + # Timestamps + inserted_at: datetime = SQLField(default_factory=now, nullable=False) + updated_at: datetime = SQLField(default_factory=now, nullable=False) + + # Relationships + project: "Project" = Relationship() + organization: "Organization" = Relationship() + evaluation_runs: list["EvaluationRun"] = Relationship( + back_populates="evaluation_dataset" + ) + + +class EvaluationRun(SQLModel, table=True): + """Database table for evaluation runs.""" + + __tablename__ = "evaluation_run" + + id: int = SQLField(default=None, primary_key=True) + + # Input fields (provided by user) + run_name: str = SQLField(index=True, description="Name of the evaluation run") + dataset_name: str = SQLField(description="Name of the Langfuse dataset") + + # Config field - dict requires sa_column + config: dict[str, Any] = SQLField( + default_factory=dict, + sa_column=Column(JSON), + description="Evaluation configuration", + ) + + # Dataset reference + dataset_id: int = SQLField( + foreign_key="evaluation_dataset.id", + nullable=False, + ondelete="CASCADE", + description="Reference to the evaluation_dataset used for this run", + ) + + # Batch job references + batch_job_id: int | None = SQLField( + default=None, + foreign_key="batch_job.id", + description=( + "Reference to the batch_job that processes this evaluation " "(responses)" + ), + ) + embedding_batch_job_id: int | None = SQLField( + default=None, + foreign_key="batch_job.id", + description="Reference to the batch_job for embedding-based similarity scoring", + ) + + # Output/Status fields (updated by system during processing) + status: str = SQLField( + default="pending", + description="Overall evaluation status: pending, processing, completed, failed", + ) + object_store_url: str | None = SQLField( + default=None, + description="Object store URL of processed evaluation results for future reference", + ) + total_items: int = SQLField( + default=0, description="Total number of items evaluated (set during processing)" + ) + + # Score field - dict requires sa_column + score: dict[str, Any] | None = SQLField( + default=None, + sa_column=Column(JSON, nullable=True), + description="Evaluation scores (e.g., correctness, cosine_similarity, etc.)", + ) + + # Error message field + error_message: str | None = SQLField( + default=None, + sa_column=Column(Text, nullable=True), + description="Error message if failed", + ) + + # Foreign keys + organization_id: int = SQLField( + foreign_key="organization.id", nullable=False, ondelete="CASCADE" + ) + project_id: int = SQLField( + foreign_key="project.id", nullable=False, ondelete="CASCADE" + ) + + # Timestamps + inserted_at: datetime = Field( + default_factory=now, + description="The timestamp when the evaluation run was started", + ) + updated_at: datetime = Field( + default_factory=now, + description="The timestamp when the evaluation run was last updated", + ) + + # Relationships + project: "Project" = Relationship() + organization: "Organization" = Relationship() + evaluation_dataset: "EvaluationDataset" = Relationship( + back_populates="evaluation_runs" + ) + batch_job: Optional["BatchJob"] = Relationship( + sa_relationship_kwargs={"foreign_keys": "[EvaluationRun.batch_job_id]"} + ) + embedding_batch_job: Optional["BatchJob"] = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[EvaluationRun.embedding_batch_job_id]" + } + ) + + +class EvaluationRunCreate(SQLModel): + """Model for creating an evaluation run.""" + + run_name: str = Field(description="Name of the evaluation run", min_length=3) + dataset_id: int = Field(description="ID of the evaluation dataset") + config: dict[str, Any] = Field( + default_factory=dict, + description=( + "Evaluation configuration (flexible dict with llm, instructions, " + "vector_store_ids, etc.)" + ), + ) + + +class EvaluationRunPublic(SQLModel): + """Public model for evaluation runs.""" + + id: int + run_name: str + dataset_name: str + config: dict[str, Any] + dataset_id: int + batch_job_id: int | None + embedding_batch_job_id: int | None + status: str + object_store_url: str | None + total_items: int + score: dict[str, Any] | None + error_message: str | None + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class EvaluationDatasetCreate(SQLModel): + """Model for creating an evaluation dataset.""" + + name: str = Field(description="Name of the dataset", min_length=1) + description: str | None = Field(None, description="Optional dataset description") + dataset_metadata: dict[str, Any] = Field( + default_factory=dict, + description=( + "Dataset metadata (original_items_count, total_items_count, " + "duplication_factor)" + ), + ) + object_store_url: str | None = Field( + None, description="Object store URL where CSV is stored" + ) + langfuse_dataset_id: str | None = Field( + None, description="Langfuse dataset ID for reference" + ) + + +class EvaluationDatasetPublic(SQLModel): + """Public model for evaluation datasets.""" + + id: int + name: str + description: str | None + dataset_metadata: dict[str, Any] + object_store_url: str | None + langfuse_dataset_id: str | None + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 4c729cdd1..db660891a 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -1,16 +1,16 @@ from datetime import datetime -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING + from sqlmodel import Field, Relationship, SQLModel from app.core.util import now if TYPE_CHECKING: - from .credentials import Credential - from .project import Project - from .api_key import APIKey from .assistants import Assistant from .collection import Collection + from .credentials import Credential from .openai_conversation import OpenAIConversation + from .project import Project # Shared properties for an Organization diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 2a1d346ae..c0d8a87ac 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -1,10 +1,19 @@ -from uuid import UUID, uuid4 from datetime import datetime -from typing import Optional, List +from typing import TYPE_CHECKING, Optional +from uuid import UUID, uuid4 + from sqlmodel import Field, Relationship, SQLModel, UniqueConstraint from app.core.util import now +if TYPE_CHECKING: + from .assistants import Assistant + from .collection import Collection + from .credentials import Credential + from .fine_tuning import Fine_Tuning + from .openai_conversation import OpenAIConversation + from .organization import Organization + # Shared properties for a Project class ProjectBase(SQLModel): diff --git a/backend/app/tests/api/routes/test_evaluation.py b/backend/app/tests/api/routes/test_evaluation.py new file mode 100644 index 000000000..5538ab146 --- /dev/null +++ b/backend/app/tests/api/routes/test_evaluation.py @@ -0,0 +1,688 @@ +import io +from unittest.mock import Mock, patch + +import pytest +from sqlmodel import select + +from app.crud.evaluations.batch import build_evaluation_jsonl +from app.models import EvaluationDataset + + +# Helper function to create CSV file-like object +def create_csv_file(content: str) -> tuple[str, io.BytesIO]: + """Create a CSV file-like object for testing.""" + file_obj = io.BytesIO(content.encode("utf-8")) + return ("test.csv", file_obj) + + +@pytest.fixture +def valid_csv_content(): + """Valid CSV content with question and answer columns.""" + return """question,answer +"Who is known as the strongest jujutsu sorcerer?","Satoru Gojo" +"What is the name of Gojo’s Domain Expansion?","Infinite Void" +"Who is known as the King of Curses?","Ryomen Sukuna" +""" + + +@pytest.fixture +def invalid_csv_missing_columns(): + """CSV content missing required columns.""" + return """query,response +"Who is known as the strongest jujutsu sorcerer?","Satoru Gojo" +""" + + +@pytest.fixture +def csv_with_empty_rows(): + """CSV content with some empty rows.""" + return """question,answer +"Who is known as the strongest jujutsu sorcerer?","Satoru Gojo" +"","4" +"Who wrote Romeo and Juliet?","" +"Valid question","Valid answer" +""" + + +class TestDatasetUploadValidation: + """Test CSV validation and parsing.""" + + def test_upload_dataset_valid_csv( + self, client, user_api_key_header, valid_csv_content, db + ): + """Test uploading a valid CSV file.""" + with ( + patch("app.core.cloud.get_cloud_storage") as _mock_storage, + patch( + "app.api.routes.evaluation.upload_csv_to_object_store" + ) as mock_store_upload, + patch( + "app.api.routes.evaluation.get_langfuse_client" + ) as mock_get_langfuse_client, + patch( + "app.api.routes.evaluation.upload_dataset_to_langfuse_from_csv" + ) as mock_langfuse_upload, + ): + # Mock object store upload + mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" + + # Mock Langfuse client + mock_get_langfuse_client.return_value = Mock() + + # Mock Langfuse upload + mock_langfuse_upload.return_value = ("test_dataset_id", 9) + + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "description": "Test dataset description", + "duplication_factor": 3, + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + data = response.json() + + assert data["dataset_name"] == "test_dataset" + assert data["original_items"] == 3 + assert data["total_items"] == 9 # 3 items * 3 duplication + assert data["duplication_factor"] == 3 + assert data["langfuse_dataset_id"] == "test_dataset_id" + assert data["object_store_url"] == "s3://bucket/datasets/test_dataset.csv" + assert "dataset_id" in data + + # Verify object store upload was called + mock_store_upload.assert_called_once() + + # Verify Langfuse upload was called + mock_langfuse_upload.assert_called_once() + + def test_upload_dataset_missing_columns( + self, + client, + user_api_key_header, + invalid_csv_missing_columns, + ): + """Test uploading CSV with missing required columns.""" + filename, file_obj = create_csv_file(invalid_csv_missing_columns) + + # The CSV validation happens before any mocked functions are called + # so this test checks the actual validation logic + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 5, + }, + headers=user_api_key_header, + ) + + # Check that the response indicates unprocessable entity + assert response.status_code == 422 + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("message", str(response_data)) + ) + assert "question" in error_str.lower() or "answer" in error_str.lower() + + def test_upload_dataset_empty_rows( + self, client, user_api_key_header, csv_with_empty_rows + ): + """Test uploading CSV with empty rows (should skip them).""" + with ( + patch("app.core.cloud.get_cloud_storage") as _mock_storage, + patch( + "app.api.routes.evaluation.upload_csv_to_object_store" + ) as mock_store_upload, + patch( + "app.api.routes.evaluation.get_langfuse_client" + ) as mock_get_langfuse_client, + patch( + "app.api.routes.evaluation.upload_dataset_to_langfuse_from_csv" + ) as mock_langfuse_upload, + ): + # Mock object store and Langfuse uploads + mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" + mock_get_langfuse_client.return_value = Mock() + mock_langfuse_upload.return_value = ("test_dataset_id", 4) + + filename, file_obj = create_csv_file(csv_with_empty_rows) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 2, + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + data = response.json() + + # Should only have 2 valid items (first and last rows) + assert data["original_items"] == 2 + assert data["total_items"] == 4 # 2 items * 2 duplication + + +class TestDatasetUploadDuplication: + """Test duplication logic.""" + + def test_upload_with_default_duplication( + self, client, user_api_key_header, valid_csv_content + ): + """Test uploading with default duplication factor (5).""" + with ( + patch("app.core.cloud.get_cloud_storage") as _mock_storage, + patch( + "app.api.routes.evaluation.upload_csv_to_object_store" + ) as mock_store_upload, + patch( + "app.api.routes.evaluation.get_langfuse_client" + ) as mock_get_langfuse_client, + patch( + "app.api.routes.evaluation.upload_dataset_to_langfuse_from_csv" + ) as mock_langfuse_upload, + ): + mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" + mock_get_langfuse_client.return_value = Mock() + mock_langfuse_upload.return_value = ("test_dataset_id", 15) + + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + # duplication_factor not provided, should default to 5 + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + data = response.json() + + assert data["duplication_factor"] == 5 + assert data["original_items"] == 3 + assert data["total_items"] == 15 # 3 items * 5 duplication + + def test_upload_with_custom_duplication( + self, client, user_api_key_header, valid_csv_content + ): + """Test uploading with custom duplication factor.""" + with ( + patch("app.core.cloud.get_cloud_storage") as _mock_storage, + patch( + "app.api.routes.evaluation.upload_csv_to_object_store" + ) as mock_store_upload, + patch( + "app.api.routes.evaluation.get_langfuse_client" + ) as mock_get_langfuse_client, + patch( + "app.api.routes.evaluation.upload_dataset_to_langfuse_from_csv" + ) as mock_langfuse_upload, + ): + mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" + mock_get_langfuse_client.return_value = Mock() + mock_langfuse_upload.return_value = ("test_dataset_id", 12) + + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 4, + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + data = response.json() + + assert data["duplication_factor"] == 4 + assert data["original_items"] == 3 + assert data["total_items"] == 12 # 3 items * 4 duplication + + def test_upload_with_description( + self, client, user_api_key_header, valid_csv_content, db + ): + """Test uploading with a description.""" + with ( + patch("app.core.cloud.get_cloud_storage") as _mock_storage, + patch( + "app.api.routes.evaluation.upload_csv_to_object_store" + ) as mock_store_upload, + patch( + "app.api.routes.evaluation.get_langfuse_client" + ) as mock_get_langfuse_client, + patch( + "app.api.routes.evaluation.upload_dataset_to_langfuse_from_csv" + ) as mock_langfuse_upload, + ): + mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" + mock_get_langfuse_client.return_value = Mock() + mock_langfuse_upload.return_value = ("test_dataset_id", 9) + + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset_with_description", + "description": "This is a test dataset for evaluation", + "duplication_factor": 3, + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + data = response.json() + + # Verify the description is stored + dataset = db.exec( + select(EvaluationDataset).where( + EvaluationDataset.id == data["dataset_id"] + ) + ).first() + + assert dataset is not None + assert dataset.description == "This is a test dataset for evaluation" + + def test_upload_with_duplication_factor_below_minimum( + self, client, user_api_key_header, valid_csv_content + ): + """Test uploading with duplication factor below minimum (0).""" + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 0, + }, + headers=user_api_key_header, + ) + + assert response.status_code == 422 + response_data = response.json() + # Check that the error mentions validation and minimum value + assert "error" in response_data + assert "greater than or equal to 1" in response_data["error"] + + def test_upload_with_duplication_factor_above_maximum( + self, client, user_api_key_header, valid_csv_content + ): + """Test uploading with duplication factor above maximum (6).""" + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 6, + }, + headers=user_api_key_header, + ) + + assert response.status_code == 422 + response_data = response.json() + # Check that the error mentions validation and maximum value + assert "error" in response_data + assert "less than or equal to 5" in response_data["error"] + + def test_upload_with_duplication_factor_boundary_minimum( + self, client, user_api_key_header, valid_csv_content + ): + """Test uploading with duplication factor at minimum boundary (1).""" + with ( + patch("app.core.cloud.get_cloud_storage") as _mock_storage, + patch( + "app.api.routes.evaluation.upload_csv_to_object_store" + ) as mock_store_upload, + patch( + "app.api.routes.evaluation.get_langfuse_client" + ) as mock_get_langfuse_client, + patch( + "app.api.routes.evaluation.upload_dataset_to_langfuse_from_csv" + ) as mock_langfuse_upload, + ): + mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" + mock_get_langfuse_client.return_value = Mock() + mock_langfuse_upload.return_value = ("test_dataset_id", 3) + + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 1, + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + data = response.json() + + assert data["duplication_factor"] == 1 + assert data["original_items"] == 3 + assert data["total_items"] == 3 # 3 items * 1 duplication + + +class TestDatasetUploadErrors: + """Test error handling.""" + + def test_upload_langfuse_configuration_fails( + self, client, user_api_key_header, valid_csv_content + ): + """Test when Langfuse client configuration fails.""" + with ( + patch("app.core.cloud.get_cloud_storage") as _mock_storage, + patch( + "app.api.routes.evaluation.upload_csv_to_object_store" + ) as mock_store_upload, + patch("app.crud.credentials.get_provider_credential") as mock_get_cred, + ): + # Mock object store upload succeeds + mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" + # Mock Langfuse credentials not found + mock_get_cred.return_value = None + + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 5, + }, + headers=user_api_key_header, + ) + + # Accept either 400 (credentials not configured) or 500 (configuration/auth fails) + assert response.status_code in [400, 500] + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("message", str(response_data)) + ) + assert ( + "langfuse" in error_str.lower() + or "credential" in error_str.lower() + or "unauthorized" in error_str.lower() + ) + + def test_upload_invalid_csv_format(self, client, user_api_key_header): + """Test uploading invalid CSV format.""" + invalid_csv = "not,a,valid\ncsv format here!!!" + filename, file_obj = create_csv_file(invalid_csv) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 5, + }, + headers=user_api_key_header, + ) + + # Should fail validation - check error contains expected message + assert response.status_code == 422 + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("message", str(response_data)) + ) + assert ( + "question" in error_str.lower() + or "answer" in error_str.lower() + or "invalid" in error_str.lower() + ) + + def test_upload_without_authentication(self, client, valid_csv_content): + """Test uploading without authentication.""" + filename, file_obj = create_csv_file(valid_csv_content) + + response = client.post( + "/api/v1/evaluations/datasets", + files={"file": (filename, file_obj, "text/csv")}, + data={ + "dataset_name": "test_dataset", + "duplication_factor": 5, + }, + ) + + assert response.status_code == 401 # Unauthorized + + +class TestBatchEvaluation: + """Test batch evaluation endpoint using OpenAI Batch API.""" + + @pytest.fixture + def sample_evaluation_config(self): + """Sample evaluation configuration.""" + return { + "model": "gpt-4o", + "temperature": 0.2, + "instructions": "You are a helpful assistant", + } + + def test_start_batch_evaluation_invalid_dataset_id( + self, client, user_api_key_header, sample_evaluation_config + ): + """Test batch evaluation fails with invalid dataset_id.""" + # Try to start evaluation with non-existent dataset_id + response = client.post( + "/api/v1/evaluations", + json={ + "experiment_name": "test_evaluation_run", + "dataset_id": 99999, # Non-existent + "config": sample_evaluation_config, + }, + headers=user_api_key_header, + ) + + assert response.status_code == 404 + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("message", str(response_data)) + ) + assert "not found" in error_str.lower() or "not accessible" in error_str.lower() + + def test_start_batch_evaluation_missing_model(self, client, user_api_key_header): + """Test batch evaluation fails when model is missing from config.""" + # We don't need a real dataset for this test - the validation should happen + # before dataset lookup. Use any dataset_id and expect config validation error + invalid_config = { + "instructions": "You are a helpful assistant", + "temperature": 0.5, + } + + response = client.post( + "/api/v1/evaluations", + json={ + "experiment_name": "test_no_model", + "dataset_id": 1, # Dummy ID, error should come before this is checked + "config": invalid_config, + }, + headers=user_api_key_header, + ) + + # Should fail with either 400 (model missing) or 404 (dataset not found) + assert response.status_code in [400, 404] + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("message", str(response_data)) + ) + # Should fail with either "model" missing or "dataset not found" (both acceptable) + assert "model" in error_str.lower() or "not found" in error_str.lower() + + def test_start_batch_evaluation_without_authentication( + self, client, sample_evaluation_config + ): + """Test batch evaluation requires authentication.""" + response = client.post( + "/api/v1/evaluations", + json={ + "experiment_name": "test_evaluation_run", + "dataset_id": 1, + "config": sample_evaluation_config, + }, + ) + + assert response.status_code == 401 # Unauthorized + + +class TestBatchEvaluationJSONLBuilding: + """Test JSONL building logic for batch evaluation.""" + + def test_build_batch_jsonl_basic(self): + """Test basic JSONL building with minimal config.""" + dataset_items = [ + { + "id": "item1", + "input": {"question": "What is 2+2?"}, + "expected_output": {"answer": "4"}, + "metadata": {}, + } + ] + + config = { + "model": "gpt-4o", + "temperature": 0.2, + "instructions": "You are a helpful assistant", + } + + jsonl_data = build_evaluation_jsonl(dataset_items, config) + + assert len(jsonl_data) == 1 + assert isinstance(jsonl_data[0], dict) + + request = jsonl_data[0] + assert request["custom_id"] == "item1" + assert request["method"] == "POST" + assert request["url"] == "/v1/responses" + assert request["body"]["model"] == "gpt-4o" + assert request["body"]["temperature"] == 0.2 + assert request["body"]["instructions"] == "You are a helpful assistant" + assert request["body"]["input"] == "What is 2+2?" + + def test_build_batch_jsonl_with_tools(self): + """Test JSONL building with tools configuration.""" + dataset_items = [ + { + "id": "item1", + "input": {"question": "Search the docs"}, + "expected_output": {"answer": "Answer from docs"}, + "metadata": {}, + } + ] + + config = { + "model": "gpt-4o-mini", + "instructions": "Search documents", + "tools": [ + { + "type": "file_search", + "vector_store_ids": ["vs_abc123"], + } + ], + } + + jsonl_data = build_evaluation_jsonl(dataset_items, config) + + assert len(jsonl_data) == 1 + request = jsonl_data[0] + assert request["body"]["tools"][0]["type"] == "file_search" + assert "vs_abc123" in request["body"]["tools"][0]["vector_store_ids"] + + def test_build_batch_jsonl_minimal_config(self): + """Test JSONL building with minimal config (only model required).""" + dataset_items = [ + { + "id": "item1", + "input": {"question": "Test question"}, + "expected_output": {"answer": "Test answer"}, + "metadata": {}, + } + ] + + config = {"model": "gpt-4o"} # Only model provided + + jsonl_data = build_evaluation_jsonl(dataset_items, config) + + assert len(jsonl_data) == 1 + request = jsonl_data[0] + assert request["body"]["model"] == "gpt-4o" + assert request["body"]["input"] == "Test question" + + def test_build_batch_jsonl_skips_empty_questions(self): + """Test that items with empty questions are skipped.""" + dataset_items = [ + { + "id": "item1", + "input": {"question": "Valid question"}, + "expected_output": {"answer": "Answer"}, + "metadata": {}, + }, + { + "id": "item2", + "input": {"question": ""}, # Empty question + "expected_output": {"answer": "Answer"}, + "metadata": {}, + }, + { + "id": "item3", + "input": {}, # Missing question key + "expected_output": {"answer": "Answer"}, + "metadata": {}, + }, + ] + + config = {"model": "gpt-4o", "instructions": "Test"} + + jsonl_data = build_evaluation_jsonl(dataset_items, config) + + # Should only have 1 valid item + assert len(jsonl_data) == 1 + assert jsonl_data[0]["custom_id"] == "item1" + + def test_build_batch_jsonl_multiple_items(self): + """Test JSONL building with multiple items.""" + dataset_items = [ + { + "id": f"item{i}", + "input": {"question": f"Question {i}"}, + "expected_output": {"answer": f"Answer {i}"}, + "metadata": {}, + } + for i in range(5) + ] + + config = { + "model": "gpt-4o", + "instructions": "Answer questions", + } + + jsonl_data = build_evaluation_jsonl(dataset_items, config) + + assert len(jsonl_data) == 5 + + for i, request_dict in enumerate(jsonl_data): + assert request_dict["custom_id"] == f"item{i}" + assert request_dict["body"]["input"] == f"Question {i}" + assert request_dict["body"]["model"] == "gpt-4o" diff --git a/backend/app/tests/crud/evaluations/__init__.py b/backend/app/tests/crud/evaluations/__init__.py new file mode 100644 index 000000000..e99ebf01b --- /dev/null +++ b/backend/app/tests/crud/evaluations/__init__.py @@ -0,0 +1 @@ +"""Tests for evaluation-related CRUD operations.""" diff --git a/backend/app/tests/crud/evaluations/test_dataset.py b/backend/app/tests/crud/evaluations/test_dataset.py new file mode 100644 index 000000000..ccd2e4f34 --- /dev/null +++ b/backend/app/tests/crud/evaluations/test_dataset.py @@ -0,0 +1,413 @@ +""" +Tests for evaluation_dataset CRUD operations. +""" + +from unittest.mock import MagicMock + +import pytest +from sqlmodel import Session, select + +from app.core.cloud.storage import CloudStorageError +from app.crud.evaluations.dataset import ( + create_evaluation_dataset, + download_csv_from_object_store, + get_dataset_by_id, + get_dataset_by_name, + list_datasets, + update_dataset_langfuse_id, + upload_csv_to_object_store, +) +from app.models import Organization, Project + + +class TestCreateEvaluationDataset: + """Test creating evaluation datasets.""" + + def test_create_evaluation_dataset_minimal(self, db: Session): + """Test creating a dataset with minimal required fields.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + dataset = create_evaluation_dataset( + session=db, + name="test_dataset", + dataset_metadata={"original_items_count": 10, "total_items_count": 50}, + organization_id=org.id, + project_id=project.id, + ) + + assert dataset.id is not None + assert dataset.name == "test_dataset" + assert dataset.dataset_metadata["original_items_count"] == 10 + assert dataset.dataset_metadata["total_items_count"] == 50 + assert dataset.organization_id == org.id + assert dataset.project_id == project.id + assert dataset.description is None + assert dataset.object_store_url is None + assert dataset.langfuse_dataset_id is None + + def test_create_evaluation_dataset_complete(self, db: Session): + """Test creating a dataset with all fields.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + dataset = create_evaluation_dataset( + session=db, + name="complete_dataset", + description="A complete test dataset", + dataset_metadata={ + "original_items_count": 5, + "total_items_count": 25, + "duplication_factor": 5, + }, + object_store_url="s3://bucket/datasets/complete_dataset.csv", + langfuse_dataset_id="langfuse_123", + organization_id=org.id, + project_id=project.id, + ) + + assert dataset.id is not None + assert dataset.name == "complete_dataset" + assert dataset.description == "A complete test dataset" + assert dataset.dataset_metadata["duplication_factor"] == 5 + assert dataset.object_store_url == "s3://bucket/datasets/complete_dataset.csv" + assert dataset.langfuse_dataset_id == "langfuse_123" + assert dataset.inserted_at is not None + assert dataset.updated_at is not None + + +class TestGetDatasetById: + """Test fetching datasets by ID.""" + + def test_get_dataset_by_id_success(self, db: Session): + """Test fetching an existing dataset by ID.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + # Create a dataset + dataset = create_evaluation_dataset( + session=db, + name="test_dataset", + dataset_metadata={"original_items_count": 10}, + organization_id=org.id, + project_id=project.id, + ) + + # Fetch it by ID + fetched = get_dataset_by_id( + session=db, + dataset_id=dataset.id, + organization_id=org.id, + project_id=project.id, + ) + + assert fetched is not None + assert fetched.id == dataset.id + assert fetched.name == "test_dataset" + + def test_get_dataset_by_id_not_found(self, db: Session): + """Test fetching a non-existent dataset.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + fetched = get_dataset_by_id( + session=db, + dataset_id=99999, + organization_id=org.id, + project_id=project.id, + ) + + assert fetched is None + + def test_get_dataset_by_id_wrong_org(self, db: Session): + """Test that datasets from other orgs can't be fetched.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + # Create a dataset + dataset = create_evaluation_dataset( + session=db, + name="test_dataset", + dataset_metadata={"original_items_count": 10}, + organization_id=org.id, + project_id=project.id, + ) + + # Try to fetch it with wrong org_id + fetched = get_dataset_by_id( + session=db, + dataset_id=dataset.id, + organization_id=99999, # Wrong org + project_id=project.id, + ) + + assert fetched is None + + +class TestGetDatasetByName: + """Test fetching datasets by name.""" + + def test_get_dataset_by_name_success(self, db: Session): + """Test fetching an existing dataset by name.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + # Create a dataset + create_evaluation_dataset( + session=db, + name="unique_dataset", + dataset_metadata={"original_items_count": 10}, + organization_id=org.id, + project_id=project.id, + ) + + # Fetch it by name + fetched = get_dataset_by_name( + session=db, + name="unique_dataset", + organization_id=org.id, + project_id=project.id, + ) + + assert fetched is not None + assert fetched.name == "unique_dataset" + + def test_get_dataset_by_name_not_found(self, db: Session): + """Test fetching a non-existent dataset by name.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + fetched = get_dataset_by_name( + session=db, + name="nonexistent_dataset", + organization_id=org.id, + project_id=project.id, + ) + + assert fetched is None + + +class TestListDatasets: + """Test listing datasets.""" + + def test_list_datasets_empty(self, db: Session): + """Test listing datasets when none exist.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + datasets = list_datasets( + session=db, organization_id=org.id, project_id=project.id + ) + + assert len(datasets) == 0 + + def test_list_datasets_multiple(self, db: Session): + """Test listing multiple datasets.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + # Create multiple datasets + for i in range(5): + create_evaluation_dataset( + session=db, + name=f"dataset_{i}", + dataset_metadata={"original_items_count": i}, + organization_id=org.id, + project_id=project.id, + ) + + datasets = list_datasets( + session=db, organization_id=org.id, project_id=project.id + ) + + assert len(datasets) == 5 + # Should be ordered by most recent first + assert datasets[0].name == "dataset_4" + assert datasets[4].name == "dataset_0" + + def test_list_datasets_pagination(self, db: Session): + """Test pagination of datasets.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + # Create 10 datasets + for i in range(10): + create_evaluation_dataset( + session=db, + name=f"dataset_{i}", + dataset_metadata={"original_items_count": i}, + organization_id=org.id, + project_id=project.id, + ) + + # Get first page + page1 = list_datasets( + session=db, organization_id=org.id, project_id=project.id, limit=5, offset=0 + ) + + # Get second page + page2 = list_datasets( + session=db, organization_id=org.id, project_id=project.id, limit=5, offset=5 + ) + + assert len(page1) == 5 + assert len(page2) == 5 + # Ensure no overlap + page1_names = [d.name for d in page1] + page2_names = [d.name for d in page2] + assert len(set(page1_names) & set(page2_names)) == 0 + + +class TestUploadCsvToObjectStore: + """Test CSV upload to object store.""" + + def test_upload_csv_to_object_store_success(self): + """Test successful object store upload.""" + mock_storage = MagicMock() + mock_storage.put.return_value = "s3://bucket/datasets/test_dataset.csv" + + csv_content = b"question,answer\nWhat is 2+2?,4\n" + + object_store_url = upload_csv_to_object_store( + storage=mock_storage, csv_content=csv_content, dataset_name="test_dataset" + ) + + assert object_store_url == "s3://bucket/datasets/test_dataset.csv" + mock_storage.put.assert_called_once() + + def test_upload_csv_to_object_store_cloud_storage_error(self): + """Test object store upload with CloudStorageError.""" + mock_storage = MagicMock() + mock_storage.put.side_effect = CloudStorageError( + "Object store bucket not found" + ) + + csv_content = b"question,answer\nWhat is 2+2?,4\n" + + # Should return None on error + object_store_url = upload_csv_to_object_store( + storage=mock_storage, csv_content=csv_content, dataset_name="test_dataset" + ) + + assert object_store_url is None + + def test_upload_csv_to_object_store_unexpected_error(self): + """Test object store upload with unexpected error.""" + mock_storage = MagicMock() + mock_storage.put.side_effect = Exception("Unexpected error") + + csv_content = b"question,answer\nWhat is 2+2?,4\n" + + # Should return None on error + object_store_url = upload_csv_to_object_store( + storage=mock_storage, csv_content=csv_content, dataset_name="test_dataset" + ) + + assert object_store_url is None + + +class TestDownloadCsvFromObjectStore: + """Test CSV download from object store.""" + + def test_download_csv_from_object_store_success(self): + """Test successful object store download.""" + mock_storage = MagicMock() + mock_body = MagicMock() + mock_body.read.return_value = b"question,answer\nWhat is 2+2?,4\n" + mock_storage.stream.return_value = mock_body + + csv_content = download_csv_from_object_store( + storage=mock_storage, object_store_url="s3://bucket/datasets/test.csv" + ) + + assert csv_content == b"question,answer\nWhat is 2+2?,4\n" + mock_storage.stream.assert_called_once_with("s3://bucket/datasets/test.csv") + + def test_download_csv_from_object_store_empty_url(self): + """Test download with empty URL.""" + mock_storage = MagicMock() + + with pytest.raises( + ValueError, match="object_store_url cannot be None or empty" + ): + download_csv_from_object_store(storage=mock_storage, object_store_url=None) + + def test_download_csv_from_object_store_error(self): + """Test download with storage error.""" + mock_storage = MagicMock() + mock_storage.stream.side_effect = Exception("Object store download failed") + + with pytest.raises(Exception, match="Object store download failed"): + download_csv_from_object_store( + storage=mock_storage, object_store_url="s3://bucket/datasets/test.csv" + ) + + +class TestUpdateDatasetLangfuseId: + """Test updating Langfuse ID.""" + + def test_update_dataset_langfuse_id(self, db: Session): + """Test updating Langfuse dataset ID.""" + # Get organization and project from seeded data + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + # Create a dataset without Langfuse ID + dataset = create_evaluation_dataset( + session=db, + name="test_dataset", + dataset_metadata={"original_items_count": 10}, + organization_id=org.id, + project_id=project.id, + ) + + assert dataset.langfuse_dataset_id is None + + # Update Langfuse ID + update_dataset_langfuse_id( + session=db, dataset_id=dataset.id, langfuse_dataset_id="langfuse_123" + ) + + # Refresh and verify + db.refresh(dataset) + assert dataset.langfuse_dataset_id == "langfuse_123" + + def test_update_dataset_langfuse_id_nonexistent(self, db: Session): + """Test updating Langfuse ID for non-existent dataset.""" + # Should not raise an error, just do nothing + update_dataset_langfuse_id( + session=db, dataset_id=99999, langfuse_dataset_id="langfuse_123" + ) + # No assertion needed, just ensuring it doesn't crash diff --git a/backend/app/tests/crud/evaluations/test_embeddings.py b/backend/app/tests/crud/evaluations/test_embeddings.py new file mode 100644 index 000000000..c06d78250 --- /dev/null +++ b/backend/app/tests/crud/evaluations/test_embeddings.py @@ -0,0 +1,384 @@ +"""Tests for evaluation embeddings functionality.""" + +import numpy as np +import pytest + +from app.crud.evaluations.embeddings import ( + build_embedding_jsonl, + calculate_average_similarity, + calculate_cosine_similarity, + parse_embedding_results, +) + + +class TestBuildEmbeddingJsonl: + """Tests for build_embedding_jsonl function.""" + + def test_build_embedding_jsonl_basic(self): + """Test building JSONL for basic evaluation results.""" + results = [ + { + "item_id": "item_1", + "question": "What is 2+2?", + "generated_output": "The answer is 4", + "ground_truth": "4", + }, + { + "item_id": "item_2", + "question": "What is the capital of France?", + "generated_output": "Paris", + "ground_truth": "Paris", + }, + ] + + trace_id_mapping = { + "item_1": "trace_1", + "item_2": "trace_2", + } + + jsonl_data = build_embedding_jsonl(results, trace_id_mapping) + + assert len(jsonl_data) == 2 + + # Check first item - uses trace_id as custom_id + assert jsonl_data[0]["custom_id"] == "trace_1" + assert jsonl_data[0]["method"] == "POST" + assert jsonl_data[0]["url"] == "/v1/embeddings" + assert jsonl_data[0]["body"]["model"] == "text-embedding-3-large" + assert jsonl_data[0]["body"]["input"] == ["The answer is 4", "4"] + assert jsonl_data[0]["body"]["encoding_format"] == "float" + + def test_build_embedding_jsonl_custom_model(self): + """Test building JSONL with custom embedding model.""" + results = [ + { + "item_id": "item_1", + "question": "Test?", + "generated_output": "Output", + "ground_truth": "Truth", + } + ] + + trace_id_mapping = {"item_1": "trace_1"} + + jsonl_data = build_embedding_jsonl( + results, trace_id_mapping, embedding_model="text-embedding-3-small" + ) + + assert len(jsonl_data) == 1 + assert jsonl_data[0]["body"]["model"] == "text-embedding-3-small" + + def test_build_embedding_jsonl_skips_empty(self): + """Test that items with empty output or ground_truth are skipped.""" + results = [ + { + "item_id": "item_1", + "question": "Test?", + "generated_output": "", # Empty + "ground_truth": "Truth", + }, + { + "item_id": "item_2", + "question": "Test?", + "generated_output": "Output", + "ground_truth": "", # Empty + }, + { + "item_id": "item_3", + "question": "Test?", + "generated_output": "Output", + "ground_truth": "Truth", + }, + ] + + trace_id_mapping = { + "item_1": "trace_1", + "item_2": "trace_2", + "item_3": "trace_3", + } + + jsonl_data = build_embedding_jsonl(results, trace_id_mapping) + + # Only item_3 should be included + assert len(jsonl_data) == 1 + assert jsonl_data[0]["custom_id"] == "trace_3" + + def test_build_embedding_jsonl_missing_item_id(self): + """Test that items without item_id or trace_id are skipped.""" + results = [ + { + # Missing item_id + "question": "Test?", + "generated_output": "Output", + "ground_truth": "Truth", + }, + { + "item_id": "item_2", + "question": "Test?", + "generated_output": "Output", + "ground_truth": "Truth", + }, + ] + + # Only item_2 has a mapping + trace_id_mapping = {"item_2": "trace_2"} + + jsonl_data = build_embedding_jsonl(results, trace_id_mapping) + + # Only item_2 should be included + assert len(jsonl_data) == 1 + assert jsonl_data[0]["custom_id"] == "trace_2" + + +class TestParseEmbeddingResults: + """Tests for parse_embedding_results function.""" + + def test_parse_embedding_results_basic(self): + """Test parsing basic embedding results.""" + raw_results = [ + { + "custom_id": "trace_1", + "response": { + "body": { + "data": [ + {"index": 0, "embedding": [0.1, 0.2, 0.3]}, + {"index": 1, "embedding": [0.15, 0.22, 0.32]}, + ] + } + }, + }, + { + "custom_id": "trace_2", + "response": { + "body": { + "data": [ + {"index": 0, "embedding": [0.5, 0.6, 0.7]}, + {"index": 1, "embedding": [0.55, 0.65, 0.75]}, + ] + } + }, + }, + ] + + embedding_pairs = parse_embedding_results(raw_results) + + assert len(embedding_pairs) == 2 + + # Check first pair - now uses trace_id + assert embedding_pairs[0]["trace_id"] == "trace_1" + assert embedding_pairs[0]["output_embedding"] == [0.1, 0.2, 0.3] + assert embedding_pairs[0]["ground_truth_embedding"] == [0.15, 0.22, 0.32] + + # Check second pair + assert embedding_pairs[1]["trace_id"] == "trace_2" + assert embedding_pairs[1]["output_embedding"] == [0.5, 0.6, 0.7] + assert embedding_pairs[1]["ground_truth_embedding"] == [0.55, 0.65, 0.75] + + def test_parse_embedding_results_with_error(self): + """Test parsing results with errors.""" + raw_results = [ + { + "custom_id": "trace_1", + "error": {"message": "Rate limit exceeded"}, + }, + { + "custom_id": "trace_2", + "response": { + "body": { + "data": [ + {"index": 0, "embedding": [0.1, 0.2]}, + {"index": 1, "embedding": [0.15, 0.22]}, + ] + } + }, + }, + ] + + embedding_pairs = parse_embedding_results(raw_results) + + # Only trace_2 should be included (trace_1 had error) + assert len(embedding_pairs) == 1 + assert embedding_pairs[0]["trace_id"] == "trace_2" + + def test_parse_embedding_results_missing_embedding(self): + """Test parsing results with missing embeddings.""" + raw_results = [ + { + "custom_id": "trace_1", + "response": { + "body": { + "data": [ + {"index": 0, "embedding": [0.1, 0.2]}, + # Missing index 1 + ] + } + }, + }, + { + "custom_id": "trace_2", + "response": { + "body": { + "data": [ + {"index": 0, "embedding": [0.1, 0.2]}, + {"index": 1, "embedding": [0.15, 0.22]}, + ] + } + }, + }, + ] + + embedding_pairs = parse_embedding_results(raw_results) + + # Only trace_2 should be included (trace_1 missing index 1) + assert len(embedding_pairs) == 1 + assert embedding_pairs[0]["trace_id"] == "trace_2" + + +class TestCalculateCosineSimilarity: + """Tests for calculate_cosine_similarity function.""" + + def test_calculate_cosine_similarity_identical(self): + """Test cosine similarity of identical vectors.""" + vec1 = [1.0, 0.0, 0.0] + vec2 = [1.0, 0.0, 0.0] + + similarity = calculate_cosine_similarity(vec1, vec2) + + assert similarity == pytest.approx(1.0) + + def test_calculate_cosine_similarity_orthogonal(self): + """Test cosine similarity of orthogonal vectors.""" + vec1 = [1.0, 0.0, 0.0] + vec2 = [0.0, 1.0, 0.0] + + similarity = calculate_cosine_similarity(vec1, vec2) + + assert similarity == pytest.approx(0.0) + + def test_calculate_cosine_similarity_opposite(self): + """Test cosine similarity of opposite vectors.""" + vec1 = [1.0, 0.0, 0.0] + vec2 = [-1.0, 0.0, 0.0] + + similarity = calculate_cosine_similarity(vec1, vec2) + + assert similarity == pytest.approx(-1.0) + + def test_calculate_cosine_similarity_partial(self): + """Test cosine similarity of partially similar vectors.""" + vec1 = [1.0, 1.0, 0.0] + vec2 = [1.0, 0.0, 0.0] + + similarity = calculate_cosine_similarity(vec1, vec2) + + # cos(45°) ≈ 0.707 + assert similarity == pytest.approx(0.707, abs=0.01) + + def test_calculate_cosine_similarity_zero_vector(self): + """Test cosine similarity with zero vector.""" + vec1 = [0.0, 0.0, 0.0] + vec2 = [1.0, 0.0, 0.0] + + similarity = calculate_cosine_similarity(vec1, vec2) + + assert similarity == 0.0 + + +class TestCalculateAverageSimilarity: + """Tests for calculate_average_similarity function.""" + + def test_calculate_average_similarity_basic(self): + """Test calculating average similarity for basic embedding pairs.""" + embedding_pairs = [ + { + "trace_id": "trace_1", + "output_embedding": [1.0, 0.0, 0.0], + "ground_truth_embedding": [1.0, 0.0, 0.0], # Similarity = 1.0 + }, + { + "trace_id": "trace_2", + "output_embedding": [1.0, 0.0, 0.0], + "ground_truth_embedding": [0.0, 1.0, 0.0], # Similarity = 0.0 + }, + { + "trace_id": "trace_3", + "output_embedding": [1.0, 1.0, 0.0], + "ground_truth_embedding": [1.0, 0.0, 0.0], # Similarity ≈ 0.707 + }, + ] + + stats = calculate_average_similarity(embedding_pairs) + + assert stats["total_pairs"] == 3 + # Average of [1.0, 0.0, 0.707] ≈ 0.569 + assert stats["cosine_similarity_avg"] == pytest.approx(0.569, abs=0.01) + assert "cosine_similarity_std" in stats + assert len(stats["per_item_scores"]) == 3 + + def test_calculate_average_similarity_empty(self): + """Test calculating average similarity for empty list.""" + embedding_pairs = [] + + stats = calculate_average_similarity(embedding_pairs) + + assert stats["total_pairs"] == 0 + assert stats["cosine_similarity_avg"] == 0.0 + assert stats["cosine_similarity_std"] == 0.0 + assert stats["per_item_scores"] == [] + + def test_calculate_average_similarity_per_item_scores(self): + """Test that per-item scores are correctly calculated.""" + embedding_pairs = [ + { + "trace_id": "trace_1", + "output_embedding": [1.0, 0.0], + "ground_truth_embedding": [1.0, 0.0], + }, + { + "trace_id": "trace_2", + "output_embedding": [0.0, 1.0], + "ground_truth_embedding": [0.0, 1.0], + }, + ] + + stats = calculate_average_similarity(embedding_pairs) + + assert len(stats["per_item_scores"]) == 2 + assert stats["per_item_scores"][0]["trace_id"] == "trace_1" + assert stats["per_item_scores"][0]["cosine_similarity"] == pytest.approx(1.0) + assert stats["per_item_scores"][1]["trace_id"] == "trace_2" + assert stats["per_item_scores"][1]["cosine_similarity"] == pytest.approx(1.0) + + def test_calculate_average_similarity_statistics(self): + """Test that all statistics are calculated correctly.""" + # Create pairs with known similarities + embedding_pairs = [ + { + "trace_id": "trace_1", + "output_embedding": [1.0, 0.0], + "ground_truth_embedding": [1.0, 0.0], # sim = 1.0 + }, + { + "trace_id": "trace_2", + "output_embedding": [1.0, 0.0], + "ground_truth_embedding": [0.0, 1.0], # sim = 0.0 + }, + { + "trace_id": "trace_3", + "output_embedding": [1.0, 0.0], + "ground_truth_embedding": [1.0, 0.0], # sim = 1.0 + }, + { + "trace_id": "trace_4", + "output_embedding": [1.0, 0.0], + "ground_truth_embedding": [0.0, 1.0], # sim = 0.0 + }, + ] + + stats = calculate_average_similarity(embedding_pairs) + + # Similarities = [1.0, 0.0, 1.0, 0.0] + assert stats["cosine_similarity_avg"] == pytest.approx(0.5) + # Standard deviation of [1, 0, 1, 0] = 0.5 + assert stats["cosine_similarity_std"] == pytest.approx(0.5) + assert stats["total_pairs"] == 4 diff --git a/backend/app/tests/crud/evaluations/test_langfuse.py b/backend/app/tests/crud/evaluations/test_langfuse.py new file mode 100644 index 000000000..4717ca6c8 --- /dev/null +++ b/backend/app/tests/crud/evaluations/test_langfuse.py @@ -0,0 +1,414 @@ +""" +Tests for evaluation_langfuse CRUD operations. +""" + +from unittest.mock import MagicMock + +import pytest + +from app.crud.evaluations.langfuse import ( + create_langfuse_dataset_run, + update_traces_with_cosine_scores, + upload_dataset_to_langfuse_from_csv, +) + + +class TestCreateLangfuseDatasetRun: + """Test creating Langfuse dataset runs.""" + + def test_create_langfuse_dataset_run_success(self): + """Test successfully creating a dataset run with traces.""" + # Mock Langfuse client + mock_langfuse = MagicMock() + mock_dataset = MagicMock() + + # Mock dataset items + mock_item1 = MagicMock() + mock_item1.id = "item_1" + mock_item1.observe.return_value.__enter__.return_value = "trace_id_1" + + mock_item2 = MagicMock() + mock_item2.id = "item_2" + mock_item2.observe.return_value.__enter__.return_value = "trace_id_2" + + mock_dataset.items = [mock_item1, mock_item2] + mock_langfuse.get_dataset.return_value = mock_dataset + + # Test data + results = [ + { + "item_id": "item_1", + "question": "What is 2+2?", + "generated_output": "4", + "ground_truth": "4", + }, + { + "item_id": "item_2", + "question": "What is the capital of France?", + "generated_output": "Paris", + "ground_truth": "Paris", + }, + ] + + # Call function + trace_id_mapping = create_langfuse_dataset_run( + langfuse=mock_langfuse, + dataset_name="test_dataset", + run_name="test_run", + results=results, + ) + + # Verify results + assert len(trace_id_mapping) == 2 + assert trace_id_mapping["item_1"] == "trace_id_1" + assert trace_id_mapping["item_2"] == "trace_id_2" + + # Verify Langfuse calls + mock_langfuse.get_dataset.assert_called_once_with("test_dataset") + mock_langfuse.flush.assert_called_once() + assert mock_langfuse.trace.call_count == 2 + + def test_create_langfuse_dataset_run_skips_missing_items(self): + """Test that missing dataset items are skipped.""" + mock_langfuse = MagicMock() + mock_dataset = MagicMock() + + # Only one item exists + mock_item1 = MagicMock() + mock_item1.id = "item_1" + mock_item1.observe.return_value.__enter__.return_value = "trace_id_1" + + mock_dataset.items = [mock_item1] + mock_langfuse.get_dataset.return_value = mock_dataset + + # Results include an item that doesn't exist in dataset + results = [ + { + "item_id": "item_1", + "question": "What is 2+2?", + "generated_output": "4", + "ground_truth": "4", + }, + { + "item_id": "item_nonexistent", + "question": "Invalid question", + "generated_output": "Invalid", + "ground_truth": "Invalid", + }, + ] + + trace_id_mapping = create_langfuse_dataset_run( + langfuse=mock_langfuse, + dataset_name="test_dataset", + run_name="test_run", + results=results, + ) + + # Only the valid item should be in the mapping + assert len(trace_id_mapping) == 1 + assert "item_1" in trace_id_mapping + assert "item_nonexistent" not in trace_id_mapping + + def test_create_langfuse_dataset_run_handles_trace_error(self): + """Test that trace creation errors are handled gracefully.""" + mock_langfuse = MagicMock() + mock_dataset = MagicMock() + + # First item succeeds + mock_item1 = MagicMock() + mock_item1.id = "item_1" + mock_item1.observe.return_value.__enter__.return_value = "trace_id_1" + + # Second item fails + mock_item2 = MagicMock() + mock_item2.id = "item_2" + mock_item2.observe.side_effect = Exception("Trace creation failed") + + mock_dataset.items = [mock_item1, mock_item2] + mock_langfuse.get_dataset.return_value = mock_dataset + + results = [ + { + "item_id": "item_1", + "question": "What is 2+2?", + "generated_output": "4", + "ground_truth": "4", + }, + { + "item_id": "item_2", + "question": "What is the capital?", + "generated_output": "Paris", + "ground_truth": "Paris", + }, + ] + + trace_id_mapping = create_langfuse_dataset_run( + langfuse=mock_langfuse, + dataset_name="test_dataset", + run_name="test_run", + results=results, + ) + + # Only successful item should be in mapping + assert len(trace_id_mapping) == 1 + assert "item_1" in trace_id_mapping + assert "item_2" not in trace_id_mapping + + def test_create_langfuse_dataset_run_empty_results(self): + """Test with empty results list.""" + mock_langfuse = MagicMock() + mock_dataset = MagicMock() + mock_dataset.items = [] + mock_langfuse.get_dataset.return_value = mock_dataset + + trace_id_mapping = create_langfuse_dataset_run( + langfuse=mock_langfuse, + dataset_name="test_dataset", + run_name="test_run", + results=[], + ) + + assert len(trace_id_mapping) == 0 + mock_langfuse.flush.assert_called_once() + + +class TestUpdateTracesWithCosineScores: + """Test updating Langfuse traces with cosine similarity scores.""" + + def test_update_traces_with_cosine_scores_success(self): + """Test successfully updating traces with scores.""" + mock_langfuse = MagicMock() + + per_item_scores = [ + {"trace_id": "trace_1", "cosine_similarity": 0.95}, + {"trace_id": "trace_2", "cosine_similarity": 0.87}, + {"trace_id": "trace_3", "cosine_similarity": 0.92}, + ] + + update_traces_with_cosine_scores( + langfuse=mock_langfuse, per_item_scores=per_item_scores + ) + + # Verify score was called for each item + assert mock_langfuse.score.call_count == 3 + + # Verify the score calls + calls = mock_langfuse.score.call_args_list + assert calls[0].kwargs["trace_id"] == "trace_1" + assert calls[0].kwargs["name"] == "cosine_similarity" + assert calls[0].kwargs["value"] == 0.95 + assert "cosine similarity" in calls[0].kwargs["comment"].lower() + + assert calls[1].kwargs["trace_id"] == "trace_2" + assert calls[1].kwargs["value"] == 0.87 + + mock_langfuse.flush.assert_called_once() + + def test_update_traces_with_cosine_scores_missing_trace_id(self): + """Test that items without trace_id are skipped.""" + mock_langfuse = MagicMock() + + per_item_scores = [ + {"trace_id": "trace_1", "cosine_similarity": 0.95}, + {"cosine_similarity": 0.87}, # Missing trace_id + {"trace_id": "trace_3", "cosine_similarity": 0.92}, + ] + + update_traces_with_cosine_scores( + langfuse=mock_langfuse, per_item_scores=per_item_scores + ) + + # Should only call score for items with trace_id + assert mock_langfuse.score.call_count == 2 + + def test_update_traces_with_cosine_scores_error_handling(self): + """Test that score errors don't stop processing.""" + mock_langfuse = MagicMock() + + # First call succeeds, second fails, third succeeds + mock_langfuse.score.side_effect = [None, Exception("Score failed"), None] + + per_item_scores = [ + {"trace_id": "trace_1", "cosine_similarity": 0.95}, + {"trace_id": "trace_2", "cosine_similarity": 0.87}, + {"trace_id": "trace_3", "cosine_similarity": 0.92}, + ] + + # Should not raise exception + update_traces_with_cosine_scores( + langfuse=mock_langfuse, per_item_scores=per_item_scores + ) + + # All three should have been attempted + assert mock_langfuse.score.call_count == 3 + mock_langfuse.flush.assert_called_once() + + def test_update_traces_with_cosine_scores_empty_list(self): + """Test with empty scores list.""" + mock_langfuse = MagicMock() + + update_traces_with_cosine_scores(langfuse=mock_langfuse, per_item_scores=[]) + + mock_langfuse.score.assert_not_called() + mock_langfuse.flush.assert_called_once() + + +class TestUploadDatasetToLangfuseFromCsv: + """Test uploading datasets to Langfuse from CSV content.""" + + @pytest.fixture + def valid_csv_content(self): + """Valid CSV content.""" + csv_string = """question,answer +"What is 2+2?","4" +"What is the capital of France?","Paris" +"Who wrote Romeo and Juliet?","Shakespeare" +""" + return csv_string.encode("utf-8") + + def test_upload_dataset_to_langfuse_from_csv_success(self, valid_csv_content): + """Test successful upload with duplication.""" + mock_langfuse = MagicMock() + mock_dataset = MagicMock() + mock_dataset.id = "dataset_123" + mock_langfuse.create_dataset.return_value = mock_dataset + + langfuse_id, total_items = upload_dataset_to_langfuse_from_csv( + langfuse=mock_langfuse, + csv_content=valid_csv_content, + dataset_name="test_dataset", + duplication_factor=5, + ) + + assert langfuse_id == "dataset_123" + assert total_items == 15 # 3 items * 5 duplication + + # Verify dataset creation + mock_langfuse.create_dataset.assert_called_once_with(name="test_dataset") + + # Verify dataset items were created (3 original * 5 duplicates = 15) + assert mock_langfuse.create_dataset_item.call_count == 15 + + mock_langfuse.flush.assert_called_once() + + def test_upload_dataset_to_langfuse_from_csv_duplication_metadata( + self, valid_csv_content + ): + """Test that duplication metadata is included.""" + mock_langfuse = MagicMock() + mock_dataset = MagicMock() + mock_dataset.id = "dataset_123" + mock_langfuse.create_dataset.return_value = mock_dataset + + upload_dataset_to_langfuse_from_csv( + langfuse=mock_langfuse, + csv_content=valid_csv_content, + dataset_name="test_dataset", + duplication_factor=3, + ) + + # Check metadata in create_dataset_item calls + calls = mock_langfuse.create_dataset_item.call_args_list + + # Each original item should have 3 duplicates + duplicate_numbers = [] + for call_args in calls: + metadata = call_args.kwargs.get("metadata", {}) + duplicate_numbers.append(metadata.get("duplicate_number")) + assert metadata.get("duplication_factor") == 3 + + # Should have 3 sets of duplicates (1, 2, 3) + assert duplicate_numbers.count(1) == 3 # 3 original items, each with dup #1 + assert duplicate_numbers.count(2) == 3 # 3 original items, each with dup #2 + assert duplicate_numbers.count(3) == 3 # 3 original items, each with dup #3 + + def test_upload_dataset_to_langfuse_from_csv_missing_columns(self): + """Test with CSV missing required columns.""" + mock_langfuse = MagicMock() + + invalid_csv = b"query,response\nWhat is 2+2?,4\n" + + with pytest.raises(ValueError, match="question.*answer"): + upload_dataset_to_langfuse_from_csv( + langfuse=mock_langfuse, + csv_content=invalid_csv, + dataset_name="test_dataset", + duplication_factor=1, + ) + + def test_upload_dataset_to_langfuse_from_csv_empty_rows(self): + """Test that empty rows are skipped.""" + mock_langfuse = MagicMock() + mock_dataset = MagicMock() + mock_dataset.id = "dataset_123" + mock_langfuse.create_dataset.return_value = mock_dataset + + # CSV with some empty rows + csv_with_empty = b"""question,answer +"Valid question 1","Valid answer 1" +"","Empty answer" +"Valid question 2","" +"Valid question 3","Valid answer 3" +""" + + langfuse_id, total_items = upload_dataset_to_langfuse_from_csv( + langfuse=mock_langfuse, + csv_content=csv_with_empty, + dataset_name="test_dataset", + duplication_factor=2, + ) + + # Should only process 2 valid items (first and last) + assert total_items == 4 # 2 valid items * 2 duplication + assert mock_langfuse.create_dataset_item.call_count == 4 + + def test_upload_dataset_to_langfuse_from_csv_empty_dataset(self): + """Test with CSV that has no valid items.""" + mock_langfuse = MagicMock() + + empty_csv = b"""question,answer +"","" +"","answer without question" +""" + + with pytest.raises(ValueError, match="No valid items found"): + upload_dataset_to_langfuse_from_csv( + langfuse=mock_langfuse, + csv_content=empty_csv, + dataset_name="test_dataset", + duplication_factor=1, + ) + + def test_upload_dataset_to_langfuse_from_csv_invalid_encoding(self): + """Test with invalid CSV encoding.""" + mock_langfuse = MagicMock() + + # Invalid UTF-8 bytes + invalid_csv = b"\xff\xfe Invalid UTF-8" + + with pytest.raises((ValueError, Exception)): + upload_dataset_to_langfuse_from_csv( + langfuse=mock_langfuse, + csv_content=invalid_csv, + dataset_name="test_dataset", + duplication_factor=1, + ) + + def test_upload_dataset_to_langfuse_from_csv_default_duplication( + self, valid_csv_content + ): + """Test upload with duplication factor of 1.""" + mock_langfuse = MagicMock() + mock_dataset = MagicMock() + mock_dataset.id = "dataset_123" + mock_langfuse.create_dataset.return_value = mock_dataset + + langfuse_id, total_items = upload_dataset_to_langfuse_from_csv( + langfuse=mock_langfuse, + csv_content=valid_csv_content, + dataset_name="test_dataset", + duplication_factor=1, + ) + + assert total_items == 3 # 3 items * 1 duplication + assert mock_langfuse.create_dataset_item.call_count == 3 diff --git a/backend/app/utils.py b/backend/app/utils.py index 360054dc8..094c36829 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -206,9 +206,7 @@ def get_openai_client(session: Session, org_id: int, project_id: int) -> OpenAI: ) -def get_langfuse_client( - session: Session, org_id: int, project_id: int -) -> Langfuse | None: +def get_langfuse_client(session: Session, org_id: int, project_id: int) -> Langfuse: """ Fetch Langfuse credentials for the current org/project and return a configured client. """ @@ -219,21 +217,32 @@ def get_langfuse_client( project_id=project_id, ) - has_credentials = ( - credentials - and "public_key" in credentials - and "secret_key" in credentials - and "host" in credentials - ) + if not credentials or not all( + key in credentials for key in ["public_key", "secret_key", "host"] + ): + logger.error( + f"[get_langfuse_client] Langfuse credentials not found or incomplete. | project_id: {project_id}" + ) + raise HTTPException( + status_code=400, + detail="Langfuse credentials not configured for this organization/project.", + ) - if has_credentials: + try: return Langfuse( public_key=credentials["public_key"], secret_key=credentials["secret_key"], host=credentials["host"], ) - - return None + except Exception as e: + logger.error( + f"[get_langfuse_client] Failed to configure Langfuse client. | project_id: {project_id} | error: {str(e)}", + exc_info=True, + ) + raise HTTPException( + status_code=500, + detail=f"Failed to configure Langfuse client: {str(e)}", + ) def handle_openai_error(e: openai.OpenAIError) -> str: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 54de78378..98e87d7c3 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "asgi-correlation-id>=4.3.4", "py-zerox>=0.0.7,<1.0.0", "pandas>=2.3.2", + "numpy>=1.24.0", "scikit-learn>=1.7.1", "celery>=5.3.0,<6.0.0", "redis>=5.0.0,<6.0.0", diff --git a/backend/uv.lock b/backend/uv.lock index 7f9cb9107..45f551a1a 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -194,6 +194,7 @@ dependencies = [ { name = "jinja2" }, { name = "langfuse" }, { name = "moto", extra = ["s3"] }, + { name = "numpy" }, { name = "openai" }, { name = "openai-responses" }, { name = "pandas" }, @@ -238,7 +239,8 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.4,<4.0.0" }, { name = "langfuse", specifier = "==2.60.3" }, { name = "moto", extras = ["s3"], specifier = ">=5.1.1" }, - { name = "openai", specifier = ">=1.100.1" }, + { name = "numpy", specifier = ">=1.24.0" }, + { name = "openai", specifier = ">=1.67.0" }, { name = "openai-responses" }, { name = "pandas", specifier = ">=2.3.2" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4,<2.0.0" }, diff --git a/scripts/python/invoke-cron.py b/scripts/python/invoke-cron.py new file mode 100644 index 000000000..cbc42a82b --- /dev/null +++ b/scripts/python/invoke-cron.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +""" +Cron script to invoke an API endpoint periodically. +Uses async HTTP client to be resource-efficient. +""" + +import asyncio +import logging +import os +import sys +from datetime import datetime +from pathlib import Path + +import httpx +from dotenv import load_dotenv + +# Configuration +ENDPOINT = "/api/v1/cron/evaluations" # Endpoint to invoke +REQUEST_TIMEOUT = 30 # Timeout for requests in seconds + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +class EndpointInvoker: + """Handles periodic endpoint invocation with authentication.""" + + def __init__(self): + # Load BASE_URL from environment with default fallback + base_url = os.getenv("API_BASE_URL", "http://localhost:8000") + self.base_url = base_url.rstrip("/") + self.endpoint = ENDPOINT + + # Load interval from environment with default of 5 minutes + self.interval_minutes = int(os.getenv("CRON_INTERVAL_MINUTES", "5")) + self.interval_seconds = self.interval_minutes * 60 + self.access_token = None + self.token_expiry = None + + # Load credentials from environment + self.email = os.getenv("FIRST_SUPERUSER") + self.password = os.getenv("FIRST_SUPERUSER_PASSWORD") + + if not self.email or not self.password: + raise ValueError( + "FIRST_SUPERUSER and FIRST_SUPERUSER_PASSWORD must be set in environment" + ) + + async def authenticate(self, client: httpx.AsyncClient) -> str: + """Authenticate and get access token.""" + logger.info("Authenticating with API...") + + login_data = { + "username": self.email, + "password": self.password, + } + + try: + response = await client.post( + f"{self.base_url}/api/v1/login/access-token", + data=login_data, + timeout=REQUEST_TIMEOUT, + ) + response.raise_for_status() + + data = response.json() + self.access_token = data.get("access_token") + + if not self.access_token: + raise ValueError("No access token in response") + + logger.info("Authentication successful") + return self.access_token + + except httpx.HTTPStatusError as e: + logger.error(f"Authentication failed with status {e.response.status_code}") + raise + except Exception as e: + logger.error(f"Authentication error: {e}") + raise + + async def invoke_endpoint(self, client: httpx.AsyncClient) -> dict: + """Invoke the configured endpoint.""" + if not self.access_token: + await self.authenticate(client) + + headers = {"Authorization": f"Bearer {self.access_token}"} + + # Debug: Log what we're sending + logger.debug(f"Request URL: {self.base_url}{self.endpoint}") + logger.debug(f"Request headers: {headers}") + + try: + response = await client.get( + f"{self.base_url}{self.endpoint}", + headers=headers, + timeout=REQUEST_TIMEOUT, + ) + + # Debug: Log response headers and first part of body + logger.debug(f"Response status: {response.status_code}") + logger.debug(f"Response headers: {dict(response.headers)}") + + # If unauthorized, re-authenticate and retry once + if response.status_code == 401: + logger.info("Token expired, re-authenticating...") + await self.authenticate(client) + headers = {"Authorization": f"Bearer {self.access_token}"} + response = await client.get( + f"{self.base_url}{self.endpoint}", + headers=headers, + timeout=REQUEST_TIMEOUT, + ) + + response.raise_for_status() + return response.json() + + except httpx.HTTPStatusError as e: + logger.error( + f"Endpoint invocation failed with status {e.response.status_code}: {e.response.text}" + ) + raise + except Exception as e: + logger.error(f"Endpoint invocation error: {e}") + raise + + async def run(self): + """Main loop to invoke endpoint periodically.""" + logger.info(f"Using API Base URL: {self.base_url}") + logger.info( + f"Starting cron job - invoking {self.endpoint} every {self.interval_minutes} minutes" + ) + + # Use async context manager to ensure proper cleanup + async with httpx.AsyncClient() as client: + # Authenticate once at startup + await self.authenticate(client) + + while True: + try: + start_time = datetime.now() + logger.info(f"Invoking endpoint at {start_time}") + + result = await self.invoke_endpoint(client) + logger.info(f"Endpoint invoked successfully: {result}") + + # Calculate next invocation time + elapsed = (datetime.now() - start_time).total_seconds() + sleep_time = max(0, self.interval_seconds - elapsed) + + if sleep_time > 0: + logger.info( + f"Sleeping for {sleep_time:.1f} seconds until next invocation" + ) + await asyncio.sleep(sleep_time) + + except KeyboardInterrupt: + logger.info("Shutting down gracefully...") + break + except Exception as e: + logger.error(f"Error during invocation: {e}") + # Wait before retrying on error + logger.info(f"Waiting {self.interval_seconds} seconds before retry") + await asyncio.sleep(self.interval_seconds) + + +def main(): + """Entry point for the script.""" + # Load environment variables + env_path = Path(__file__).parent.parent.parent / ".env" + if env_path.exists(): + load_dotenv(env_path) + logger.info(f"Loaded environment from {env_path}") + else: + logger.warning(f"No .env file found at {env_path}") + + try: + invoker = EndpointInvoker() + asyncio.run(invoker.run()) + except KeyboardInterrupt: + logger.info("Interrupted by user") + sys.exit(0) + except Exception as e: + logger.error(f"Fatal error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main()