diff --git a/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py b/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py new file mode 100644 index 00000000..f3c7d502 --- /dev/null +++ b/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py @@ -0,0 +1,114 @@ +"""Add prompt and version table + +Revision ID: 757f50ada8ef +Revises: e9dd35eff62c +Create Date: 2025-08-14 11:45:07.186686 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = "757f50ada8ef" +down_revision = "e9dd35eff62c" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "prompt", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column( + "description", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True + ), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("active_version", sa.Uuid(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("is_deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["active_version"], + ["prompt_version.id"], + initially="DEFERRED", + deferrable=True, + use_alter=True, + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_prompt_name"), "prompt", ["name"], unique=False) + op.create_index( + op.f("ix_prompt_project_id"), "prompt", ["project_id"], unique=False + ) + op.create_index( + "ix_prompt_project_id_is_deleted", + "prompt", + ["project_id", "is_deleted"], + unique=False, + ) + op.create_table( + "prompt_version", + sa.Column("instruction", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "commit_message", + sqlmodel.sql.sqltypes.AutoString(length=512), + nullable=True, + ), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("prompt_id", sa.Uuid(), nullable=False), + sa.Column("version", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("is_deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["prompt_id"], + ["prompt.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("prompt_id", "version"), + ) + op.create_foreign_key( + None, + "prompt", + "prompt_version", + ["active_version"], + ["id"], + initially="DEFERRED", + deferrable=True, + use_alter=True, + ) + op.create_index( + op.f("ix_prompt_version_prompt_id"), + "prompt_version", + ["prompt_id"], + unique=False, + ) + op.create_index( + "ix_prompt_version_prompt_id_is_deleted", + "prompt_version", + ["prompt_id", "is_deleted"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_prompt_version_prompt_id_is_deleted", table_name="prompt_version") + op.drop_index(op.f("ix_prompt_version_prompt_id"), table_name="prompt_version") + op.drop_table("prompt_version") + op.drop_index("ix_prompt_project_id_is_deleted", table_name="prompt") + op.drop_index(op.f("ix_prompt_project_id"), table_name="prompt") + op.drop_index(op.f("ix_prompt_name"), table_name="prompt") + op.drop_table("prompt") + # ### end Alembic commands ### diff --git a/backend/app/api/main.py b/backend/app/api/main.py index df0b1016..abf46768 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -10,6 +10,8 @@ openai_conversation, project, project_user, + prompts, + prompt_versions, responses, private, threads, @@ -32,6 +34,8 @@ api_router.include_router(organization.router) api_router.include_router(project.router) api_router.include_router(project_user.router) +api_router.include_router(prompts.router) +api_router.include_router(prompt_versions.router) api_router.include_router(responses.router) api_router.include_router(threads.router) api_router.include_router(users.router) diff --git a/backend/app/api/routes/prompt_versions.py b/backend/app/api/routes/prompt_versions.py new file mode 100644 index 00000000..a16376f3 --- /dev/null +++ b/backend/app/api/routes/prompt_versions.py @@ -0,0 +1,55 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends +from sqlmodel import Session + +from app.api.deps import CurrentUserOrgProject, get_db +from app.crud import create_prompt_version, delete_prompt_version +from app.models import PromptVersionCreate, PromptVersionPublic +from app.utils import APIResponse + + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/prompts", tags=["Prompt Versions"]) + + +@router.post( + "/{prompt_id}/versions", + response_model=APIResponse[PromptVersionPublic], + status_code=201, +) +def create_prompt_version_route( + prompt_version_in: PromptVersionCreate, + prompt_id: UUID, + current_user: CurrentUserOrgProject, + session: Session = Depends(get_db), +): + version = create_prompt_version( + session=session, + prompt_id=prompt_id, + prompt_version_in=prompt_version_in, + project_id=current_user.project_id, + ) + return APIResponse.success_response(version) + + +@router.delete("/{prompt_id}/versions/{version_id}", response_model=APIResponse) +def delete_prompt_version_route( + prompt_id: UUID, + version_id: UUID, + current_user: CurrentUserOrgProject, + session: Session = Depends(get_db), +): + """ + Delete a prompt version by ID. + """ + delete_prompt_version( + session=session, + prompt_id=prompt_id, + version_id=version_id, + project_id=current_user.project_id, + ) + return APIResponse.success_response( + data={"message": "Prompt version deleted successfully."} + ) diff --git a/backend/app/api/routes/prompts.py b/backend/app/api/routes/prompts.py new file mode 100644 index 00000000..2ff6a774 --- /dev/null +++ b/backend/app/api/routes/prompts.py @@ -0,0 +1,134 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends, Path, Query +from sqlmodel import Session + +from app.api.deps import CurrentUserOrgProject, get_db +from app.crud import ( + create_prompt, + delete_prompt, + get_prompt_by_id, + get_prompts, + count_prompts_in_project, + update_prompt, +) +from app.models import ( + PromptCreate, + PromptPublic, + PromptUpdate, + PromptWithVersion, + PromptWithVersions, +) +from app.utils import APIResponse + + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/prompts", tags=["Prompts"]) + + +@router.post("/", response_model=APIResponse[PromptWithVersion], status_code=201) +def create_prompt_route( + prompt_in: PromptCreate, + current_user: CurrentUserOrgProject, + session: Session = Depends(get_db), +): + """ + Create a new prompt under the specified organization and project. + """ + prompt, version = create_prompt( + session=session, prompt_in=prompt_in, project_id=current_user.project_id + ) + prompt_with_version = PromptWithVersion(**prompt.model_dump(), version=version) + return APIResponse.success_response(prompt_with_version) + + +@router.get( + "/", + response_model=APIResponse[list[PromptPublic]], +) +def get_prompts_route( + current_user: CurrentUserOrgProject, + skip: int = Query( + 0, ge=0, description="Number of prompts to skip (for pagination)." + ), + limit: int = Query(100, gt=0, description="Maximum number of prompts to return."), + session: Session = Depends(get_db), +): + """ + Get all prompts for the specified organization and project. + """ + prompts = get_prompts( + session=session, + project_id=current_user.project_id, + skip=skip, + limit=limit, + ) + total = count_prompts_in_project( + session=session, project_id=current_user.project_id + ) + metadata = {"pagination": {"total": total, "skip": skip, "limit": limit}} + return APIResponse.success_response(prompts, metadata=metadata) + + +@router.get( + "/{prompt_id}", + response_model=APIResponse[PromptWithVersions], + summary="Get a single prompt by its ID by default returns the active version", +) +def get_prompt_by_id_route( + current_user: CurrentUserOrgProject, + prompt_id: UUID = Path(..., description="The ID of the prompt to fetch"), + include_versions: bool = Query( + False, description="Whether to include all versions of the prompt." + ), + session: Session = Depends(get_db), +): + """ + Get a single prompt by its ID. + """ + prompt, versions = get_prompt_by_id( + session=session, + prompt_id=prompt_id, + project_id=current_user.project_id, + include_versions=include_versions, + ) + prompt_with_versions = PromptWithVersions(**prompt.model_dump(), versions=versions) + return APIResponse.success_response(prompt_with_versions) + + +@router.patch("/{prompt_id}", response_model=APIResponse[PromptPublic]) +def update_prompt_route( + current_user: CurrentUserOrgProject, + prompt_update: PromptUpdate, + prompt_id: UUID = Path(..., description="The ID of the prompt to Update"), + session: Session = Depends(get_db), +): + """ + Update a prompt's name or description. + """ + + prompt = update_prompt( + session=session, + prompt_id=prompt_id, + project_id=current_user.project_id, + prompt_update=prompt_update, + ) + return APIResponse.success_response(prompt) + + +@router.delete("/{prompt_id}", response_model=APIResponse) +def delete_prompt_route( + current_user: CurrentUserOrgProject, + prompt_id: UUID = Path(..., description="The ID of the prompt to delete"), + session: Session = Depends(get_db), +): + """ + Delete a prompt by ID. + """ + delete_prompt( + session=session, prompt_id=prompt_id, project_id=current_user.project_id + ) + return APIResponse.success_response( + data={"message": "Prompt deleted successfully."} + ) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index e4b973a0..31573155 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -64,3 +64,14 @@ create_conversation, delete_conversation, ) + +from .prompt_versions import create_prompt_version, delete_prompt_version + +from .prompts import ( + create_prompt, + count_prompts_in_project, + delete_prompt, + get_prompt_by_id, + get_prompts, + update_prompt, +) diff --git a/backend/app/crud/prompt_versions.py b/backend/app/crud/prompt_versions.py new file mode 100644 index 00000000..2c307b22 --- /dev/null +++ b/backend/app/crud/prompt_versions.py @@ -0,0 +1,103 @@ +import logging +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, and_, select + +from app.core.util import now +from app.crud.prompts import prompt_exists +from app.models import PromptVersion, PromptVersionCreate + +logger = logging.getLogger(__name__) + + +def get_next_prompt_version(session: Session, prompt_id: int) -> int: + """ + fetch the next prompt version for a given prompt_id and project_id + """ + + # Not filtering is_deleted here because we want to get the next version even if the latest version is deleted + prompt_version = session.exec( + select(PromptVersion) + .where( + PromptVersion.prompt_id == prompt_id, + ) + .order_by(PromptVersion.version.desc()) + ).first() + + return prompt_version.version + 1 if prompt_version else 1 + + +def create_prompt_version( + session: Session, + prompt_id: UUID, + prompt_version_in: PromptVersionCreate, + project_id: int, +) -> PromptVersion: + """Create a new version for an existing prompt.""" + + prompt_exists( + session=session, + prompt_id=prompt_id, + project_id=project_id, + ) + + next_version = get_next_prompt_version(session=session, prompt_id=prompt_id) + prompt_version = PromptVersion( + prompt_id=prompt_id, + version=next_version, + instruction=prompt_version_in.instruction, + commit_message=prompt_version_in.commit_message, + ) + + session.add(prompt_version) + session.commit() + session.refresh(prompt_version) + logger.info( + f"[create_prompt_version] Created new version prompt_version | Prompt ID: {prompt_id}, Version: {prompt_version.version}" + ) + return prompt_version + + +def delete_prompt_version( + session: Session, prompt_id: UUID, version_id: UUID, project_id: int +): + """ + Delete a prompt version by ID. + """ + prompt = prompt_exists( + session=session, + prompt_id=prompt_id, + project_id=project_id, + ) + if prompt.active_version == version_id: + logger.error( + f"[delete_prompt_version] Cannot delete active version | Version ID: {version_id}, Prompt ID: {prompt_id}" + ) + raise HTTPException(status_code=409, detail="Cannot delete active version") + + stmt = select(PromptVersion).where( + and_( + PromptVersion.id == version_id, + PromptVersion.prompt_id == prompt_id, + PromptVersion.is_deleted.is_(False), + ) + ) + prompt_version = session.exec(stmt).first() + + if not prompt_version: + logger.error( + f"[delete_prompt_version] Prompt version not found | version_id={version_id}, prompt_id={prompt_id}" + ) + raise HTTPException(status_code=404, detail="Prompt version not found") + + prompt_version.is_deleted = True + prompt_version.deleted_at = now() + + session.add(prompt_version) + session.commit() + session.refresh(prompt_version) + + logger.info( + f"[delete_prompt_version] Deleted prompt version | Version ID: {version_id}, Prompt ID: {prompt_id}" + ) diff --git a/backend/app/crud/prompts.py b/backend/app/crud/prompts.py new file mode 100644 index 00000000..86b11716 --- /dev/null +++ b/backend/app/crud/prompts.py @@ -0,0 +1,191 @@ +import logging +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, and_, func, select + +from app.core.util import now +from app.models import ( + Prompt, + PromptCreate, + PromptUpdate, + PromptVersion, + PromptWithVersion, + PromptWithVersions, +) + +logger = logging.getLogger(__name__) + + +def create_prompt( + session: Session, prompt_in: PromptCreate, project_id: int +) -> tuple[Prompt, PromptVersion]: + """ + Create a new prompt and its first version. + """ + prompt = Prompt( + name=prompt_in.name, description=prompt_in.description, project_id=project_id + ) + session.add(prompt) + session.flush() + + version = PromptVersion( + prompt_id=prompt.id, + instruction=prompt_in.instruction, + commit_message=prompt_in.commit_message, + version=1, + ) + session.add(version) + session.flush() + + prompt.active_version = version.id + + session.commit() + session.refresh(prompt) + + logger.info( + f"[create_prompt] Prompt created | id={prompt.id}, name={prompt.name}, " + f"project_id={project_id}, version_id={version.id}" + ) + + return prompt, version + + +def get_prompts( + session: Session, + project_id: int, + skip: int = 0, + limit: int = 100, +) -> list[Prompt]: + """Get prompts for a project.""" + stmt = ( + select(Prompt) + .where(Prompt.project_id == project_id, Prompt.is_deleted.is_(False)) + .order_by(Prompt.updated_at.desc()) + .offset(skip) + .limit(limit) + ) + return session.exec(stmt).all() + + +def count_prompts_in_project(session: Session, project_id: int) -> int: + return session.exec( + select(func.count()).select_from( + select(Prompt) + .where(Prompt.project_id == project_id, Prompt.is_deleted == False) + .subquery() + ) + ).one() + + +def prompt_exists(session: Session, prompt_id: UUID, project_id: int) -> Prompt: + """ + Check if a prompt exists for the given ID and project. + """ + stmt = select(Prompt).where( + Prompt.id == prompt_id, + Prompt.project_id == project_id, + Prompt.is_deleted.is_(False), + ) + + prompt = session.exec(stmt).first() + if not prompt: + logger.error( + f"[update_prompt] Prompt not found | prompt_id={prompt_id}, project_id={project_id}" + ) + raise HTTPException(status_code=404, detail="Prompt not found.") + + return prompt + + +def get_prompt_by_id( + session: Session, prompt_id: UUID, project_id: int, include_versions: bool = False +) -> tuple[Prompt, list[PromptVersion]]: + """ + Get a prompt by its ID, optionally including all versions. + By default, Always returns the active version. + """ + if include_versions: + join_condition = and_( + PromptVersion.prompt_id == Prompt.id, PromptVersion.is_deleted.is_(False) + ) + order_by = PromptVersion.version.desc() + else: + join_condition = and_( + PromptVersion.id == Prompt.active_version, + PromptVersion.is_deleted.is_(False), + ) + order_by = None # no need to order when fetching only 1 row + + stmt = ( + select(Prompt, PromptVersion) + .join(PromptVersion, join_condition, isouter=True) + .where( + Prompt.id == prompt_id, + Prompt.project_id == project_id, + Prompt.is_deleted.is_(False), + ) + ) + if order_by is not None: + stmt = stmt.order_by(order_by) + + results = session.exec(stmt).all() + if not results: + logger.error( + f"[get_prompt_by_id] Prompt not found | ID: {prompt_id}, Project ID: {project_id}" + ) + raise HTTPException(status_code=404, detail="Prompt not found") + + # Unpack tuples into variables + prompt, _ = results[0] + versions = [version for _, version in results if version is not None] + + return prompt, versions + + +def update_prompt( + session: Session, prompt_id: UUID, project_id: int, prompt_update: PromptUpdate +) -> Prompt: + prompt = prompt_exists(session=session, prompt_id=prompt_id, project_id=project_id) + update_prompt = prompt_update.model_dump(exclude_unset=True) + + active_version = update_prompt.get("active_version") + if active_version: + stmt = select(PromptVersion).where( + PromptVersion.id == active_version, + PromptVersion.prompt_id == prompt.id, + PromptVersion.is_deleted.is_(False), + ) + prompt_version = session.exec(stmt).first() + if not prompt_version: + logger.error( + f"[update_prompt] Prompt version not found | version_id={active_version}, prompt_id={prompt.id}" + ) + raise HTTPException(status_code=404, detail="Invalid Active Version Id") + + if update_prompt: + for field, value in update_prompt.items(): + setattr(prompt, field, value) + prompt.updated_at = now() + session.add(prompt) + session.commit() + session.refresh(prompt) + + logger.info( + f"[update_prompt] Prompt updated | id={prompt.id}, name={prompt.name}, project_id={project_id}" + ) + + return prompt + + +def delete_prompt(session: Session, prompt_id: UUID, project_id: int) -> None: + prompt = prompt_exists(session=session, prompt_id=prompt_id, project_id=project_id) + + prompt.is_deleted = True + prompt.deleted_at = now() + session.add(prompt) + session.commit() + session.refresh(prompt) + logger.info( + f"[delete_prompt] Prompt deleted | id={prompt.id}, name={prompt.name}, project_id={project_id}" + ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index a1c2009c..07b8354c 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -62,3 +62,18 @@ OpenAIConversationBase, OpenAIConversationCreate, ) + +from .prompt import ( + Prompt, + PromptCreate, + PromptPublic, + PromptUpdate, + PromptWithVersion, + PromptWithVersions, +) + +from .prompt_version import ( + PromptVersion, + PromptVersionCreate, + PromptVersionPublic, +) diff --git a/backend/app/models/prompt.py b/backend/app/models/prompt.py new file mode 100644 index 00000000..d02ac5cf --- /dev/null +++ b/backend/app/models/prompt.py @@ -0,0 +1,79 @@ +from datetime import datetime +from uuid import UUID, uuid4 + +from sqlalchemy import Column, ForeignKey +from sqlmodel import Index, SQLModel, Field, Relationship + +from app.core.util import now +from app.models.prompt_version import ( + PromptVersion, + PromptVersionCreate, + PromptVersionPublic, +) + + +class PromptBase(SQLModel): + name: str = Field(index=True, nullable=False, min_length=1, max_length=50) + description: str | None = Field(default=None, min_length=1, max_length=500) + + +class Prompt(PromptBase, table=True): + __table_args__ = ( + Index("ix_prompt_project_id_is_deleted", "project_id", "is_deleted"), + ) + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + ) + active_version: UUID = Field( + default_factory=uuid4, + sa_column=Column( + ForeignKey( + "prompt_version.id", + use_alter=True, + deferrable=True, + initially="DEFERRED", + ), + nullable=False, + ), + ) + project_id: int = Field(foreign_key="project.id", index=True, nullable=False) + inserted_at: datetime = Field(default_factory=now, nullable=False) + updated_at: datetime = Field(default_factory=now, nullable=False) + is_deleted: bool = Field(default=False) + deleted_at: datetime | None = None + + versions: list["PromptVersion"] = Relationship( + back_populates="prompt", + sa_relationship_kwargs={"foreign_keys": "[PromptVersion.prompt_id]"}, + ) + + +class PromptPublic(PromptBase): + id: UUID + active_version: UUID + project_id: int + inserted_at: datetime + updated_at: datetime + + +class PromptWithVersion(PromptPublic): + version: PromptVersionPublic + + +class PromptWithVersions(PromptPublic): + versions: list[PromptVersionPublic] + + +class PromptCreate(PromptBase, PromptVersionCreate): + pass + + +class PromptUpdate(SQLModel): + name: str | None = Field(default=None, min_length=1, max_length=50) + description: str | None = Field(default=None, min_length=1, max_length=500) + active_version: UUID | None = Field(default=None) + + class Config: + from_attributes = True diff --git a/backend/app/models/prompt_version.py b/backend/app/models/prompt_version.py new file mode 100644 index 00000000..5378429a --- /dev/null +++ b/backend/app/models/prompt_version.py @@ -0,0 +1,50 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from sqlmodel import Index, SQLModel, UniqueConstraint, Field, Relationship + +from app.core.util import now + +if TYPE_CHECKING: + from app.models.prompt import Prompt + + +class PromptVersionBase(SQLModel): + instruction: str = Field(nullable=False, min_length=1) + commit_message: str | None = Field(default=None, max_length=512) + + +class PromptVersion(PromptVersionBase, table=True): + __tablename__ = "prompt_version" + __table_args__ = ( + UniqueConstraint("prompt_id", "version"), + Index("ix_prompt_version_prompt_id_is_deleted", "prompt_id", "is_deleted"), + ) + + id: UUID = Field(default_factory=uuid4, primary_key=True) + prompt_id: UUID = Field(foreign_key="prompt.id", nullable=False, index=True) + version: int = Field(nullable=False) + + inserted_at: datetime = Field(default_factory=now, nullable=False) + updated_at: datetime = Field(default_factory=now, nullable=False) + + is_deleted: bool = Field(default=False, nullable=False) + deleted_at: datetime | None = Field(default=None) + + prompt: "Prompt" = Relationship( + back_populates="versions", + sa_relationship_kwargs={"foreign_keys": "[PromptVersion.prompt_id]"}, + ) + + +class PromptVersionPublic(PromptVersionBase): + id: UUID + prompt_id: UUID + version: int + inserted_at: datetime + updated_at: datetime + + +class PromptVersionCreate(PromptVersionBase): + pass diff --git a/backend/app/tests/api/routes/test_prompt_versions.py b/backend/app/tests/api/routes/test_prompt_versions.py new file mode 100644 index 00000000..3d66c386 --- /dev/null +++ b/backend/app/tests/api/routes/test_prompt_versions.py @@ -0,0 +1,71 @@ +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.crud.prompts import create_prompt +from app.models import APIKeyPublic, PromptCreate, PromptVersion, PromptVersionCreate +from app.tests.utils.test_data import create_test_prompt + + +def test_create_prompt_version_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful creation of a prompt version via API route.""" + prompt, _ = create_test_prompt(db, user_api_key.project_id) + + version_in = PromptVersionCreate( + instruction="Version 2 instructions", commit_message="Second version" + ) + + response = client.post( + f"/api/v1/prompts/{prompt.id}/versions", + headers={"X-API-KEY": user_api_key.key}, + json=version_in.model_dump(), + ) + + assert response.status_code == 201 + response_data = response.json() + + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + assert data["prompt_id"] == str(prompt.id) + assert data["instruction"] == version_in.instruction + assert data["commit_message"] == version_in.commit_message + assert ( + data["version"] == 2 + ) # First version created by create_prompt, this is second + + +def test_delete_prompt_version_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful deletion of a non-active prompt version via API route""" + prompt, _ = create_test_prompt(db, user_api_key.project_id) + + # Create a second version (non-active) + second_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2, + ) + db.add(second_version) + db.commit() + + response = client.delete( + f"/api/v1/prompts/{prompt.id}/versions/{second_version.id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"]["message"] == "Prompt version deleted successfully." + + db.refresh(second_version) + assert second_version.is_deleted + assert second_version.deleted_at is not None diff --git a/backend/app/tests/api/routes/test_prompts.py b/backend/app/tests/api/routes/test_prompts.py new file mode 100644 index 00000000..8c3d21a1 --- /dev/null +++ b/backend/app/tests/api/routes/test_prompts.py @@ -0,0 +1,277 @@ +from uuid import uuid4 + +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.models import APIKeyPublic, PromptCreate, PromptUpdate, PromptVersion +from app.tests.utils.test_data import create_test_prompt + + +def test_create_prompt_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful creation of a prompt via API route""" + project_id = user_api_key.project_id + prompt_in = PromptCreate( + name="test_prompt", + description="Test prompt description", + instruction="Test instruction", + commit_message="Initial version", + ) + + response = client.post( + "/api/v1/prompts/", + headers={"X-API-KEY": user_api_key.key}, + json=prompt_in.model_dump(), + ) + + assert response.status_code == 201 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + + assert data["name"] == prompt_in.name + assert data["description"] == prompt_in.description + assert data["project_id"] == project_id + assert data["inserted_at"] is not None + assert data["updated_at"] is not None + + assert "version" in data + assert data["version"]["instruction"] == prompt_in.instruction + assert data["version"]["commit_message"] == prompt_in.commit_message + assert data["version"]["version"] == 1 + assert data["active_version"] == data["version"]["id"] + + +def test_get_prompts_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successfully retrieving prompts with pagination metadata""" + project_id = user_api_key.project_id + + create_test_prompt(db, project_id, name="prompt_1") + create_test_prompt(db, project_id, name="prompt_2") + + response = client.get( + f"/api/v1/prompts/?skip=0&limit=100", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + assert "metadata" in response_data + assert response_data["metadata"]["pagination"]["total"] == 2 + assert response_data["metadata"]["pagination"]["skip"] == 0 + assert response_data["metadata"]["pagination"]["limit"] == 100 + + prompts = response_data["data"] + assert len(prompts) == 2 + assert prompts[0]["name"] == "prompt_2" + assert prompts[1]["name"] == "prompt_1" + assert all(prompt["project_id"] == project_id for prompt in prompts) + + +def test_get_prompts_route_empty( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test retrieving an empty list when no prompts exist""" + response = client.get( + f"/api/v1/prompts/?skip=0&limit=100", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"] == [] + assert response_data["metadata"]["pagination"]["total"] == 0 + assert response_data["metadata"]["pagination"]["skip"] == 0 + assert response_data["metadata"]["pagination"]["limit"] == 100 + + +def test_get_prompts_route_pagination( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test retrieving prompts with specific skip and limit values""" + project_id = user_api_key.project_id + + for i in range(3): + create_test_prompt(db, project_id, name=f"prompt_{i}") + + response = client.get( + f"/api/v1/prompts/?skip=1&limit=1", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert len(response_data["data"]) == 1 + assert response_data["metadata"]["pagination"]["total"] == 3 + assert response_data["metadata"]["pagination"]["skip"] == 1 + assert response_data["metadata"]["pagination"]["limit"] == 1 + + +def test_get_prompt_by_id_route_success_active_version( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successfully retrieving a prompt with its active version""" + project_id = user_api_key.project_id + prompt, version = create_test_prompt(db, project_id, name="test_prompt") + + response = client.get( + f"/api/v1/prompts/{prompt.id}?include_versions=false", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + + assert data["id"] == str(prompt.id) + assert data["name"] == "test_prompt" + assert data["description"] == "Test prompt description" + assert data["project_id"] == project_id + assert len(data["versions"]) == 1 + assert data["versions"][0]["id"] == str(version.id) + assert data["versions"][0]["instruction"] == "Test instruction" + assert data["versions"][0]["commit_message"] == "Initial version" + assert data["versions"][0]["version"] == 1 + assert data["active_version"] == str(version.id) + + +def test_get_prompt_by_id_route_with_versions( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test retrieving a prompt with all its versions""" + project_id = user_api_key.project_id + prompt, _ = create_test_prompt(db, project_id, name="test_prompt") + + second_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2, + ) + db.add(second_version) + db.commit() + + response = client.get( + f"/api/v1/prompts/{prompt.id}?include_versions=true", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + + assert len(data["versions"]) == 2 + assert data["versions"][0]["version"] == 2 + assert data["versions"][1]["version"] == 1 + assert data["versions"][0]["instruction"] == "Second instruction" + assert data["versions"][1]["instruction"] == "Test instruction" + + +def test_get_prompt_by_id_route_not_found( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test retrieving a non-existent prompt returns 404""" + non_existent_prompt_id = uuid4() + + response = client.get( + f"/api/v1/prompts/{non_existent_prompt_id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 404 + response_data = response.json() + assert response_data["success"] is False + assert "not found" in response_data["error"].lower() + + +def test_update_prompt_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successfully updating a prompt's name and description""" + project_id = user_api_key.project_id + prompt, _ = create_test_prompt(db, project_id, name="test_prompt") + + prompt_version = PromptVersion( + prompt_id=prompt.id, + instruction="Test instruction", + commit_message="Initial version", + version=2, + ) + db.add(prompt_version) + db.commit() + + update_data = PromptUpdate( + name="updated_prompt", + description="Updated description", + active_version=prompt_version.id, + ) + + response = client.patch( + f"/api/v1/prompts/{prompt.id}", + headers={"X-API-KEY": user_api_key.key}, + json=update_data.model_dump(mode="json", exclude_unset=True), + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + + assert data["id"] == str(prompt.id) + assert data["name"] == "updated_prompt" + assert data["description"] == "Updated description" + assert data["project_id"] == project_id + assert data["active_version"] == str(prompt_version.id) + + +def test_delete_prompt_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successfully soft-deleting a prompt""" + project_id = user_api_key.project_id + prompt, _ = create_test_prompt(db, project_id, name="test_prompt") + + response = client.delete( + f"/api/v1/prompts/{prompt.id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"]["message"] == "Prompt deleted successfully." + + db.refresh(prompt) + assert prompt.is_deleted + assert prompt.deleted_at is not None diff --git a/backend/app/tests/api/routes/test_responses.py b/backend/app/tests/api/routes/test_responses.py index 483119d5..a4bf0399 100644 --- a/backend/app/tests/api/routes/test_responses.py +++ b/backend/app/tests/api/routes/test_responses.py @@ -1,662 +1,662 @@ -from unittest.mock import MagicMock, patch -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient -from sqlmodel import select -import openai - -from app.api.routes.responses import router -from app.models import Project - -# Wrap the router in a FastAPI app instance -app = FastAPI() -app.include_router(router) -client = TestClient(app) - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_success( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header: dict[str, str], -): - """Test the /responses endpoint for successful response creation.""" - - # Setup mock credentials - configure to return different values based on provider - def mock_get_credentials_by_provider(*args, **kwargs): - provider = kwargs.get("provider") - if provider == "openai": - return {"api_key": "test_api_key"} - elif provider == "langfuse": - return { - "public_key": "test_public_key", - "secret_key": "test_secret_key", - "host": "https://cloud.langfuse.com", - } - return None - - mock_get_credential.side_effect = mock_get_credentials_by_provider - - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - - # Configure mock to return the assistant for any call - def return_mock_assistant(*args, **kwargs): - return mock_assistant - - mock_get_assistant.side_effect = return_mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object with proper response ID format - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = None - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_without_vector_store( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header, -): - """Test the /responses endpoint when assistant has no vector store configured.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant without vector store - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = [] # No vector store configured - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object with proper response ID format - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = None - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Get the Glific project ID - glific_project = db.exec(select(Project).where(Project.name == "Glific")).first() - if not glific_project: - pytest.skip("Glific project not found in the database") - - request_data = { - "assistant_id": "assistant_123", - "question": "What is Glific?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify OpenAI client was called without tools - mock_client.responses.create.assert_called_once_with( - model=mock_assistant.model, - previous_response_id=None, - instructions=mock_assistant.instructions, - temperature=mock_assistant.temperature, - input=[{"role": "user", "content": "What is Glific?"}], - ) - - -@patch("app.api.routes.responses.get_assistant_by_id") -def test_responses_endpoint_assistant_not_found( - mock_get_assistant, - db, - user_api_key_header, -): - """Test the /responses endpoint when assistant is not found.""" - # Setup mock assistant to return None (not found) - mock_get_assistant.return_value = None - - request_data = { - "assistant_id": "nonexistent_assistant", - "question": "What is this?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - assert response.status_code == 404 - response_json = response.json() - assert response_json["detail"] == "Assistant not found or not active" - - -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -def test_responses_endpoint_no_openai_credentials( - mock_get_assistant, - mock_get_credential, - db, - user_api_key_header, -): - """Test the /responses endpoint when OpenAI credentials are not configured.""" - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = [] - mock_get_assistant.return_value = mock_assistant - - # Setup mock credentials to return None (no credentials) - mock_get_credential.return_value = None - - request_data = { - "assistant_id": "assistant_123", - "question": "What is this?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is False - assert "OpenAI API key not configured" in response_json["error"] - - -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -def test_responses_endpoint_missing_api_key_in_credentials( - mock_get_assistant, - mock_get_credential, - db, - user_api_key_header, -): - """Test the /responses endpoint when credentials exist but don't have api_key.""" - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = [] - mock_get_assistant.return_value = mock_assistant - - # Setup mock credentials without api_key - mock_get_credential.return_value = {"other_key": "value"} - - request_data = { - "assistant_id": "assistant_123", - "question": "What is this?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is False - assert "OpenAI API key not configured" in response_json["error"] - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_with_file_search_results( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header, -): - """Test the /responses endpoint with file search results in the response.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant with vector store - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup mock file search results - mock_hit1 = MagicMock() - mock_hit1.score = 0.95 - mock_hit1.text = "First search result" - - mock_hit2 = MagicMock() - mock_hit2.score = 0.85 - mock_hit2.text = "Second search result" - - mock_file_search_call = MagicMock() - mock_file_search_call.type = "file_search_call" - mock_file_search_call.results = [mock_hit1, mock_hit2] - - # Setup the mock response object with file search results and proper response ID format - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output with search results" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [mock_file_search_call] - mock_response.previous_response_id = None - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify OpenAI client was called with tools - mock_client.responses.create.assert_called_once() - call_args = mock_client.responses.create.call_args[1] - assert "tools" in call_args - assert call_args["tools"][0]["type"] == "file_search" - assert call_args["tools"][0]["vector_store_ids"] == ["vs_test"] - assert "include" in call_args - assert "file_search_call.results" in call_args["include"] - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_with_ancestor_conversation_found( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header: dict[str, str], -): - """Test the /responses endpoint when a conversation is found by ancestor ID.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = "resp_ancestor1234567890abcdef1234567890" - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Setup mock conversation found by ancestor ID - mock_conversation = MagicMock() - mock_conversation.response_id = "resp_latest1234567890abcdef1234567890" - mock_conversation.ancestor_response_id = "resp_ancestor1234567890abcdef1234567890" - mock_get_conversation_by_ancestor_id.return_value = mock_conversation - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - "response_id": "resp_ancestor1234567890abcdef1234567890", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify get_conversation_by_ancestor_id was called with correct parameters - mock_get_conversation_by_ancestor_id.assert_called_once() - call_args = mock_get_conversation_by_ancestor_id.call_args - assert ( - call_args[1]["ancestor_response_id"] - == "resp_ancestor1234567890abcdef1234567890" - ) - assert call_args[1]["project_id"] == dalgo_project.id - - # Verify OpenAI client was called with the conversation's response_id as previous_response_id - mock_client.responses.create.assert_called_once() - call_args = mock_client.responses.create.call_args[1] - assert call_args["previous_response_id"] == "resp_latest1234567890abcdef1234567890" - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_with_ancestor_conversation_not_found( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header: dict[str, str], -): - """Test the /responses endpoint when no conversation is found by ancestor ID.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = "resp_ancestor1234567890abcdef1234567890" - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Setup mock conversation not found by ancestor ID - mock_get_conversation_by_ancestor_id.return_value = None - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - "response_id": "resp_ancestor1234567890abcdef1234567890", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify get_conversation_by_ancestor_id was called with correct parameters - mock_get_conversation_by_ancestor_id.assert_called_once() - call_args = mock_get_conversation_by_ancestor_id.call_args - assert ( - call_args[1]["ancestor_response_id"] - == "resp_ancestor1234567890abcdef1234567890" - ) - assert call_args[1]["project_id"] == dalgo_project.id - - # Verify OpenAI client was called with the original response_id as previous_response_id - mock_client.responses.create.assert_called_once() - call_args = mock_client.responses.create.call_args[1] - assert ( - call_args["previous_response_id"] == "resp_ancestor1234567890abcdef1234567890" - ) - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_without_response_id( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header: dict[str, str], -): - """Test the /responses endpoint when no response_id is provided.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = None - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_1234567890abcdef1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - # No response_id provided - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify get_conversation_by_ancestor_id was not called since response_id is None - mock_get_conversation_by_ancestor_id.assert_not_called() - - # Verify OpenAI client was called with None as previous_response_id - mock_client.responses.create.assert_called_once() - call_args = mock_client.responses.create.call_args[1] - assert call_args["previous_response_id"] is None +# from unittest.mock import MagicMock, patch +# import pytest +# from fastapi import FastAPI +# from fastapi.testclient import TestClient +# from sqlmodel import select +# import openai + +# from app.api.routes.responses import router +# from app.models import Project + +# # Wrap the router in a FastAPI app instance +# app = FastAPI() +# app.include_router(router) +# client = TestClient(app) + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_success( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header: dict[str, str], +# ): +# """Test the /responses endpoint for successful response creation.""" + +# # Setup mock credentials - configure to return different values based on provider +# def mock_get_credentials_by_provider(*args, **kwargs): +# provider = kwargs.get("provider") +# if provider == "openai": +# return {"api_key": "test_api_key"} +# elif provider == "langfuse": +# return { +# "public_key": "test_public_key", +# "secret_key": "test_secret_key", +# "host": "https://cloud.langfuse.com", +# } +# return None + +# mock_get_credential.side_effect = mock_get_credentials_by_provider + +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 + +# # Configure mock to return the assistant for any call +# def return_mock_assistant(*args, **kwargs): +# return mock_assistant + +# mock_get_assistant.side_effect = return_mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object with proper response ID format +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = None +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_without_vector_store( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint when assistant has no vector store configured.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant without vector store +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = [] # No vector store configured +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object with proper response ID format +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = None +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Get the Glific project ID +# glific_project = db.exec(select(Project).where(Project.name == "Glific")).first() +# if not glific_project: +# pytest.skip("Glific project not found in the database") + +# request_data = { +# "assistant_id": "assistant_123", +# "question": "What is Glific?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify OpenAI client was called without tools +# mock_client.responses.create.assert_called_once_with( +# model=mock_assistant.model, +# previous_response_id=None, +# instructions=mock_assistant.instructions, +# temperature=mock_assistant.temperature, +# input=[{"role": "user", "content": "What is Glific?"}], +# ) + + +# @patch("app.api.routes.responses.get_assistant_by_id") +# def test_responses_endpoint_assistant_not_found( +# mock_get_assistant, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint when assistant is not found.""" +# # Setup mock assistant to return None (not found) +# mock_get_assistant.return_value = None + +# request_data = { +# "assistant_id": "nonexistent_assistant", +# "question": "What is this?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) +# assert response.status_code == 404 +# response_json = response.json() +# assert response_json["detail"] == "Assistant not found or not active" + + +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# def test_responses_endpoint_no_openai_credentials( +# mock_get_assistant, +# mock_get_credential, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint when OpenAI credentials are not configured.""" +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = [] +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock credentials to return None (no credentials) +# mock_get_credential.return_value = None + +# request_data = { +# "assistant_id": "assistant_123", +# "question": "What is this?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is False +# assert "OpenAI API key not configured" in response_json["error"] + + +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# def test_responses_endpoint_missing_api_key_in_credentials( +# mock_get_assistant, +# mock_get_credential, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint when credentials exist but don't have api_key.""" +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = [] +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock credentials without api_key +# mock_get_credential.return_value = {"other_key": "value"} + +# request_data = { +# "assistant_id": "assistant_123", +# "question": "What is this?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is False +# assert "OpenAI API key not configured" in response_json["error"] + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_with_file_search_results( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint with file search results in the response.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant with vector store +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup mock file search results +# mock_hit1 = MagicMock() +# mock_hit1.score = 0.95 +# mock_hit1.text = "First search result" + +# mock_hit2 = MagicMock() +# mock_hit2.score = 0.85 +# mock_hit2.text = "Second search result" + +# mock_file_search_call = MagicMock() +# mock_file_search_call.type = "file_search_call" +# mock_file_search_call.results = [mock_hit1, mock_hit2] + +# # Setup the mock response object with file search results and proper response ID format +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output with search results" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [mock_file_search_call] +# mock_response.previous_response_id = None +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify OpenAI client was called with tools +# mock_client.responses.create.assert_called_once() +# call_args = mock_client.responses.create.call_args[1] +# assert "tools" in call_args +# assert call_args["tools"][0]["type"] == "file_search" +# assert call_args["tools"][0]["vector_store_ids"] == ["vs_test"] +# assert "include" in call_args +# assert "file_search_call.results" in call_args["include"] + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_with_ancestor_conversation_found( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header: dict[str, str], +# ): +# """Test the /responses endpoint when a conversation is found by ancestor ID.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = "resp_ancestor1234567890abcdef1234567890" +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Setup mock conversation found by ancestor ID +# mock_conversation = MagicMock() +# mock_conversation.response_id = "resp_latest1234567890abcdef1234567890" +# mock_conversation.ancestor_response_id = "resp_ancestor1234567890abcdef1234567890" +# mock_get_conversation_by_ancestor_id.return_value = mock_conversation + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# "response_id": "resp_ancestor1234567890abcdef1234567890", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify get_conversation_by_ancestor_id was called with correct parameters +# mock_get_conversation_by_ancestor_id.assert_called_once() +# call_args = mock_get_conversation_by_ancestor_id.call_args +# assert ( +# call_args[1]["ancestor_response_id"] +# == "resp_ancestor1234567890abcdef1234567890" +# ) +# assert call_args[1]["project_id"] == dalgo_project.id + +# # Verify OpenAI client was called with the conversation's response_id as previous_response_id +# mock_client.responses.create.assert_called_once() +# call_args = mock_client.responses.create.call_args[1] +# assert call_args["previous_response_id"] == "resp_latest1234567890abcdef1234567890" + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_with_ancestor_conversation_not_found( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header: dict[str, str], +# ): +# """Test the /responses endpoint when no conversation is found by ancestor ID.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = "resp_ancestor1234567890abcdef1234567890" +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Setup mock conversation not found by ancestor ID +# mock_get_conversation_by_ancestor_id.return_value = None + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# "response_id": "resp_ancestor1234567890abcdef1234567890", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify get_conversation_by_ancestor_id was called with correct parameters +# mock_get_conversation_by_ancestor_id.assert_called_once() +# call_args = mock_get_conversation_by_ancestor_id.call_args +# assert ( +# call_args[1]["ancestor_response_id"] +# == "resp_ancestor1234567890abcdef1234567890" +# ) +# assert call_args[1]["project_id"] == dalgo_project.id + +# # Verify OpenAI client was called with the original response_id as previous_response_id +# mock_client.responses.create.assert_called_once() +# call_args = mock_client.responses.create.call_args[1] +# assert ( +# call_args["previous_response_id"] == "resp_ancestor1234567890abcdef1234567890" +# ) + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_without_response_id( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header: dict[str, str], +# ): +# """Test the /responses endpoint when no response_id is provided.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = None +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_1234567890abcdef1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# # No response_id provided +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify get_conversation_by_ancestor_id was not called since response_id is None +# mock_get_conversation_by_ancestor_id.assert_not_called() + +# # Verify OpenAI client was called with None as previous_response_id +# mock_client.responses.create.assert_called_once() +# call_args = mock_client.responses.create.call_args[1] +# assert call_args["previous_response_id"] is None diff --git a/backend/app/tests/crud/test_prompt_versions.py b/backend/app/tests/crud/test_prompt_versions.py new file mode 100644 index 00000000..5a7a8279 --- /dev/null +++ b/backend/app/tests/crud/test_prompt_versions.py @@ -0,0 +1,239 @@ +from uuid import uuid4 + +import pytest +from fastapi import HTTPException +from sqlmodel import Session + +from app.core.util import now +from app.crud.prompt_versions import ( + create_prompt_version, + delete_prompt_version, + get_next_prompt_version, +) +from app.crud.prompts import create_prompt +from app.models import Prompt, PromptCreate, PromptVersion, PromptVersionCreate +from app.tests.utils.utils import get_project +from app.tests.utils.test_data import create_test_prompt + + +@pytest.fixture +def prompt(db: Session) -> Prompt: + """Fixture to create a reusable prompt""" + project = get_project(db) + prompt, _ = create_test_prompt(db, project.id) + return prompt + + +def test_create_prompt_version_success(db: Session, prompt: Prompt): + """Successfully create a new prompt version""" + project_id = prompt.project_id + version_data = PromptVersionCreate( + instruction="New instruction", commit_message="New version" + ) + + prompt_version = create_prompt_version( + db, prompt_id=prompt.id, prompt_version_in=version_data, project_id=project_id + ) + + assert isinstance(prompt_version, PromptVersion) + assert prompt_version.prompt_id == prompt.id + assert ( + prompt_version.version == 2 + ) # First version created by create_prompt, this is second + assert prompt_version.instruction == version_data.instruction + assert prompt_version.commit_message == version_data.commit_message + assert not prompt_version.is_deleted + + +def test_create_prompt_version_multiple_versions(db: Session, prompt: Prompt): + """Create multiple versions and verify correct version increment""" + project_id = prompt.project_id + version_data_1 = PromptVersionCreate( + instruction="New instruction 1", commit_message="Version 2" + ) + version_data_2 = PromptVersionCreate( + instruction="New instruction 2", commit_message="Version 3" + ) + + version_1 = create_prompt_version( + db, prompt_id=prompt.id, prompt_version_in=version_data_1, project_id=project_id + ) + version_2 = create_prompt_version( + db, prompt_id=prompt.id, prompt_version_in=version_data_2, project_id=project_id + ) + + assert version_1.version == 2 + assert version_2.version == 3 + assert version_1.instruction == "New instruction 1" + assert version_2.instruction == "New instruction 2" + + +def test_create_prompt_version_prompt_not_found(db: Session): + """Raise 404 error when prompt does not exist""" + project = get_project(db) + non_existent_id = uuid4() + version_data = PromptVersionCreate( + instruction="New instruction", commit_message="New version" + ) + + with pytest.raises(HTTPException) as exc_info: + create_prompt_version( + db, + prompt_id=non_existent_id, + prompt_version_in=version_data, + project_id=project.id, + ) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_create_prompt_version_deleted_prompt(db: Session, prompt: Prompt): + """Raise 404 error when prompt is deleted""" + project_id = prompt.project_id + prompt.is_deleted = True + db.add(prompt) + db.commit() + + version_data = PromptVersionCreate( + instruction="New instruction", commit_message="New version" + ) + + with pytest.raises(HTTPException) as exc_info: + create_prompt_version( + db, + prompt_id=prompt.id, + prompt_version_in=version_data, + project_id=project_id, + ) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_get_next_prompt_version(db: Session, prompt: Prompt): + """Return incremented version number when versions exist""" + prompt_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2, + ) + db.add(prompt_version) + db.commit() + + version = get_next_prompt_version(db, prompt_id=prompt.id) + assert version == 3 + + +def test_get_next_prompt_version_with_deleted_version(db: Session, prompt: Prompt): + """Return incremented version number even if the latest version is deleted""" + prompt_version = PromptVersion( + prompt_id=prompt.id, + instruction="Deleted instruction", + commit_message="Deleted version", + version=2, + is_deleted=True, + ) + db.add(prompt_version) + db.commit() + + version = get_next_prompt_version(db, prompt_id=prompt.id) + assert version == 3 + + +def test_delete_prompt_version_success(db: Session, prompt: Prompt): + """Successfully soft-delete a non-active prompt version""" + project_id = prompt.project_id + + # Create a second version (non-active) + second_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2, + ) + db.add(second_version) + db.commit() + + delete_prompt_version( + db, prompt_id=prompt.id, version_id=second_version.id, project_id=project_id + ) + + db.refresh(second_version) + assert second_version.is_deleted + assert second_version.deleted_at is not None + assert second_version.deleted_at <= now() + + +def test_delete_prompt_version_active_version(db: Session, prompt: Prompt): + """Raise 409 error when attempting to delete the active version""" + project_id = prompt.project_id + active_version_id = prompt.active_version + + with pytest.raises(HTTPException) as exc_info: + delete_prompt_version( + db, prompt_id=prompt.id, version_id=active_version_id, project_id=project_id + ) + + assert exc_info.value.status_code == 409 + assert "cannot delete active version" in exc_info.value.detail.lower() + + +def test_delete_prompt_version_not_found(db: Session, prompt: Prompt): + """Raise 404 error when version does not exist""" + project_id = prompt.project_id + non_existent_version_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt_version( + db, + prompt_id=prompt.id, + version_id=non_existent_version_id, + project_id=project_id, + ) + + assert exc_info.value.status_code == 404 + assert "prompt version not found" in exc_info.value.detail.lower() + + +def test_delete_prompt_version_already_deleted(db: Session, prompt: Prompt): + """Raise 404 error when attempting to delete an already deleted version""" + project_id = prompt.project_id + + second_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2, + is_deleted=True, + deleted_at=now(), + ) + db.add(second_version) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt_version( + db, prompt_id=prompt.id, version_id=second_version.id, project_id=project_id + ) + + assert exc_info.value.status_code == 404 + assert "prompt version not found" in exc_info.value.detail.lower() + + +def test_delete_prompt_version_prompt_not_found(db: Session): + """Raise 404 error when prompt does not exist""" + project = get_project(db) + non_existent_prompt_id = uuid4() + version_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt_version( + db, + prompt_id=non_existent_prompt_id, + version_id=version_id, + project_id=project.id, + ) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() diff --git a/backend/app/tests/crud/test_prompts.py b/backend/app/tests/crud/test_prompts.py new file mode 100644 index 00000000..d23cd0f8 --- /dev/null +++ b/backend/app/tests/crud/test_prompts.py @@ -0,0 +1,436 @@ +import pytest +from uuid import uuid4 + +from fastapi import HTTPException +from sqlmodel import Session + +from app.core.util import now +from app.crud.prompts import ( + count_prompts_in_project, + create_prompt, + delete_prompt, + get_prompt_by_id, + get_prompts, + prompt_exists, + update_prompt, +) +from app.models import Prompt, PromptCreate, PromptUpdate, PromptVersion +from app.tests.utils.utils import get_project +from app.tests.utils.test_data import create_test_prompt + + +@pytest.fixture +def prompt(db) -> Prompt: + """Fixture to create a reusable prompt""" + project = get_project(db) + prompt, _ = create_test_prompt(db, project.id) + return prompt + + +def test_create_prompt_success(db: Session): + """Prompt and its first version are created successfully when valid input is provided""" + project = get_project(db) + prompt_data = PromptCreate( + name="test_prompt", + description="This is a test prompt", + instruction="Test instruction", + commit_message="Initial version", + ) + + prompt, version = create_prompt(db, prompt_data, project_id=project.id) + + # Prompt checks + assert prompt.id is not None + assert prompt.name == prompt_data.name + assert prompt.description == prompt_data.description + assert prompt.project_id == project.id + assert prompt.inserted_at is not None + assert prompt.updated_at is not None + assert prompt.active_version == version.id + + # Version checks + assert version.prompt_id == prompt.id + assert version.version == 1 + assert version.instruction == prompt_data.instruction + assert version.commit_message == prompt_data.commit_message + + +def test_get_prompts_success(db: Session): + """Retrieve prompts for a project with pagination, ensuring correct filtering and ordering""" + project = get_project(db) + + create_prompt( + db, + PromptCreate( + name="prompt1", + description="First prompt", + instruction="Instruction 1", + commit_message="Initial", + ), + project_id=project.id, + ) + create_prompt( + db, + PromptCreate( + name="prompt2", + description="Second prompt", + instruction="Instruction 2", + commit_message="Initial", + ), + project_id=project.id, + ) + + prompts = get_prompts(db, project_id=project.id, skip=0, limit=100) + + assert len(prompts) == 2 + assert prompts[0].name == "prompt2" + assert prompts[1].name == "prompt1" + assert all(not prompt.is_deleted for prompt in prompts) + assert all(prompt.project_id == project.id for prompt in prompts) + + prompts_limited = get_prompts(db, project_id=project.id, skip=1, limit=1) + assert len(prompts_limited) == 1 + assert prompts_limited[0].name == "prompt1" + + +def test_get_prompts_empty(db: Session): + """Return empty list when no prompts exist for a project or project has no non-deleted prompts""" + project = get_project(db) + + prompts = get_prompts(db, project_id=project.id) + assert prompts == [] + + # Create a deleted prompt + prompt, _ = create_prompt( + db, + PromptCreate( + name="deleted_prompt", + description="Deleted", + instruction="Instruction", + commit_message="Initial", + ), + project_id=project.id, + ) + prompt.is_deleted = True + db.add(prompt) + db.commit() + + prompts = get_prompts(db, project_id=project.id) + assert prompts == [] + + +def test_count_prompts_in_project_success(db: Session): + """Correctly count non-deleted prompts in a project""" + project = get_project(db) + + # Create multiple prompts + create_prompt( + db, + PromptCreate( + name="prompt1", + description="First prompt", + instruction="Instruction 1", + commit_message="Initial", + ), + project_id=project.id, + ) + create_prompt( + db, + PromptCreate( + name="prompt2", + description="Second prompt", + instruction="Instruction 2", + commit_message="Initial", + ), + project_id=project.id, + ) + + count = count_prompts_in_project(db, project_id=project.id) + assert count == 2 + + +def test_count_prompts_in_project_empty_or_deleted(db: Session): + """Return 0 when no prompts exist or all prompts are deleted""" + project = get_project(db) + + # Test empty project + count = count_prompts_in_project(db, project_id=project.id) + assert count == 0 + + # Create a deleted prompt + prompt, _ = create_prompt( + db, + PromptCreate( + name="deleted_prompt", + description="Deleted", + instruction="Instruction", + commit_message="Initial", + ), + project_id=project.id, + ) + prompt.is_deleted = True + db.add(prompt) + db.commit() + + count = count_prompts_in_project(db, project_id=project.id) + assert count == 0 + + +def test_prompt_exists_success(db: Session, prompt: Prompt): + """Successfully retrieve an existing prompt by ID and project""" + project = get_project(db) # Call get_project as a function + result = prompt_exists(db, prompt_id=prompt.id, project_id=project.id) + + assert isinstance(result, Prompt) + assert result.id == prompt.id + assert result.project_id == project.id + assert result.name == prompt.name + assert result.description == prompt.description + assert not result.is_deleted + + +def test_prompt_exists_not_found(db: Session): + """Raise 404 error when prompt ID does not exist""" + project = get_project(db) # Call get_project as a function + non_existent_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + prompt_exists(db, prompt_id=non_existent_id, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_prompt_exists_deleted_prompt(db: Session, prompt: Prompt): + """Raise 404 error when prompt is deleted""" + project_id = prompt.project_id + prompt.is_deleted = True + db.add(prompt) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + prompt_exists(db, prompt_id=prompt.id, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_get_prompt_by_id_success_active_version(db: Session, prompt: Prompt): + """Retrieve a prompt by ID with only its active version""" + project = get_project(db) + retrieved_prompt, versions = get_prompt_by_id( + db, prompt_id=prompt.id, project_id=project.id, include_versions=False + ) + + assert isinstance(retrieved_prompt, Prompt) + assert isinstance(versions, list) + assert len(versions) == 1 + assert retrieved_prompt.id == prompt.id + assert retrieved_prompt.name == prompt.name + assert retrieved_prompt.description == prompt.description + assert retrieved_prompt.project_id == project.id + assert not retrieved_prompt.is_deleted + assert versions[0].id == prompt.active_version + assert versions[0].instruction == "Test instruction" + assert versions[0].commit_message == "Initial version" + assert versions[0].version == 1 + assert not versions[0].is_deleted + + +def test_get_prompt_by_id_with_versions(db: Session, prompt: Prompt): + """Retrieve a prompt by ID with all its versions""" + project = get_project(db) + + # Add another version + new_version = PromptVersion( + prompt_id=prompt.id, + instruction="Updated instruction", + commit_message="Second version", + version=2, + ) + db.add(new_version) + db.commit() + + retrieved_prompt, versions = get_prompt_by_id( + db, prompt_id=prompt.id, project_id=project.id, include_versions=True + ) + + assert isinstance(retrieved_prompt, Prompt) + assert isinstance(versions, list) + assert len(versions) == 2 + assert retrieved_prompt.id == prompt.id + assert retrieved_prompt.name == prompt.name + assert retrieved_prompt.description == prompt.description + assert retrieved_prompt.project_id == project.id + assert not retrieved_prompt.is_deleted + assert versions[0].version == 2 # Latest version first (descending order) + assert versions[1].version == 1 + assert versions[0].instruction == "Updated instruction" + assert versions[1].instruction == "Test instruction" + assert not any(version.is_deleted for version in versions) + + +def test_get_prompt_by_id_not_found(db: Session): + """Raise 404 error when prompt ID does not exist""" + project = get_project(db) + non_existent_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + get_prompt_by_id(db, prompt_id=non_existent_id, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_get_prompt_by_id_deleted_prompt(db: Session, prompt: Prompt): + """Raise 404 error when prompt is deleted""" + project_id = prompt.project_id + prompt.is_deleted = True + db.add(prompt) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + get_prompt_by_id(db, prompt_id=prompt.id, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_get_prompt_by_id_deleted_version(db: Session, prompt: Prompt): + """Exclude deleted versions when retrieving prompt versions""" + project_id = prompt.project_id + + deleted_version = PromptVersion( + prompt_id=prompt.id, + instruction="Deleted instruction", + commit_message="Deleted version", + version=2, + is_deleted=True, + ) + db.add(deleted_version) + db.commit() + + retrieved_prompt, versions = get_prompt_by_id( + db, prompt_id=prompt.id, project_id=project_id, include_versions=True + ) + + assert isinstance(retrieved_prompt, Prompt) + assert isinstance(versions, list) + assert len(versions) == 1 + assert versions[0].version == 1 + assert not versions[0].is_deleted + assert versions[0].instruction == "Test instruction" + + +def test_update_prompt_success_name_description(db: Session, prompt: Prompt): + """Successfully update prompt's name and description""" + project_id = prompt.project_id + update_data = PromptUpdate(name="updated_prompt", description="Updated description") + + updated_prompt = update_prompt( + db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data + ) + + assert isinstance(updated_prompt, Prompt) + assert updated_prompt.id == prompt.id + assert updated_prompt.name == "updated_prompt" + assert updated_prompt.description == "Updated description" + assert updated_prompt.project_id == project_id + assert not updated_prompt.is_deleted + + +def test_update_prompt_success_active_version(db: Session, prompt: Prompt): + """Successfully update prompt's active version""" + project_id = prompt.project_id + + # Create a new version + new_version = PromptVersion( + prompt_id=prompt.id, + instruction="New instruction", + commit_message="Second version", + version=2, + ) + db.add(new_version) + db.commit() + + update_data = PromptUpdate(active_version=new_version.id) + updated_prompt = update_prompt( + db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data + ) + + assert isinstance(updated_prompt, Prompt) + assert updated_prompt.id == prompt.id + assert updated_prompt.active_version == new_version.id + + +def test_update_prompt_invalid_active_version(db: Session, prompt: Prompt): + """Raise 404 error when updating with an invalid active version ID""" + project_id = prompt.project_id + invalid_version_id = uuid4() + + update_data = PromptUpdate(active_version=invalid_version_id) + + with pytest.raises(HTTPException) as exc_info: + update_prompt( + db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data + ) + + assert exc_info.value.status_code == 404 + assert "invalid active version id" in exc_info.value.detail.lower() + + +def test_update_prompt_not_found(db: Session): + """Raise 404 error when prompt does not exist""" + project = get_project(db) + non_existent_id = uuid4() + update_data = PromptUpdate(name="new_name") + + with pytest.raises(HTTPException) as exc_info: + update_prompt( + db, + prompt_id=non_existent_id, + project_id=project.id, + prompt_update=update_data, + ) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_delete_prompt_success(db: Session, prompt: Prompt): + """Successfully soft delete a prompt""" + project_id = prompt.project_id + + delete_prompt(db, prompt_id=prompt.id, project_id=project_id) + + db.refresh(prompt) + assert prompt.is_deleted + assert prompt.deleted_at is not None + assert prompt.deleted_at <= now() + + +def test_delete_prompt_not_found(db: Session): + """Raise 404 error when deleting a non-existent prompt""" + project = get_project(db) + non_existent_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt(db, prompt_id=non_existent_id, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_delete_prompt_already_deleted(db: Session, prompt: Prompt): + """Raise 404 error when attempting to delete an already deleted prompt""" + project_id = prompt.project_id + prompt.is_deleted = True + prompt.deleted_at = now() + db.add(prompt) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt(db, prompt_id=prompt.id, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 0d9c5d3b..996597ce 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -7,11 +7,14 @@ Credential, OrganizationCreate, ProjectCreate, + PromptCreate, + PromptVersion, CredsCreate, ) from app.crud import ( create_organization, create_project, + create_prompt, create_api_key, set_creds_for_org, ) @@ -112,3 +115,21 @@ def create_test_credential(db: Session) -> tuple[list[Credential], Project]: }, ) return set_creds_for_org(session=db, creds_add=creds_data), project + + +def create_test_prompt( + db: Session, + project_id: str, + name: str = "test_prompt", + description: str = "Test prompt description", + instruction: str = "Test instruction", + commit_message: str = "Initial version", +) -> tuple[PromptCreate, PromptVersion]: + """Helper function to create a PromptCreate object and persist it in the database.""" + prompt_in = PromptCreate( + name=name, + description=description, + instruction=instruction, + commit_message=commit_message, + ) + return create_prompt(db, prompt_in=prompt_in, project_id=project_id)