From 56127c910462f194bada578ee3444eb8e64a22f9 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 12 Aug 2025 09:40:52 +0530 Subject: [PATCH 1/7] Model for prompt and prompt version --- ...88eb2c84d9_add_prompt_and_version_table.py | 59 +++++++++++++++++ backend/app/models/__init__.py | 13 ++++ backend/app/models/prompt.py | 63 +++++++++++++++++++ backend/app/models/prompt_version.py | 50 +++++++++++++++ 4 files changed, 185 insertions(+) create mode 100644 backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py create mode 100644 backend/app/models/prompt.py create mode 100644 backend/app/models/prompt_version.py diff --git a/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py b/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py new file mode 100644 index 00000000..26c95673 --- /dev/null +++ b/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py @@ -0,0 +1,59 @@ +"""Add prompt and version table + +Revision ID: 7288eb2c84d9 +Revises: e9dd35eff62c +Create Date: 2025-08-12 09:29:27.335097 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '7288eb2c84d9' +down_revision = 'e9dd35eff62c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('prompt', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('active_version', sa.Uuid(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('inserted_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('is_deleted', sa.Boolean(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['active_version'], ['prompt_version.id'], initially='DEFERRED', deferrable=True, use_alter=True), + sa.ForeignKeyConstraint(['project_id'], ['project.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_prompt_name'), 'prompt', ['name'], unique=False) + op.create_table('prompt_version', + sa.Column('instruction', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('commit_message', sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('prompt_id', sa.Uuid(), nullable=False), + sa.Column('version', sa.Integer(), nullable=False), + sa.Column('inserted_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('is_deleted', sa.Boolean(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['prompt_id'], ['prompt.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_foreign_key(None, 'prompt', 'prompt_version', ['active_version'], ['id'], initially='DEFERRED', deferrable=True, use_alter=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('prompt_version') + op.drop_index(op.f('ix_prompt_name'), table_name='prompt') + op.drop_table('prompt') + # ### end Alembic commands ### diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index a1c2009c..bf440da2 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -62,3 +62,16 @@ OpenAIConversationBase, OpenAIConversationCreate, ) + +from .prompt import ( + Prompt, + PromptCreate, + PromptPublic, + PromptUpdate, +) + +from .prompt_version import ( + PromptVersion, + PromptVersionCreate, + PromptVersionPublic, +) \ No newline at end of file diff --git a/backend/app/models/prompt.py b/backend/app/models/prompt.py new file mode 100644 index 00000000..3a719979 --- /dev/null +++ b/backend/app/models/prompt.py @@ -0,0 +1,63 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from sqlmodel import SQLModel, Field, Relationship +from sqlalchemy import Column, ForeignKey + +from app.core.util import now +from app.models.prompt_version import PromptVersion, PromptVersionCreate + + +class PromptBase(SQLModel): + name: str = Field(index=True, nullable=False, min_length=1, max_length=50) + description: str | None = Field(default=None, min_length=1, max_length=500) + + +class Prompt(PromptBase, table=True): + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + ) + active_version: UUID = Field( + default_factory=uuid4, + sa_column=Column( + ForeignKey( + "prompt_version.id", + use_alter=True, + deferrable=True, + initially="DEFERRED" + ), + nullable=False + ) + ) + project_id: int = Field(foreign_key="project.id") + inserted_at: datetime = Field(default_factory=now, nullable=False) + updated_at: datetime = Field(default_factory=now, nullable=False) + is_deleted: bool = Field(default=False) + deleted_at: datetime | None = None + + versions: list["PromptVersion"] = Relationship( + back_populates="prompt", + sa_relationship_kwargs={"foreign_keys": "[PromptVersion.prompt_id]"} + ) + + +class PromptPublic(PromptBase): + id: UUID + active_version: UUID + project_id: int + inserted_at: datetime + updated_at: datetime + + +class PromptCreate(PromptBase, PromptVersionCreate): + pass + + +class PromptUpdate(SQLModel): + name: str | None = Field(default=None, min_length=1, max_length=50) + description: str | None = Field(default=None, min_length=1, max_length=500) + + class Config: + from_attributes = True diff --git a/backend/app/models/prompt_version.py b/backend/app/models/prompt_version.py new file mode 100644 index 00000000..bbab278c --- /dev/null +++ b/backend/app/models/prompt_version.py @@ -0,0 +1,50 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from sqlmodel import SQLModel, Field, Relationship + +from app.core.util import now + +if TYPE_CHECKING: + from app.models.prompt import Prompt + + +class PromptVersionBase(SQLModel): + instruction: str = Field(nullable=False, min_length=1) + commit_message: str | None = Field(default=None, max_length=512) + + +class PromptVersion(PromptVersionBase, table=True): + __tablename__ = "prompt_version" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + ) + prompt_id: UUID = Field(foreign_key="prompt.id") + version: int + + inserted_at: datetime = Field(default_factory=now) + updated_at: datetime = Field(default_factory=now) + + is_deleted: bool = Field(default=False) + deleted_at: datetime | None = Field(default=None) + + prompt: "Prompt" = Relationship( + back_populates="versions", + sa_relationship_kwargs={"foreign_keys": "[PromptVersion.prompt_id]"} + ) + + + +class PromptVersionPublic(PromptVersionBase): + id: UUID + prompt_id: UUID + version: int + inserted_at: datetime + updated_at: datetime + + +class PromptVersionCreate(PromptVersionBase): + pass From 9fc65af291f8d944e94d5c129d3645df39ae0678 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Wed, 13 Aug 2025 09:38:30 +0530 Subject: [PATCH 2/7] crud/api for prompt and prompt version --- ...88eb2c84d9_add_prompt_and_version_table.py | 93 +++++--- backend/app/api/main.py | 4 + backend/app/api/routes/prompt_versions.py | 55 +++++ backend/app/api/routes/prompts.py | 144 ++++++++++++ backend/app/crud/__init__.py | 11 + backend/app/crud/prompt_versions.py | 100 +++++++++ backend/app/crud/prompts.py | 212 ++++++++++++++++++ backend/app/models/__init__.py | 4 +- backend/app/models/prompt.py | 26 ++- backend/app/models/prompt_version.py | 18 +- 10 files changed, 616 insertions(+), 51 deletions(-) create mode 100644 backend/app/api/routes/prompt_versions.py create mode 100644 backend/app/api/routes/prompts.py create mode 100644 backend/app/crud/prompt_versions.py create mode 100644 backend/app/crud/prompts.py diff --git a/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py b/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py index 26c95673..346b3f29 100644 --- a/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py +++ b/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py @@ -11,49 +11,78 @@ # revision identifiers, used by Alembic. -revision = '7288eb2c84d9' -down_revision = 'e9dd35eff62c' +revision = "7288eb2c84d9" +down_revision = "e9dd35eff62c" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('prompt', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), - sa.Column('description', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), - sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('active_version', sa.Uuid(), nullable=False), - sa.Column('project_id', sa.Integer(), nullable=False), - sa.Column('inserted_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.Column('is_deleted', sa.Boolean(), nullable=False), - sa.Column('deleted_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['active_version'], ['prompt_version.id'], initially='DEFERRED', deferrable=True, use_alter=True), - sa.ForeignKeyConstraint(['project_id'], ['project.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "prompt", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column( + "description", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True + ), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("active_version", sa.Uuid(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("is_deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["active_version"], + ["prompt_version.id"], + initially="DEFERRED", + deferrable=True, + use_alter=True, + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f('ix_prompt_name'), 'prompt', ['name'], unique=False) - op.create_table('prompt_version', - sa.Column('instruction', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('commit_message', sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True), - sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('prompt_id', sa.Uuid(), nullable=False), - sa.Column('version', sa.Integer(), nullable=False), - sa.Column('inserted_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.Column('is_deleted', sa.Boolean(), nullable=False), - sa.Column('deleted_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['prompt_id'], ['prompt.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index(op.f("ix_prompt_name"), "prompt", ["name"], unique=False) + op.create_table( + "prompt_version", + sa.Column("instruction", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "commit_message", + sqlmodel.sql.sqltypes.AutoString(length=512), + nullable=True, + ), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("prompt_id", sa.Uuid(), nullable=False), + sa.Column("version", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("is_deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["prompt_id"], + ["prompt.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_foreign_key( + None, + "prompt", + "prompt_version", + ["active_version"], + ["id"], + initially="DEFERRED", + deferrable=True, + use_alter=True, ) - op.create_foreign_key(None, 'prompt', 'prompt_version', ['active_version'], ['id'], initially='DEFERRED', deferrable=True, use_alter=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('prompt_version') - op.drop_index(op.f('ix_prompt_name'), table_name='prompt') - op.drop_table('prompt') + op.drop_table("prompt_version") + op.drop_index(op.f("ix_prompt_name"), table_name="prompt") + op.drop_table("prompt") # ### end Alembic commands ### diff --git a/backend/app/api/main.py b/backend/app/api/main.py index df0b1016..abf46768 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -10,6 +10,8 @@ openai_conversation, project, project_user, + prompts, + prompt_versions, responses, private, threads, @@ -32,6 +34,8 @@ api_router.include_router(organization.router) api_router.include_router(project.router) api_router.include_router(project_user.router) +api_router.include_router(prompts.router) +api_router.include_router(prompt_versions.router) api_router.include_router(responses.router) api_router.include_router(threads.router) api_router.include_router(users.router) diff --git a/backend/app/api/routes/prompt_versions.py b/backend/app/api/routes/prompt_versions.py new file mode 100644 index 00000000..6ef9579a --- /dev/null +++ b/backend/app/api/routes/prompt_versions.py @@ -0,0 +1,55 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends +from sqlmodel import Session + +from app.api.deps import CurrentUserOrgProject, get_db +from app.crud import create_prompt_version, delete_prompt_version +from app.models import PromptVersionCreate, PromptVersionPublic +from app.utils import APIResponse + + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/prompts", tags=["Prompt Versions"]) + + +@router.post( + "/{prompt_id}/versions", + response_model=APIResponse[PromptVersionPublic], + status_code=201, +) +def create_prompt_version_route( + prompt_version_in: PromptVersionCreate, + prompt_id: UUID, + current_user: CurrentUserOrgProject, + session: Session = Depends(get_db), +): + version = create_prompt_version( + session=session, + prompt_id=prompt_id, + prompt_version_in=prompt_version_in, + project_id=current_user.project_id, + ) + return APIResponse.success_response(version) + + +@router.delete("/{prompt_id}/versions/{version_id}", response_model=APIResponse) +def delete_prompt_version_route( + prompt_id: UUID, + version_id: UUID, + current_user: CurrentUserOrgProject, + session: Session = Depends(get_db), +): + """ + Delete a prompt version by ID. + """ + delete_prompt_version( + session=session, + prompt_id=prompt_id, + version_id=version_id, + project_id=current_user.project_id + ) + return APIResponse.success_response( + data={"message": "Prompt version deleted successfully."} + ) \ No newline at end of file diff --git a/backend/app/api/routes/prompts.py b/backend/app/api/routes/prompts.py new file mode 100644 index 00000000..03fb5195 --- /dev/null +++ b/backend/app/api/routes/prompts.py @@ -0,0 +1,144 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends, Path, Query +from sqlmodel import Session + +from app.api.deps import CurrentUserOrgProject, get_db +from app.crud import ( + create_prompt, + delete_prompt, + get_prompt_by_id, + get_prompts, + count_prompts_in_project, + update_prompt, +) +from app.models import ( + PromptCreate, + PromptPublic, + PromptUpdate, + PromptWithVersion, + PromptWithVersions, +) +from app.utils import APIResponse + + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/prompts", tags=["Prompts"]) + + +@router.post("/", response_model=APIResponse[PromptWithVersion], status_code=201) +def create_prompt_route( + prompt_in: PromptCreate, + current_user: CurrentUserOrgProject, + session: Session = Depends(get_db), +): + """ + Create a new prompt under the specified organization and project. + """ + prompt = create_prompt( + session=session, prompt_in=prompt_in, project_id=current_user.project_id + ) + + return APIResponse.success_response(prompt) + + +@router.get( + "/", + response_model=APIResponse[list[PromptPublic]], +) +def get_prompts_route( + current_user: CurrentUserOrgProject, + skip: int = Query( + 0, + ge=0, + description="Number of prompts to skip (for pagination)." + ), + limit: int = Query( + 100, + gt=0, + description="Maximum number of prompts to return." + ), + session: Session = Depends(get_db), +): + """ + Get all prompts for the specified organization and project. + """ + prompts = get_prompts( + session=session, + project_id=current_user.project_id, + skip=skip, + limit=limit, + ) + total = count_prompts_in_project(session=session, project_id=current_user.project_id) + metadata = { + "pagination": { + "total": total, + "skip": skip, + "limit": limit + } + } + return APIResponse.success_response(prompts, metadata=metadata) + + +@router.get( + "/{prompt_id}", + response_model=APIResponse[PromptWithVersions], + summary="Get a single prompt by its ID by default returns the active version", +) +def get_prompt_by_id_route( + current_user: CurrentUserOrgProject, + prompt_id: UUID = Path(..., description="The ID of the prompt to fetch"), + include_versions: bool = Query( + False, + description="Whether to include all versions of the prompt." + ), + session: Session = Depends(get_db), +): + """ + Get a single prompt by its ID. + """ + prompt = get_prompt_by_id( + session=session, + prompt_id=prompt_id, + project_id=current_user.project_id, + include_versions=include_versions + ) + return APIResponse.success_response(prompt) + + +@router.patch("/{prompt_id}", response_model=APIResponse[PromptPublic]) +def update_prompt_route( + current_user: CurrentUserOrgProject, + prompt_update: PromptUpdate, + prompt_id: UUID = Path(..., description="The ID of the prompt to Update"), + session: Session = Depends(get_db), +): + """ + Update a prompt's name or description. + """ + + prompt = update_prompt( + session=session, + prompt_id=prompt_id, + project_id=current_user.project_id, + prompt_update=prompt_update, + ) + return APIResponse.success_response(prompt) + + +@router.delete("/{prompt_id}", response_model=APIResponse) +def delete_prompt_route( + current_user: CurrentUserOrgProject, + prompt_id: UUID = Path(..., description="The ID of the prompt to delete"), + session: Session = Depends(get_db), +): + """ + Delete a prompt by ID. + """ + delete_prompt( + session=session, prompt_id=prompt_id, project_id=current_user.project_id + ) + return APIResponse.success_response( + data={"message": "Prompt deleted successfully."} + ) \ No newline at end of file diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index e4b973a0..31573155 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -64,3 +64,14 @@ create_conversation, delete_conversation, ) + +from .prompt_versions import create_prompt_version, delete_prompt_version + +from .prompts import ( + create_prompt, + count_prompts_in_project, + delete_prompt, + get_prompt_by_id, + get_prompts, + update_prompt, +) diff --git a/backend/app/crud/prompt_versions.py b/backend/app/crud/prompt_versions.py new file mode 100644 index 00000000..bb54c4cc --- /dev/null +++ b/backend/app/crud/prompt_versions.py @@ -0,0 +1,100 @@ +import logging +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, and_, select + +from app.core.util import now +from app.crud.prompts import prompt_exists +from app.models import PromptVersion, PromptVersionCreate + +logger = logging.getLogger(__name__) + + +def get_next_prompt_version(session: Session, prompt_id: int) -> int: + """ + fetch the next prompt version for a given prompt_id and project_id + """ + + # Not filtering is_deleted here because we want to get the next version even if the latest version is deleted + prompt_version = session.exec( + select(PromptVersion) + .where( + PromptVersion.prompt_id == prompt_id, + ) + .order_by(PromptVersion.version.desc()) + ).first() + + return prompt_version.version + 1 if prompt_version else 1 + + +def create_prompt_version( + session: Session, + prompt_id: UUID, + prompt_version_in: PromptVersionCreate, + project_id: int, +) -> PromptVersion: + """Create a new version for an existing prompt.""" + + prompt_exists( + session=session, + prompt_id=prompt_id, + project_id=project_id, + ) + + next_version = get_next_prompt_version(session=session, prompt_id=prompt_id) + prompt_version = PromptVersion( + prompt_id=prompt_id, + version=next_version, + instruction=prompt_version_in.instruction, + commit_message=prompt_version_in.commit_message, + ) + + session.add(prompt_version) + session.commit() + session.refresh(prompt_version) + logger.info( + f"[create_prompt_version] Created new version prompt_version | Prompt ID: {prompt_id}, Version: {prompt_version.version}" + ) + return prompt_version + + +def delete_prompt_version( + session: Session, + prompt_id: UUID, + version_id: UUID, + project_id: int +): + """ + Delete a prompt version by ID. + """ + prompt = prompt_exists( + session=session, + prompt_id=prompt_id, + project_id=project_id, + ) + if prompt.active_version == version_id: + logger.error(f"[delete_prompt_version] Cannot delete active version | Version ID: {version_id}, Prompt ID: {prompt_id}") + raise HTTPException(status_code=409, detail="Cannot delete active version") + + stmt = select(PromptVersion).where( + and_( + PromptVersion.id == version_id, + PromptVersion.prompt_id == prompt_id, + PromptVersion.is_deleted.is_(False) + ) + ) + prompt_version = session.exec(stmt).first() + + if not prompt_version: + logger.error(f"[delete_prompt_version] Prompt version not found | version_id={version_id}, prompt_id={prompt_id}") + raise HTTPException(status_code=404, detail="Prompt version not found") + + prompt_version.is_deleted = True + prompt_version.deleted_at = now() + + session.add(prompt_version) + session.commit() + session.refresh(prompt_version) + + logger.info(f"[delete_prompt_version] Deleted prompt version | Version ID: {version_id}, Prompt ID: {prompt_id}") \ No newline at end of file diff --git a/backend/app/crud/prompts.py b/backend/app/crud/prompts.py new file mode 100644 index 00000000..2c6c3105 --- /dev/null +++ b/backend/app/crud/prompts.py @@ -0,0 +1,212 @@ +import logging +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, and_, func, select + +from app.core.util import now +from app.models import ( + Prompt, + PromptCreate, + PromptUpdate, + PromptVersion, + PromptWithVersion, + PromptWithVersions, +) + +logger = logging.getLogger(__name__) + + +def create_prompt(session: Session, prompt_in: PromptCreate, project_id: int) -> PromptWithVersion: + """ + Create a new prompt and its first version. + """ + prompt = Prompt( + name=prompt_in.name, + description=prompt_in.description, + project_id=project_id + ) + session.add(prompt) + session.flush() + + version = PromptVersion( + prompt_id=prompt.id, + instruction=prompt_in.instruction, + commit_message=prompt_in.commit_message, + version=1 + ) + session.add(version) + session.flush() + + prompt.active_version = version.id + + session.commit() + session.refresh(prompt) + + logger.info( + f"[create_prompt] Prompt created | id={prompt.id}, name={prompt.name}, " + f"project_id={project_id}, version_id={version.id}" + ) + + return PromptWithVersion( + **prompt.model_dump(), + version=version + ) + + +def get_prompts( + session: Session, + project_id: int, + skip: int = 0, + limit: int = 100, +) -> list[Prompt]: + """Get prompts for a project.""" + stmt = ( + select(Prompt) + .where( + Prompt.project_id == project_id, + Prompt.is_deleted.is_(False) + ) + .order_by(Prompt.updated_at.desc()) + .offset(skip) + .limit(limit) + ) + return session.exec(stmt).all() + + +def count_prompts_in_project(session: Session, project_id: int) -> int: + return session.exec( + select(func.count()).select_from( + select(Prompt) + .where(Prompt.project_id == project_id, Prompt.is_deleted == False) + .subquery() + ) + ).one() + + +def prompt_exists(session: Session, prompt_id: UUID, project_id: int) -> Prompt: + """ + Check if a prompt exists for the given ID and project. + """ + stmt = ( + select(Prompt) + .where( + Prompt.id == prompt_id, + Prompt.project_id == project_id, + Prompt.is_deleted.is_(False) + ) + ) + + prompt = session.exec(stmt).first() + if not prompt: + logger.error( + f"[update_prompt] Prompt not found | prompt_id={prompt_id}, project_id={project_id}" + ) + raise HTTPException(status_code=404, detail="Prompt not found.") + + return prompt + + +def get_prompt_by_id( + session: Session, + prompt_id: UUID, + project_id: int, + include_versions: bool = False +) -> PromptWithVersions: + """ + Get a prompt by its ID, optionally including all versions. + By default, Always returns the active version. + """ + if include_versions: + join_condition = and_( + PromptVersion.prompt_id == Prompt.id, + PromptVersion.is_deleted.is_(False) + ) + order_by = PromptVersion.version.desc() + else: + join_condition = and_( + PromptVersion.id == Prompt.active_version, + PromptVersion.is_deleted.is_(False) + ) + order_by = None # no need to order when fetching only 1 row + + stmt = ( + select(Prompt, PromptVersion) + .join(PromptVersion, join_condition, isouter=True) + .where( + Prompt.id == prompt_id, + Prompt.project_id == project_id, + Prompt.is_deleted.is_(False) + ) + ) + if order_by is not None: + stmt = stmt.order_by(order_by) + + results = session.exec(stmt).all() + if not results: + logger.error(f"[get_prompt_by_id] Prompt not found | ID: {prompt_id}, Project ID: {project_id}") + raise HTTPException(status_code=404, detail="Prompt not found") + + # Unpack tuples into variables + prompt, _ = results[0] + versions = [version for _, version in results if version is not None] + + return PromptWithVersions( + **prompt.model_dump(), + versions=versions + ) + + +def update_prompt( + session: Session, prompt_id: UUID, project_id: int, prompt_update: PromptUpdate +) -> Prompt: + prompt = prompt_exists( + session=session, prompt_id=prompt_id, project_id=project_id + ) + update_prompt = prompt_update.model_dump(exclude_unset=True) + + active_version = update_prompt.get('active_version') + if active_version: + stmt = ( + select(PromptVersion) + .where( + PromptVersion.id == active_version, + PromptVersion.prompt_id == prompt.id, + PromptVersion.is_deleted.is_(False) + ) + ) + prompt_version = session.exec(stmt).first() + if not prompt_version: + logger.error( + f"[update_prompt] Prompt version not found | version_id={active_version}, prompt_id={prompt.id}" + ) + raise HTTPException(status_code=404, detail="Invalid Active Version Id") + + if update_prompt: + for field, value in update_prompt.items(): + setattr(prompt, field, value) + prompt.updated_at = now() + session.add(prompt) + session.commit() + session.refresh(prompt) + + logger.info( + f"[update_prompt] Prompt updated | id={prompt.id}, name={prompt.name}, project_id={project_id}" + ) + + return prompt + + +def delete_prompt(session: Session, prompt_id: UUID, project_id: int) -> None: + prompt = prompt_exists( + session=session, prompt_id=prompt_id, project_id=project_id + ) + + prompt.is_deleted = True + prompt.deleted_at = now() + session.add(prompt) + session.commit() + session.refresh(prompt) + logger.info( + f"[delete_prompt] Prompt deleted | id={prompt.id}, name={prompt.name}, project_id={project_id}" + ) \ No newline at end of file diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index bf440da2..07b8354c 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -68,10 +68,12 @@ PromptCreate, PromptPublic, PromptUpdate, + PromptWithVersion, + PromptWithVersions, ) from .prompt_version import ( PromptVersion, PromptVersionCreate, PromptVersionPublic, -) \ No newline at end of file +) diff --git a/backend/app/models/prompt.py b/backend/app/models/prompt.py index 3a719979..269fb45d 100644 --- a/backend/app/models/prompt.py +++ b/backend/app/models/prompt.py @@ -1,12 +1,15 @@ from datetime import datetime -from typing import TYPE_CHECKING from uuid import UUID, uuid4 -from sqlmodel import SQLModel, Field, Relationship from sqlalchemy import Column, ForeignKey +from sqlmodel import SQLModel, Field, Relationship from app.core.util import now -from app.models.prompt_version import PromptVersion, PromptVersionCreate +from app.models.prompt_version import ( + PromptVersion, + PromptVersionCreate, + PromptVersionPublic, +) class PromptBase(SQLModel): @@ -26,10 +29,10 @@ class Prompt(PromptBase, table=True): "prompt_version.id", use_alter=True, deferrable=True, - initially="DEFERRED" + initially="DEFERRED", ), - nullable=False - ) + nullable=False, + ), ) project_id: int = Field(foreign_key="project.id") inserted_at: datetime = Field(default_factory=now, nullable=False) @@ -39,7 +42,7 @@ class Prompt(PromptBase, table=True): versions: list["PromptVersion"] = Relationship( back_populates="prompt", - sa_relationship_kwargs={"foreign_keys": "[PromptVersion.prompt_id]"} + sa_relationship_kwargs={"foreign_keys": "[PromptVersion.prompt_id]"}, ) @@ -51,6 +54,14 @@ class PromptPublic(PromptBase): updated_at: datetime +class PromptWithVersion(PromptPublic): + version: PromptVersionPublic + + +class PromptWithVersions(PromptPublic): + versions: list[PromptVersionPublic] + + class PromptCreate(PromptBase, PromptVersionCreate): pass @@ -58,6 +69,7 @@ class PromptCreate(PromptBase, PromptVersionCreate): class PromptUpdate(SQLModel): name: str | None = Field(default=None, min_length=1, max_length=50) description: str | None = Field(default=None, min_length=1, max_length=500) + active_version: UUID | None = Field(default=None) class Config: from_attributes = True diff --git a/backend/app/models/prompt_version.py b/backend/app/models/prompt_version.py index bbab278c..aedaa1b8 100644 --- a/backend/app/models/prompt_version.py +++ b/backend/app/models/prompt_version.py @@ -18,26 +18,22 @@ class PromptVersionBase(SQLModel): class PromptVersion(PromptVersionBase, table=True): __tablename__ = "prompt_version" - id: UUID = Field( - default_factory=uuid4, - primary_key=True, - ) - prompt_id: UUID = Field(foreign_key="prompt.id") - version: int + id: UUID = Field(default_factory=uuid4, primary_key=True) + prompt_id: UUID = Field(foreign_key="prompt.id", nullable=False) + version: int = Field(nullable=False) - inserted_at: datetime = Field(default_factory=now) - updated_at: datetime = Field(default_factory=now) + inserted_at: datetime = Field(default_factory=now, nullable=False) + updated_at: datetime = Field(default_factory=now, nullable=False) - is_deleted: bool = Field(default=False) + is_deleted: bool = Field(default=False, nullable=False) deleted_at: datetime | None = Field(default=None) prompt: "Prompt" = Relationship( back_populates="versions", - sa_relationship_kwargs={"foreign_keys": "[PromptVersion.prompt_id]"} + sa_relationship_kwargs={"foreign_keys": "[PromptVersion.prompt_id]"}, ) - class PromptVersionPublic(PromptVersionBase): id: UUID prompt_id: UUID From 948f4473f50e2f2aba56ebe86bc799f4af1c1a45 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:58:32 +0530 Subject: [PATCH 3/7] handle prompt and version in routes --- backend/app/api/routes/prompt_versions.py | 4 +- backend/app/api/routes/prompts.py | 40 +++++------- backend/app/crud/prompt_versions.py | 23 ++++--- backend/app/crud/prompts.py | 79 +++++++++-------------- 4 files changed, 59 insertions(+), 87 deletions(-) diff --git a/backend/app/api/routes/prompt_versions.py b/backend/app/api/routes/prompt_versions.py index 6ef9579a..a16376f3 100644 --- a/backend/app/api/routes/prompt_versions.py +++ b/backend/app/api/routes/prompt_versions.py @@ -48,8 +48,8 @@ def delete_prompt_version_route( session=session, prompt_id=prompt_id, version_id=version_id, - project_id=current_user.project_id + project_id=current_user.project_id, ) return APIResponse.success_response( data={"message": "Prompt version deleted successfully."} - ) \ No newline at end of file + ) diff --git a/backend/app/api/routes/prompts.py b/backend/app/api/routes/prompts.py index 03fb5195..2ff6a774 100644 --- a/backend/app/api/routes/prompts.py +++ b/backend/app/api/routes/prompts.py @@ -36,11 +36,11 @@ def create_prompt_route( """ Create a new prompt under the specified organization and project. """ - prompt = create_prompt( + prompt, version = create_prompt( session=session, prompt_in=prompt_in, project_id=current_user.project_id ) - - return APIResponse.success_response(prompt) + prompt_with_version = PromptWithVersion(**prompt.model_dump(), version=version) + return APIResponse.success_response(prompt_with_version) @router.get( @@ -50,15 +50,9 @@ def create_prompt_route( def get_prompts_route( current_user: CurrentUserOrgProject, skip: int = Query( - 0, - ge=0, - description="Number of prompts to skip (for pagination)." - ), - limit: int = Query( - 100, - gt=0, - description="Maximum number of prompts to return." + 0, ge=0, description="Number of prompts to skip (for pagination)." ), + limit: int = Query(100, gt=0, description="Maximum number of prompts to return."), session: Session = Depends(get_db), ): """ @@ -70,14 +64,10 @@ def get_prompts_route( skip=skip, limit=limit, ) - total = count_prompts_in_project(session=session, project_id=current_user.project_id) - metadata = { - "pagination": { - "total": total, - "skip": skip, - "limit": limit - } - } + total = count_prompts_in_project( + session=session, project_id=current_user.project_id + ) + metadata = {"pagination": {"total": total, "skip": skip, "limit": limit}} return APIResponse.success_response(prompts, metadata=metadata) @@ -90,21 +80,21 @@ def get_prompt_by_id_route( current_user: CurrentUserOrgProject, prompt_id: UUID = Path(..., description="The ID of the prompt to fetch"), include_versions: bool = Query( - False, - description="Whether to include all versions of the prompt." + False, description="Whether to include all versions of the prompt." ), session: Session = Depends(get_db), ): """ Get a single prompt by its ID. """ - prompt = get_prompt_by_id( + prompt, versions = get_prompt_by_id( session=session, prompt_id=prompt_id, project_id=current_user.project_id, - include_versions=include_versions + include_versions=include_versions, ) - return APIResponse.success_response(prompt) + prompt_with_versions = PromptWithVersions(**prompt.model_dump(), versions=versions) + return APIResponse.success_response(prompt_with_versions) @router.patch("/{prompt_id}", response_model=APIResponse[PromptPublic]) @@ -141,4 +131,4 @@ def delete_prompt_route( ) return APIResponse.success_response( data={"message": "Prompt deleted successfully."} - ) \ No newline at end of file + ) diff --git a/backend/app/crud/prompt_versions.py b/backend/app/crud/prompt_versions.py index bb54c4cc..2c307b22 100644 --- a/backend/app/crud/prompt_versions.py +++ b/backend/app/crud/prompt_versions.py @@ -60,10 +60,7 @@ def create_prompt_version( def delete_prompt_version( - session: Session, - prompt_id: UUID, - version_id: UUID, - project_id: int + session: Session, prompt_id: UUID, version_id: UUID, project_id: int ): """ Delete a prompt version by ID. @@ -74,27 +71,33 @@ def delete_prompt_version( project_id=project_id, ) if prompt.active_version == version_id: - logger.error(f"[delete_prompt_version] Cannot delete active version | Version ID: {version_id}, Prompt ID: {prompt_id}") + logger.error( + f"[delete_prompt_version] Cannot delete active version | Version ID: {version_id}, Prompt ID: {prompt_id}" + ) raise HTTPException(status_code=409, detail="Cannot delete active version") stmt = select(PromptVersion).where( and_( PromptVersion.id == version_id, PromptVersion.prompt_id == prompt_id, - PromptVersion.is_deleted.is_(False) + PromptVersion.is_deleted.is_(False), ) ) prompt_version = session.exec(stmt).first() if not prompt_version: - logger.error(f"[delete_prompt_version] Prompt version not found | version_id={version_id}, prompt_id={prompt_id}") + logger.error( + f"[delete_prompt_version] Prompt version not found | version_id={version_id}, prompt_id={prompt_id}" + ) raise HTTPException(status_code=404, detail="Prompt version not found") prompt_version.is_deleted = True prompt_version.deleted_at = now() - + session.add(prompt_version) session.commit() session.refresh(prompt_version) - - logger.info(f"[delete_prompt_version] Deleted prompt version | Version ID: {version_id}, Prompt ID: {prompt_id}") \ No newline at end of file + + logger.info( + f"[delete_prompt_version] Deleted prompt version | Version ID: {version_id}, Prompt ID: {prompt_id}" + ) diff --git a/backend/app/crud/prompts.py b/backend/app/crud/prompts.py index 2c6c3105..86b11716 100644 --- a/backend/app/crud/prompts.py +++ b/backend/app/crud/prompts.py @@ -17,14 +17,14 @@ logger = logging.getLogger(__name__) -def create_prompt(session: Session, prompt_in: PromptCreate, project_id: int) -> PromptWithVersion: +def create_prompt( + session: Session, prompt_in: PromptCreate, project_id: int +) -> tuple[Prompt, PromptVersion]: """ Create a new prompt and its first version. """ prompt = Prompt( - name=prompt_in.name, - description=prompt_in.description, - project_id=project_id + name=prompt_in.name, description=prompt_in.description, project_id=project_id ) session.add(prompt) session.flush() @@ -33,7 +33,7 @@ def create_prompt(session: Session, prompt_in: PromptCreate, project_id: int) -> prompt_id=prompt.id, instruction=prompt_in.instruction, commit_message=prompt_in.commit_message, - version=1 + version=1, ) session.add(version) session.flush() @@ -48,10 +48,7 @@ def create_prompt(session: Session, prompt_in: PromptCreate, project_id: int) -> f"project_id={project_id}, version_id={version.id}" ) - return PromptWithVersion( - **prompt.model_dump(), - version=version - ) + return prompt, version def get_prompts( @@ -63,10 +60,7 @@ def get_prompts( """Get prompts for a project.""" stmt = ( select(Prompt) - .where( - Prompt.project_id == project_id, - Prompt.is_deleted.is_(False) - ) + .where(Prompt.project_id == project_id, Prompt.is_deleted.is_(False)) .order_by(Prompt.updated_at.desc()) .offset(skip) .limit(limit) @@ -88,13 +82,10 @@ def prompt_exists(session: Session, prompt_id: UUID, project_id: int) -> Prompt: """ Check if a prompt exists for the given ID and project. """ - stmt = ( - select(Prompt) - .where( - Prompt.id == prompt_id, - Prompt.project_id == project_id, - Prompt.is_deleted.is_(False) - ) + stmt = select(Prompt).where( + Prompt.id == prompt_id, + Prompt.project_id == project_id, + Prompt.is_deleted.is_(False), ) prompt = session.exec(stmt).first() @@ -108,25 +99,21 @@ def prompt_exists(session: Session, prompt_id: UUID, project_id: int) -> Prompt: def get_prompt_by_id( - session: Session, - prompt_id: UUID, - project_id: int, - include_versions: bool = False -) -> PromptWithVersions: + session: Session, prompt_id: UUID, project_id: int, include_versions: bool = False +) -> tuple[Prompt, list[PromptVersion]]: """ Get a prompt by its ID, optionally including all versions. By default, Always returns the active version. """ if include_versions: join_condition = and_( - PromptVersion.prompt_id == Prompt.id, - PromptVersion.is_deleted.is_(False) + PromptVersion.prompt_id == Prompt.id, PromptVersion.is_deleted.is_(False) ) order_by = PromptVersion.version.desc() else: join_condition = and_( PromptVersion.id == Prompt.active_version, - PromptVersion.is_deleted.is_(False) + PromptVersion.is_deleted.is_(False), ) order_by = None # no need to order when fetching only 1 row @@ -136,7 +123,7 @@ def get_prompt_by_id( .where( Prompt.id == prompt_id, Prompt.project_id == project_id, - Prompt.is_deleted.is_(False) + Prompt.is_deleted.is_(False), ) ) if order_by is not None: @@ -144,36 +131,30 @@ def get_prompt_by_id( results = session.exec(stmt).all() if not results: - logger.error(f"[get_prompt_by_id] Prompt not found | ID: {prompt_id}, Project ID: {project_id}") + logger.error( + f"[get_prompt_by_id] Prompt not found | ID: {prompt_id}, Project ID: {project_id}" + ) raise HTTPException(status_code=404, detail="Prompt not found") # Unpack tuples into variables - prompt, _ = results[0] + prompt, _ = results[0] versions = [version for _, version in results if version is not None] - return PromptWithVersions( - **prompt.model_dump(), - versions=versions - ) + return prompt, versions def update_prompt( session: Session, prompt_id: UUID, project_id: int, prompt_update: PromptUpdate ) -> Prompt: - prompt = prompt_exists( - session=session, prompt_id=prompt_id, project_id=project_id - ) + prompt = prompt_exists(session=session, prompt_id=prompt_id, project_id=project_id) update_prompt = prompt_update.model_dump(exclude_unset=True) - active_version = update_prompt.get('active_version') + active_version = update_prompt.get("active_version") if active_version: - stmt = ( - select(PromptVersion) - .where( - PromptVersion.id == active_version, - PromptVersion.prompt_id == prompt.id, - PromptVersion.is_deleted.is_(False) - ) + stmt = select(PromptVersion).where( + PromptVersion.id == active_version, + PromptVersion.prompt_id == prompt.id, + PromptVersion.is_deleted.is_(False), ) prompt_version = session.exec(stmt).first() if not prompt_version: @@ -198,9 +179,7 @@ def update_prompt( def delete_prompt(session: Session, prompt_id: UUID, project_id: int) -> None: - prompt = prompt_exists( - session=session, prompt_id=prompt_id, project_id=project_id - ) + prompt = prompt_exists(session=session, prompt_id=prompt_id, project_id=project_id) prompt.is_deleted = True prompt.deleted_at = now() @@ -209,4 +188,4 @@ def delete_prompt(session: Session, prompt_id: UUID, project_id: int) -> None: session.refresh(prompt) logger.info( f"[delete_prompt] Prompt deleted | id={prompt.id}, name={prompt.name}, project_id={project_id}" - ) \ No newline at end of file + ) From fcf40010c1ebe27a8e739311ec395e6469df895f Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Wed, 13 Aug 2025 12:03:08 +0530 Subject: [PATCH 4/7] test cases for prompt management --- .../tests/api/routes/test_prompt_versions.py | 81 ++++ backend/app/tests/api/routes/test_prompts.py | 317 ++++++++++++++ .../app/tests/crud/test_prompt_versions.py | 215 ++++++++++ backend/app/tests/crud/test_prompts.py | 395 ++++++++++++++++++ 4 files changed, 1008 insertions(+) create mode 100644 backend/app/tests/api/routes/test_prompt_versions.py create mode 100644 backend/app/tests/api/routes/test_prompts.py create mode 100644 backend/app/tests/crud/test_prompt_versions.py create mode 100644 backend/app/tests/crud/test_prompts.py diff --git a/backend/app/tests/api/routes/test_prompt_versions.py b/backend/app/tests/api/routes/test_prompt_versions.py new file mode 100644 index 00000000..8f11844b --- /dev/null +++ b/backend/app/tests/api/routes/test_prompt_versions.py @@ -0,0 +1,81 @@ +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.crud.prompts import create_prompt +from app.models import APIKeyPublic, PromptCreate, PromptVersion, PromptVersionCreate + + +def test_create_prompt_version_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful creation of a prompt version via API route.""" + prompt_in = PromptCreate( + name=f"Test Prompt", + description="Prompt for testing version creation route", + instruction="Initial instruction", + commit_message="Initial version" + ) + prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=user_api_key.project_id) + + version_in = PromptVersionCreate( + instruction="Version 2 instructions", + commit_message="Second version" + ) + + response = client.post( + f"/api/v1/prompts/{prompt.id}/versions", + headers={"X-API-KEY": user_api_key.key}, + json=version_in.model_dump(), + ) + + assert response.status_code == 201 + response_data = response.json() + + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + assert data["prompt_id"] == str(prompt.id) + assert data["instruction"] == version_in.instruction + assert data["commit_message"] == version_in.commit_message + assert data["version"] == 2 # First version created by create_prompt, this is second + + +def test_delete_prompt_version_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful deletion of a non-active prompt version via API route""" + prompt_in = PromptCreate( + name=f"Test Prompt", + description="Prompt for testing version creation route", + instruction="Initial instruction", + commit_message="Initial version" + ) + prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=user_api_key.project_id) + + # Create a second version (non-active) + second_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2 + ) + db.add(second_version) + db.commit() + + response = client.delete( + f"/api/v1/prompts/{prompt.id}/versions/{second_version.id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"]["message"] == "Prompt version deleted successfully." + + db.refresh(second_version) + assert second_version.is_deleted + assert second_version.deleted_at is not None \ No newline at end of file diff --git a/backend/app/tests/api/routes/test_prompts.py b/backend/app/tests/api/routes/test_prompts.py new file mode 100644 index 00000000..8c3601e1 --- /dev/null +++ b/backend/app/tests/api/routes/test_prompts.py @@ -0,0 +1,317 @@ +from uuid import uuid4 + +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.crud.prompts import create_prompt +from app.models import APIKeyPublic, PromptCreate, PromptUpdate, PromptVersion + + + +def test_create_prompt_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful creation of a prompt via API route""" + project_id = user_api_key.project_id + prompt_in = PromptCreate( + name=f"test_prompt_{uuid4()}", + description="Test prompt description", + instruction="Test instruction", + commit_message="Initial version" + ) + + response = client.post( + "/api/v1/prompts/", + headers={"X-API-KEY": user_api_key.key}, + json=prompt_in.model_dump(), + ) + + assert response.status_code == 201 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + + assert data["name"] == prompt_in.name + assert data["description"] == prompt_in.description + assert data["project_id"] == project_id + assert data["inserted_at"] is not None + assert data["updated_at"] is not None + + assert "version" in data + assert data["version"]["instruction"] == prompt_in.instruction + assert data["version"]["commit_message"] == prompt_in.commit_message + assert data["version"]["version"] == 1 + assert data["active_version"] == data["version"]["id"] + + +def test_get_prompts_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successfully retrieving prompts with pagination metadata""" + project_id = user_api_key.project_id + + # Create multiple prompts + prompt_1_in = PromptCreate( + name=f"prompt_1_{uuid4()}", + description="First prompt description", + instruction="First instruction", + commit_message="Initial version" + ) + prompt_2_in = PromptCreate( + name=f"prompt_2_{uuid4()}", + description="Second prompt description", + instruction="Second instruction", + commit_message="Initial version" + ) + create_prompt(db, prompt_in=prompt_1_in, project_id=project_id) + create_prompt(db, prompt_in=prompt_2_in, project_id=project_id) + + response = client.get( + f"/api/v1/prompts/?skip=0&limit=100", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + assert "metadata" in response_data + assert response_data["metadata"]["pagination"]["total"] == 2 + assert response_data["metadata"]["pagination"]["skip"] == 0 + assert response_data["metadata"]["pagination"]["limit"] == 100 + + prompts = response_data["data"] + assert len(prompts) == 2 + assert prompts[0]["name"] == prompt_2_in.name + assert prompts[1]["name"] == prompt_1_in.name + assert all(prompt["project_id"] == project_id for prompt in prompts) + + +def test_get_prompts_route_empty( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test retrieving an empty list when no prompts exist""" + response = client.get( + f"/api/v1/prompts/?skip=0&limit=100", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"] == [] + assert response_data["metadata"]["pagination"]["total"] == 0 + assert response_data["metadata"]["pagination"]["skip"] == 0 + assert response_data["metadata"]["pagination"]["limit"] == 100 + + +def test_get_prompts_route_pagination( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test retrieving prompts with specific skip and limit values""" + project_id = user_api_key.project_id + + for i in range(3): + prompt_in = PromptCreate( + name=f"prompt_{i}", + description=f"Prompt {i} description", + instruction=f"Instruction {i}", + commit_message="Initial version" + ) + create_prompt(db, prompt_in=prompt_in, project_id=project_id) + + response = client.get( + f"/api/v1/prompts/?skip=1&limit=1", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert len(response_data["data"]) == 1 + assert response_data["metadata"]["pagination"]["total"] == 3 + assert response_data["metadata"]["pagination"]["skip"] == 1 + assert response_data["metadata"]["pagination"]["limit"] == 1 + + +def test_get_prompt_by_id_route_success_active_version( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successfully retrieving a prompt with its active version""" + project_id = user_api_key.project_id + prompt_in = PromptCreate( + name=f"test_prompt_{uuid4()}", + description="Test prompt description", + instruction="Test instruction", + commit_message="Initial version" + ) + prompt, version = create_prompt(db, prompt_in=prompt_in, project_id=project_id) + + response = client.get( + f"/api/v1/prompts/{prompt.id}?include_versions=false", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + + assert data["id"] == str(prompt.id) + assert data["name"] == prompt_in.name + assert data["description"] == prompt_in.description + assert data["project_id"] == project_id + assert len(data["versions"]) == 1 + assert data["versions"][0]["id"] == str(version.id) + assert data["versions"][0]["instruction"] == prompt_in.instruction + assert data["versions"][0]["commit_message"] == prompt_in.commit_message + assert data["versions"][0]["version"] == 1 + assert data["active_version"] == str(version.id) + + +def test_get_prompt_by_id_route_with_versions( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test retrieving a prompt with all its versions""" + project_id = user_api_key.project_id + prompt_in = PromptCreate( + name=f"test_prompt_{uuid4()}", + description="Test prompt description", + instruction="Test instruction", + commit_message="Initial version" + ) + prompt, version = create_prompt(db, prompt_in=prompt_in, project_id=project_id) + + second_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2 + ) + db.add(second_version) + db.commit() + + response = client.get( + f"/api/v1/prompts/{prompt.id}?include_versions=true", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + + assert len(data["versions"]) == 2 + assert data["versions"][0]["version"] == 2 + assert data["versions"][1]["version"] == 1 + assert data["versions"][0]["instruction"] == "Second instruction" + assert data["versions"][1]["instruction"] == prompt_in.instruction + + +def test_get_prompt_by_id_route_not_found( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test retrieving a non-existent prompt returns 404""" + non_existent_prompt_id = uuid4() + + response = client.get( + f"/api/v1/prompts/{non_existent_prompt_id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 404 + response_data = response.json() + assert response_data["success"] is False + assert "not found" in response_data["error"].lower() + + +def test_update_prompt_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successfully updating a prompt's name and description""" + project_id = user_api_key.project_id + prompt_in = PromptCreate( + name=f"test_prompt", + description="Test prompt description", + instruction="Test instruction", + commit_message="Initial version" + ) + prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=project_id) + + prompt_version = PromptVersion( + prompt_id=prompt.id, + instruction="Test instruction", + commit_message="Initial version", + version=2 + ) + db.add(prompt_version) + db.commit() + + update_data = PromptUpdate(name="updated_prompt", description="Updated description", active_version=prompt_version.id) + + response = client.patch( + f"/api/v1/prompts/{prompt.id}", + headers={"X-API-KEY": user_api_key.key}, + json=update_data.model_dump(mode="json", exclude_unset=True) + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "data" in response_data + data = response_data["data"] + + assert data["id"] == str(prompt.id) + assert data["name"] == "updated_prompt" + assert data["description"] == "Updated description" + assert data["project_id"] == project_id + assert data["active_version"] == str(prompt_version.id) + + +def test_delete_prompt_route_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successfully soft-deleting a prompt""" + project_id = user_api_key.project_id + prompt_in = PromptCreate( + name=f"test_prompt", + description="Test prompt description", + instruction="Test instruction", + commit_message="Initial version" + ) + prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=project_id) + + response = client.delete( + f"/api/v1/prompts/{prompt.id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"]["message"] == "Prompt deleted successfully." + + db.refresh(prompt) + assert prompt.is_deleted + assert prompt.deleted_at is not None diff --git a/backend/app/tests/crud/test_prompt_versions.py b/backend/app/tests/crud/test_prompt_versions.py new file mode 100644 index 00000000..b9303887 --- /dev/null +++ b/backend/app/tests/crud/test_prompt_versions.py @@ -0,0 +1,215 @@ +from uuid import uuid4 + +import pytest +from fastapi import HTTPException +from sqlmodel import Session + +from app.core.util import now +from app.crud.prompt_versions import ( + create_prompt_version, + delete_prompt_version, + get_next_prompt_version, +) +from app.crud.prompts import create_prompt +from app.models import Prompt, PromptCreate, PromptVersion, PromptVersionCreate +from app.tests.utils.utils import get_project + + +@pytest.fixture +def prompt(db: Session) -> Prompt: + """Fixture to create a reusable prompt""" + project = get_project(db) + prompt_data = PromptCreate( + name=f"test_prompt_{uuid4()}", + description="This is a test prompt", + instruction="Test instruction", + commit_message="Initial version" + ) + prompt, _ = create_prompt(db, prompt_in=prompt_data, project_id=project.id) + return prompt + + +def test_create_prompt_version_success(db: Session, prompt: Prompt): + """Successfully create a new prompt version""" + project_id = prompt.project_id + version_data = PromptVersionCreate( + instruction="New instruction", + commit_message="New version" + ) + + prompt_version = create_prompt_version(db, prompt_id=prompt.id, prompt_version_in=version_data, project_id=project_id) + + assert isinstance(prompt_version, PromptVersion) + assert prompt_version.prompt_id == prompt.id + assert prompt_version.version == 2 # First version created by create_prompt, this is second + assert prompt_version.instruction == version_data.instruction + assert prompt_version.commit_message == version_data.commit_message + assert not prompt_version.is_deleted + + +def test_create_prompt_version_multiple_versions(db: Session, prompt: Prompt): + """Create multiple versions and verify correct version increment""" + project_id = prompt.project_id + version_data_1 = PromptVersionCreate( + instruction="New instruction 1", + commit_message="Version 2" + ) + version_data_2 = PromptVersionCreate( + instruction="New instruction 2", + commit_message="Version 3" + ) + + version_1 = create_prompt_version(db, prompt_id=prompt.id, prompt_version_in=version_data_1, project_id=project_id) + version_2 = create_prompt_version(db, prompt_id=prompt.id, prompt_version_in=version_data_2, project_id=project_id) + + assert version_1.version == 2 + assert version_2.version == 3 + assert version_1.instruction == "New instruction 1" + assert version_2.instruction == "New instruction 2" + + +def test_create_prompt_version_prompt_not_found(db: Session): + """Raise 404 error when prompt does not exist""" + project = get_project(db) + non_existent_id = uuid4() + version_data = PromptVersionCreate( + instruction="New instruction", + commit_message="New version" + ) + + with pytest.raises(HTTPException) as exc_info: + create_prompt_version(db, prompt_id=non_existent_id, prompt_version_in=version_data, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_create_prompt_version_deleted_prompt(db: Session, prompt: Prompt): + """Raise 404 error when prompt is deleted""" + project_id = prompt.project_id + prompt.is_deleted = True + db.add(prompt) + db.commit() + + version_data = PromptVersionCreate( + instruction="New instruction", + commit_message="New version" + ) + + with pytest.raises(HTTPException) as exc_info: + create_prompt_version(db, prompt_id=prompt.id, prompt_version_in=version_data, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_get_next_prompt_version(db: Session, prompt: Prompt): + """Return incremented version number when versions exist""" + prompt_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2 + ) + db.add(prompt_version) + db.commit() + + version = get_next_prompt_version(db, prompt_id=prompt.id) + assert version == 3 + + +def test_get_next_prompt_version_with_deleted_version(db: Session, prompt: Prompt): + """Return incremented version number even if the latest version is deleted""" + prompt_version = PromptVersion( + prompt_id=prompt.id, + instruction="Deleted instruction", + commit_message="Deleted version", + version=2, + is_deleted=True + ) + db.add(prompt_version) + db.commit() + + version = get_next_prompt_version(db, prompt_id=prompt.id) + assert version == 3 + + +def test_delete_prompt_version_success(db: Session, prompt: Prompt): + """Successfully soft-delete a non-active prompt version""" + project_id = prompt.project_id + + # Create a second version (non-active) + second_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2 + ) + db.add(second_version) + db.commit() + + delete_prompt_version(db, prompt_id=prompt.id, version_id=second_version.id, project_id=project_id) + + db.refresh(second_version) + assert second_version.is_deleted + assert second_version.deleted_at is not None + assert second_version.deleted_at <= now() + + +def test_delete_prompt_version_active_version(db: Session, prompt: Prompt): + """Raise 409 error when attempting to delete the active version""" + project_id = prompt.project_id + active_version_id = prompt.active_version + + with pytest.raises(HTTPException) as exc_info: + delete_prompt_version(db, prompt_id=prompt.id, version_id=active_version_id, project_id=project_id) + + assert exc_info.value.status_code == 409 + assert "cannot delete active version" in exc_info.value.detail.lower() + + +def test_delete_prompt_version_not_found(db: Session, prompt: Prompt): + """Raise 404 error when version does not exist""" + project_id = prompt.project_id + non_existent_version_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt_version(db, prompt_id=prompt.id, version_id=non_existent_version_id, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "prompt version not found" in exc_info.value.detail.lower() + + +def test_delete_prompt_version_already_deleted(db: Session, prompt: Prompt): + """Raise 404 error when attempting to delete an already deleted version""" + project_id = prompt.project_id + + second_version = PromptVersion( + prompt_id=prompt.id, + instruction="Second instruction", + commit_message="Second version", + version=2, + is_deleted=True, + deleted_at=now() + ) + db.add(second_version) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt_version(db, prompt_id=prompt.id, version_id=second_version.id, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "prompt version not found" in exc_info.value.detail.lower() + + +def test_delete_prompt_version_prompt_not_found(db: Session): + """Raise 404 error when prompt does not exist""" + project = get_project(db) + non_existent_prompt_id = uuid4() + version_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt_version(db, prompt_id=non_existent_prompt_id, version_id=version_id, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() \ No newline at end of file diff --git a/backend/app/tests/crud/test_prompts.py b/backend/app/tests/crud/test_prompts.py new file mode 100644 index 00000000..1df9148a --- /dev/null +++ b/backend/app/tests/crud/test_prompts.py @@ -0,0 +1,395 @@ +import pytest +from uuid import uuid4 + +from fastapi import HTTPException +from sqlmodel import Session + +from app.core.util import now +from app.crud.prompts import ( + count_prompts_in_project, + create_prompt, + delete_prompt, + get_prompt_by_id, + get_prompts, + prompt_exists, + update_prompt, +) +from app.models import Prompt, PromptCreate, PromptUpdate, PromptVersion +from app.tests.utils.utils import get_project + + +@pytest.fixture +def prompt(db) -> Prompt: + """Fixture to create a reusable prompt""" + project = get_project(db) + prompt_data = PromptCreate( + name="test_prompt", + description="This is a test prompt", + instruction="Test instruction", + commit_message="Initial version" + ) + prompt, _ = create_prompt(db, prompt_in=prompt_data, project_id=project.id) + return prompt + + +def test_create_prompt_success(db: Session): + """Prompt and its first version are created successfully when valid input is provided""" + project = get_project(db) + prompt_data = PromptCreate( + name="test_prompt", + description="This is a test prompt", + instruction="Test instruction", + commit_message="Initial version" + ) + + prompt, version = create_prompt(db, prompt_data, project_id=project.id) + + # Prompt checks + assert prompt.id is not None + assert prompt.name == prompt_data.name + assert prompt.description == prompt_data.description + assert prompt.project_id == project.id + assert prompt.inserted_at is not None + assert prompt.updated_at is not None + assert prompt.active_version == version.id + + # Version checks + assert version.prompt_id == prompt.id + assert version.version == 1 + assert version.instruction == prompt_data.instruction + assert version.commit_message == prompt_data.commit_message + + +def test_get_prompts_success(db: Session): + """Retrieve prompts for a project with pagination, ensuring correct filtering and ordering""" + project = get_project(db) + + create_prompt( + db, + PromptCreate(name="prompt1", description="First prompt", instruction="Instruction 1", commit_message="Initial"), + project_id=project.id + ) + create_prompt( + db, + PromptCreate(name="prompt2", description="Second prompt", instruction="Instruction 2", commit_message="Initial"), + project_id=project.id + ) + + prompts = get_prompts(db, project_id=project.id, skip=0, limit=100) + + assert len(prompts) == 2 + assert prompts[0].name == "prompt2" + assert prompts[1].name == "prompt1" + assert all(not prompt.is_deleted for prompt in prompts) + assert all(prompt.project_id == project.id for prompt in prompts) + + + prompts_limited = get_prompts(db, project_id=project.id, skip=1, limit=1) + assert len(prompts_limited) == 1 + assert prompts_limited[0].name == "prompt1" + + +def test_get_prompts_empty(db: Session): + """Return empty list when no prompts exist for a project or project has no non-deleted prompts""" + project = get_project(db) + + prompts = get_prompts(db, project_id=project.id) + assert prompts == [] + + # Create a deleted prompt + prompt, _ = create_prompt( + db, + PromptCreate(name="deleted_prompt", description="Deleted", instruction="Instruction", commit_message="Initial"), + project_id=project.id + ) + prompt.is_deleted = True + db.add(prompt) + db.commit() + + prompts = get_prompts(db, project_id=project.id) + assert prompts == [] + + +def test_count_prompts_in_project_success(db: Session): + """Correctly count non-deleted prompts in a project""" + project = get_project(db) + + # Create multiple prompts + create_prompt( + db, + PromptCreate(name="prompt1", description="First prompt", instruction="Instruction 1", commit_message="Initial"), + project_id=project.id + ) + create_prompt( + db, + PromptCreate(name="prompt2", description="Second prompt", instruction="Instruction 2", commit_message="Initial"), + project_id=project.id + ) + + count = count_prompts_in_project(db, project_id=project.id) + assert count == 2 + + +def test_count_prompts_in_project_empty_or_deleted(db: Session): + """Return 0 when no prompts exist or all prompts are deleted""" + project = get_project(db) + + # Test empty project + count = count_prompts_in_project(db, project_id=project.id) + assert count == 0 + + # Create a deleted prompt + prompt, _ = create_prompt( + db, + PromptCreate(name="deleted_prompt", description="Deleted", instruction="Instruction", commit_message="Initial"), + project_id=project.id + ) + prompt.is_deleted = True + db.add(prompt) + db.commit() + + count = count_prompts_in_project(db, project_id=project.id) + assert count == 0 + + +def test_prompt_exists_success(db: Session, prompt: Prompt): + """Successfully retrieve an existing prompt by ID and project""" + project = get_project(db) # Call get_project as a function + result = prompt_exists(db, prompt_id=prompt.id, project_id=project.id) + + assert isinstance(result, Prompt) + assert result.id == prompt.id + assert result.project_id == project.id + assert result.name == "test_prompt" + assert result.description == "This is a test prompt" + assert not result.is_deleted + + +def test_prompt_exists_not_found(db: Session): + """Raise 404 error when prompt ID does not exist""" + project = get_project(db) # Call get_project as a function + non_existent_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + prompt_exists(db, prompt_id=non_existent_id, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_prompt_exists_deleted_prompt(db: Session, prompt: Prompt): + """Raise 404 error when prompt is deleted""" + project_id = prompt.project_id + prompt.is_deleted = True + db.add(prompt) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + prompt_exists(db, prompt_id=prompt.id, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_get_prompt_by_id_success_active_version(db: Session, prompt: Prompt): + """Retrieve a prompt by ID with only its active version""" + project = get_project(db) + retrieved_prompt, versions = get_prompt_by_id(db, prompt_id=prompt.id, project_id=project.id, include_versions=False) + + assert isinstance(retrieved_prompt, Prompt) + assert isinstance(versions, list) + assert len(versions) == 1 + assert retrieved_prompt.id == prompt.id + assert retrieved_prompt.name == prompt.name + assert retrieved_prompt.description == prompt.description + assert retrieved_prompt.project_id == project.id + assert not retrieved_prompt.is_deleted + assert versions[0].id == prompt.active_version + assert versions[0].instruction == "Test instruction" + assert versions[0].commit_message == "Initial version" + assert versions[0].version == 1 + assert not versions[0].is_deleted + + +def test_get_prompt_by_id_with_versions(db: Session, prompt: Prompt): + """Retrieve a prompt by ID with all its versions""" + project = get_project(db) + + # Add another version + new_version = PromptVersion( + prompt_id=prompt.id, + instruction="Updated instruction", + commit_message="Second version", + version=2 + ) + db.add(new_version) + db.commit() + + retrieved_prompt, versions = get_prompt_by_id(db, prompt_id=prompt.id, project_id=project.id, include_versions=True) + + assert isinstance(retrieved_prompt, Prompt) + assert isinstance(versions, list) + assert len(versions) == 2 + assert retrieved_prompt.id == prompt.id + assert retrieved_prompt.name == prompt.name + assert retrieved_prompt.description == prompt.description + assert retrieved_prompt.project_id == project.id + assert not retrieved_prompt.is_deleted + assert versions[0].version == 2 # Latest version first (descending order) + assert versions[1].version == 1 + assert versions[0].instruction == "Updated instruction" + assert versions[1].instruction == "Test instruction" + assert not any(version.is_deleted for version in versions) + + +def test_get_prompt_by_id_not_found(db: Session): + """Raise 404 error when prompt ID does not exist""" + project = get_project(db) + non_existent_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + get_prompt_by_id(db, prompt_id=non_existent_id, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_get_prompt_by_id_deleted_prompt(db: Session, prompt: Prompt): + """Raise 404 error when prompt is deleted""" + project_id = prompt.project_id + prompt.is_deleted = True + db.add(prompt) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + get_prompt_by_id(db, prompt_id=prompt.id, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_get_prompt_by_id_deleted_version(db: Session, prompt: Prompt): + """Exclude deleted versions when retrieving prompt versions""" + project_id = prompt.project_id + + deleted_version = PromptVersion( + prompt_id=prompt.id, + instruction="Deleted instruction", + commit_message="Deleted version", + version=2, + is_deleted=True + ) + db.add(deleted_version) + db.commit() + + retrieved_prompt, versions = get_prompt_by_id(db, prompt_id=prompt.id, project_id=project_id, include_versions=True) + + assert isinstance(retrieved_prompt, Prompt) + assert isinstance(versions, list) + assert len(versions) == 1 + assert versions[0].version == 1 + assert not versions[0].is_deleted + assert versions[0].instruction == "Test instruction" + + +def test_update_prompt_success_name_description(db: Session, prompt: Prompt): + """Successfully update prompt's name and description""" + project_id = prompt.project_id + update_data = PromptUpdate(name="updated_prompt", description="Updated description") + + updated_prompt = update_prompt(db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data) + + assert isinstance(updated_prompt, Prompt) + assert updated_prompt.id == prompt.id + assert updated_prompt.name == "updated_prompt" + assert updated_prompt.description == "Updated description" + assert updated_prompt.project_id == project_id + assert not updated_prompt.is_deleted + + +def test_update_prompt_success_active_version(db: Session, prompt: Prompt): + """Successfully update prompt's active version""" + project_id = prompt.project_id + + # Create a new version + new_version = PromptVersion( + prompt_id=prompt.id, + instruction="New instruction", + commit_message="Second version", + version=2 + ) + db.add(new_version) + db.commit() + + update_data = PromptUpdate(active_version=new_version.id) + updated_prompt = update_prompt(db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data) + + assert isinstance(updated_prompt, Prompt) + assert updated_prompt.id == prompt.id + assert updated_prompt.active_version == new_version.id + + +def test_update_prompt_invalid_active_version(db: Session, prompt: Prompt): + """Raise 404 error when updating with an invalid active version ID""" + project_id = prompt.project_id + invalid_version_id = uuid4() + + update_data = PromptUpdate(active_version=invalid_version_id) + + with pytest.raises(HTTPException) as exc_info: + update_prompt(db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data) + + assert exc_info.value.status_code == 404 + assert "invalid active version id" in exc_info.value.detail.lower() + + +def test_update_prompt_not_found(db: Session): + """Raise 404 error when prompt does not exist""" + project = get_project(db) + non_existent_id = uuid4() + update_data = PromptUpdate(name="new_name") + + with pytest.raises(HTTPException) as exc_info: + update_prompt(db, prompt_id=non_existent_id, project_id=project.id, prompt_update=update_data) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_delete_prompt_success(db: Session, prompt: Prompt): + """Successfully soft delete a prompt""" + project_id = prompt.project_id + + delete_prompt(db, prompt_id=prompt.id, project_id=project_id) + + db.refresh(prompt) + assert prompt.is_deleted + assert prompt.deleted_at is not None + assert prompt.deleted_at <= now() + + +def test_delete_prompt_not_found(db: Session): + """Raise 404 error when deleting a non-existent prompt""" + project = get_project(db) + non_existent_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt(db, prompt_id=non_existent_id, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +def test_delete_prompt_already_deleted(db: Session, prompt: Prompt): + """Raise 404 error when attempting to delete an already deleted prompt""" + project_id = prompt.project_id + prompt.is_deleted = True + prompt.deleted_at = now() + db.add(prompt) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + delete_prompt(db, prompt_id=prompt.id, project_id=project_id) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() \ No newline at end of file From ef1933be3e185d40945a81b2d7d2f2e6bc53d67a Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Wed, 13 Aug 2025 12:18:38 +0530 Subject: [PATCH 5/7] pre commit --- .../tests/api/routes/test_prompt_versions.py | 33 ++-- backend/app/tests/api/routes/test_prompts.py | 79 ++++---- .../app/tests/crud/test_prompt_versions.py | 129 +++++++----- backend/app/tests/crud/test_prompts.py | 186 +++++++++++------- 4 files changed, 255 insertions(+), 172 deletions(-) diff --git a/backend/app/tests/api/routes/test_prompt_versions.py b/backend/app/tests/api/routes/test_prompt_versions.py index 8f11844b..15f6fd6c 100644 --- a/backend/app/tests/api/routes/test_prompt_versions.py +++ b/backend/app/tests/api/routes/test_prompt_versions.py @@ -15,13 +15,14 @@ def test_create_prompt_version_route_success( name=f"Test Prompt", description="Prompt for testing version creation route", instruction="Initial instruction", - commit_message="Initial version" + commit_message="Initial version", + ) + prompt, _ = create_prompt( + db, prompt_in=prompt_in, project_id=user_api_key.project_id ) - prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=user_api_key.project_id) version_in = PromptVersionCreate( - instruction="Version 2 instructions", - commit_message="Second version" + instruction="Version 2 instructions", commit_message="Second version" ) response = client.post( @@ -32,14 +33,16 @@ def test_create_prompt_version_route_success( assert response.status_code == 201 response_data = response.json() - + assert response_data["success"] is True assert "data" in response_data data = response_data["data"] assert data["prompt_id"] == str(prompt.id) assert data["instruction"] == version_in.instruction assert data["commit_message"] == version_in.commit_message - assert data["version"] == 2 # First version created by create_prompt, this is second + assert ( + data["version"] == 2 + ) # First version created by create_prompt, this is second def test_delete_prompt_version_route_success( @@ -52,30 +55,32 @@ def test_delete_prompt_version_route_success( name=f"Test Prompt", description="Prompt for testing version creation route", instruction="Initial instruction", - commit_message="Initial version" + commit_message="Initial version", ) - prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=user_api_key.project_id) - + prompt, _ = create_prompt( + db, prompt_in=prompt_in, project_id=user_api_key.project_id + ) + # Create a second version (non-active) second_version = PromptVersion( prompt_id=prompt.id, instruction="Second instruction", commit_message="Second version", - version=2 + version=2, ) db.add(second_version) db.commit() - + response = client.delete( f"/api/v1/prompts/{prompt.id}/versions/{second_version.id}", headers={"X-API-KEY": user_api_key.key}, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True assert response_data["data"]["message"] == "Prompt version deleted successfully." - + db.refresh(second_version) assert second_version.is_deleted - assert second_version.deleted_at is not None \ No newline at end of file + assert second_version.deleted_at is not None diff --git a/backend/app/tests/api/routes/test_prompts.py b/backend/app/tests/api/routes/test_prompts.py index 8c3601e1..2177e075 100644 --- a/backend/app/tests/api/routes/test_prompts.py +++ b/backend/app/tests/api/routes/test_prompts.py @@ -7,7 +7,6 @@ from app.models import APIKeyPublic, PromptCreate, PromptUpdate, PromptVersion - def test_create_prompt_route_success( client: TestClient, db: Session, @@ -19,7 +18,7 @@ def test_create_prompt_route_success( name=f"test_prompt_{uuid4()}", description="Test prompt description", instruction="Test instruction", - commit_message="Initial version" + commit_message="Initial version", ) response = client.post( @@ -33,13 +32,13 @@ def test_create_prompt_route_success( assert response_data["success"] is True assert "data" in response_data data = response_data["data"] - + assert data["name"] == prompt_in.name assert data["description"] == prompt_in.description assert data["project_id"] == project_id assert data["inserted_at"] is not None assert data["updated_at"] is not None - + assert "version" in data assert data["version"]["instruction"] == prompt_in.instruction assert data["version"]["commit_message"] == prompt_in.commit_message @@ -54,28 +53,28 @@ def test_get_prompts_route_success( ): """Test successfully retrieving prompts with pagination metadata""" project_id = user_api_key.project_id - + # Create multiple prompts prompt_1_in = PromptCreate( name=f"prompt_1_{uuid4()}", description="First prompt description", instruction="First instruction", - commit_message="Initial version" + commit_message="Initial version", ) prompt_2_in = PromptCreate( name=f"prompt_2_{uuid4()}", description="Second prompt description", instruction="Second instruction", - commit_message="Initial version" + commit_message="Initial version", ) create_prompt(db, prompt_in=prompt_1_in, project_id=project_id) create_prompt(db, prompt_in=prompt_2_in, project_id=project_id) - + response = client.get( f"/api/v1/prompts/?skip=0&limit=100", headers={"X-API-KEY": user_api_key.key}, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True @@ -84,7 +83,7 @@ def test_get_prompts_route_success( assert response_data["metadata"]["pagination"]["total"] == 2 assert response_data["metadata"]["pagination"]["skip"] == 0 assert response_data["metadata"]["pagination"]["limit"] == 100 - + prompts = response_data["data"] assert len(prompts) == 2 assert prompts[0]["name"] == prompt_2_in.name @@ -102,7 +101,7 @@ def test_get_prompts_route_empty( f"/api/v1/prompts/?skip=0&limit=100", headers={"X-API-KEY": user_api_key.key}, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True @@ -119,21 +118,21 @@ def test_get_prompts_route_pagination( ): """Test retrieving prompts with specific skip and limit values""" project_id = user_api_key.project_id - + for i in range(3): prompt_in = PromptCreate( name=f"prompt_{i}", description=f"Prompt {i} description", instruction=f"Instruction {i}", - commit_message="Initial version" + commit_message="Initial version", ) create_prompt(db, prompt_in=prompt_in, project_id=project_id) - + response = client.get( f"/api/v1/prompts/?skip=1&limit=1", headers={"X-API-KEY": user_api_key.key}, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True @@ -154,21 +153,21 @@ def test_get_prompt_by_id_route_success_active_version( name=f"test_prompt_{uuid4()}", description="Test prompt description", instruction="Test instruction", - commit_message="Initial version" + commit_message="Initial version", ) prompt, version = create_prompt(db, prompt_in=prompt_in, project_id=project_id) - + response = client.get( f"/api/v1/prompts/{prompt.id}?include_versions=false", headers={"X-API-KEY": user_api_key.key}, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True assert "data" in response_data data = response_data["data"] - + assert data["id"] == str(prompt.id) assert data["name"] == prompt_in.name assert data["description"] == prompt_in.description @@ -192,30 +191,30 @@ def test_get_prompt_by_id_route_with_versions( name=f"test_prompt_{uuid4()}", description="Test prompt description", instruction="Test instruction", - commit_message="Initial version" + commit_message="Initial version", ) prompt, version = create_prompt(db, prompt_in=prompt_in, project_id=project_id) - + second_version = PromptVersion( prompt_id=prompt.id, instruction="Second instruction", commit_message="Second version", - version=2 + version=2, ) db.add(second_version) db.commit() - + response = client.get( f"/api/v1/prompts/{prompt.id}?include_versions=true", headers={"X-API-KEY": user_api_key.key}, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True assert "data" in response_data data = response_data["data"] - + assert len(data["versions"]) == 2 assert data["versions"][0]["version"] == 2 assert data["versions"][1]["version"] == 1 @@ -230,12 +229,12 @@ def test_get_prompt_by_id_route_not_found( ): """Test retrieving a non-existent prompt returns 404""" non_existent_prompt_id = uuid4() - + response = client.get( f"/api/v1/prompts/{non_existent_prompt_id}", headers={"X-API-KEY": user_api_key.key}, ) - + assert response.status_code == 404 response_data = response.json() assert response_data["success"] is False @@ -253,7 +252,7 @@ def test_update_prompt_route_success( name=f"test_prompt", description="Test prompt description", instruction="Test instruction", - commit_message="Initial version" + commit_message="Initial version", ) prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=project_id) @@ -261,25 +260,29 @@ def test_update_prompt_route_success( prompt_id=prompt.id, instruction="Test instruction", commit_message="Initial version", - version=2 + version=2, ) db.add(prompt_version) db.commit() - update_data = PromptUpdate(name="updated_prompt", description="Updated description", active_version=prompt_version.id) - + update_data = PromptUpdate( + name="updated_prompt", + description="Updated description", + active_version=prompt_version.id, + ) + response = client.patch( f"/api/v1/prompts/{prompt.id}", headers={"X-API-KEY": user_api_key.key}, - json=update_data.model_dump(mode="json", exclude_unset=True) + json=update_data.model_dump(mode="json", exclude_unset=True), ) - + assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True assert "data" in response_data data = response_data["data"] - + assert data["id"] == str(prompt.id) assert data["name"] == "updated_prompt" assert data["description"] == "Updated description" @@ -298,20 +301,20 @@ def test_delete_prompt_route_success( name=f"test_prompt", description="Test prompt description", instruction="Test instruction", - commit_message="Initial version" + commit_message="Initial version", ) prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=project_id) - + response = client.delete( f"/api/v1/prompts/{prompt.id}", headers={"X-API-KEY": user_api_key.key}, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True assert response_data["data"]["message"] == "Prompt deleted successfully." - + db.refresh(prompt) assert prompt.is_deleted assert prompt.deleted_at is not None diff --git a/backend/app/tests/crud/test_prompt_versions.py b/backend/app/tests/crud/test_prompt_versions.py index b9303887..c2f918c4 100644 --- a/backend/app/tests/crud/test_prompt_versions.py +++ b/backend/app/tests/crud/test_prompt_versions.py @@ -23,7 +23,7 @@ def prompt(db: Session) -> Prompt: name=f"test_prompt_{uuid4()}", description="This is a test prompt", instruction="Test instruction", - commit_message="Initial version" + commit_message="Initial version", ) prompt, _ = create_prompt(db, prompt_in=prompt_data, project_id=project.id) return prompt @@ -33,15 +33,18 @@ def test_create_prompt_version_success(db: Session, prompt: Prompt): """Successfully create a new prompt version""" project_id = prompt.project_id version_data = PromptVersionCreate( - instruction="New instruction", - commit_message="New version" + instruction="New instruction", commit_message="New version" ) - - prompt_version = create_prompt_version(db, prompt_id=prompt.id, prompt_version_in=version_data, project_id=project_id) - + + prompt_version = create_prompt_version( + db, prompt_id=prompt.id, prompt_version_in=version_data, project_id=project_id + ) + assert isinstance(prompt_version, PromptVersion) assert prompt_version.prompt_id == prompt.id - assert prompt_version.version == 2 # First version created by create_prompt, this is second + assert ( + prompt_version.version == 2 + ) # First version created by create_prompt, this is second assert prompt_version.instruction == version_data.instruction assert prompt_version.commit_message == version_data.commit_message assert not prompt_version.is_deleted @@ -51,17 +54,19 @@ def test_create_prompt_version_multiple_versions(db: Session, prompt: Prompt): """Create multiple versions and verify correct version increment""" project_id = prompt.project_id version_data_1 = PromptVersionCreate( - instruction="New instruction 1", - commit_message="Version 2" + instruction="New instruction 1", commit_message="Version 2" ) version_data_2 = PromptVersionCreate( - instruction="New instruction 2", - commit_message="Version 3" + instruction="New instruction 2", commit_message="Version 3" + ) + + version_1 = create_prompt_version( + db, prompt_id=prompt.id, prompt_version_in=version_data_1, project_id=project_id + ) + version_2 = create_prompt_version( + db, prompt_id=prompt.id, prompt_version_in=version_data_2, project_id=project_id ) - - version_1 = create_prompt_version(db, prompt_id=prompt.id, prompt_version_in=version_data_1, project_id=project_id) - version_2 = create_prompt_version(db, prompt_id=prompt.id, prompt_version_in=version_data_2, project_id=project_id) - + assert version_1.version == 2 assert version_2.version == 3 assert version_1.instruction == "New instruction 1" @@ -73,13 +78,17 @@ def test_create_prompt_version_prompt_not_found(db: Session): project = get_project(db) non_existent_id = uuid4() version_data = PromptVersionCreate( - instruction="New instruction", - commit_message="New version" + instruction="New instruction", commit_message="New version" ) - + with pytest.raises(HTTPException) as exc_info: - create_prompt_version(db, prompt_id=non_existent_id, prompt_version_in=version_data, project_id=project.id) - + create_prompt_version( + db, + prompt_id=non_existent_id, + prompt_version_in=version_data, + project_id=project.id, + ) + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() @@ -90,15 +99,19 @@ def test_create_prompt_version_deleted_prompt(db: Session, prompt: Prompt): prompt.is_deleted = True db.add(prompt) db.commit() - + version_data = PromptVersionCreate( - instruction="New instruction", - commit_message="New version" + instruction="New instruction", commit_message="New version" ) - + with pytest.raises(HTTPException) as exc_info: - create_prompt_version(db, prompt_id=prompt.id, prompt_version_in=version_data, project_id=project_id) - + create_prompt_version( + db, + prompt_id=prompt.id, + prompt_version_in=version_data, + project_id=project_id, + ) + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() @@ -109,11 +122,11 @@ def test_get_next_prompt_version(db: Session, prompt: Prompt): prompt_id=prompt.id, instruction="Second instruction", commit_message="Second version", - version=2 + version=2, ) db.add(prompt_version) db.commit() - + version = get_next_prompt_version(db, prompt_id=prompt.id) assert version == 3 @@ -125,11 +138,11 @@ def test_get_next_prompt_version_with_deleted_version(db: Session, prompt: Promp instruction="Deleted instruction", commit_message="Deleted version", version=2, - is_deleted=True + is_deleted=True, ) db.add(prompt_version) db.commit() - + version = get_next_prompt_version(db, prompt_id=prompt.id) assert version == 3 @@ -137,19 +150,21 @@ def test_get_next_prompt_version_with_deleted_version(db: Session, prompt: Promp def test_delete_prompt_version_success(db: Session, prompt: Prompt): """Successfully soft-delete a non-active prompt version""" project_id = prompt.project_id - + # Create a second version (non-active) second_version = PromptVersion( prompt_id=prompt.id, instruction="Second instruction", commit_message="Second version", - version=2 + version=2, ) db.add(second_version) db.commit() - - delete_prompt_version(db, prompt_id=prompt.id, version_id=second_version.id, project_id=project_id) - + + delete_prompt_version( + db, prompt_id=prompt.id, version_id=second_version.id, project_id=project_id + ) + db.refresh(second_version) assert second_version.is_deleted assert second_version.deleted_at is not None @@ -160,10 +175,12 @@ def test_delete_prompt_version_active_version(db: Session, prompt: Prompt): """Raise 409 error when attempting to delete the active version""" project_id = prompt.project_id active_version_id = prompt.active_version - + with pytest.raises(HTTPException) as exc_info: - delete_prompt_version(db, prompt_id=prompt.id, version_id=active_version_id, project_id=project_id) - + delete_prompt_version( + db, prompt_id=prompt.id, version_id=active_version_id, project_id=project_id + ) + assert exc_info.value.status_code == 409 assert "cannot delete active version" in exc_info.value.detail.lower() @@ -172,10 +189,15 @@ def test_delete_prompt_version_not_found(db: Session, prompt: Prompt): """Raise 404 error when version does not exist""" project_id = prompt.project_id non_existent_version_id = uuid4() - + with pytest.raises(HTTPException) as exc_info: - delete_prompt_version(db, prompt_id=prompt.id, version_id=non_existent_version_id, project_id=project_id) - + delete_prompt_version( + db, + prompt_id=prompt.id, + version_id=non_existent_version_id, + project_id=project_id, + ) + assert exc_info.value.status_code == 404 assert "prompt version not found" in exc_info.value.detail.lower() @@ -183,21 +205,23 @@ def test_delete_prompt_version_not_found(db: Session, prompt: Prompt): def test_delete_prompt_version_already_deleted(db: Session, prompt: Prompt): """Raise 404 error when attempting to delete an already deleted version""" project_id = prompt.project_id - + second_version = PromptVersion( prompt_id=prompt.id, instruction="Second instruction", commit_message="Second version", version=2, is_deleted=True, - deleted_at=now() + deleted_at=now(), ) db.add(second_version) db.commit() - + with pytest.raises(HTTPException) as exc_info: - delete_prompt_version(db, prompt_id=prompt.id, version_id=second_version.id, project_id=project_id) - + delete_prompt_version( + db, prompt_id=prompt.id, version_id=second_version.id, project_id=project_id + ) + assert exc_info.value.status_code == 404 assert "prompt version not found" in exc_info.value.detail.lower() @@ -207,9 +231,14 @@ def test_delete_prompt_version_prompt_not_found(db: Session): project = get_project(db) non_existent_prompt_id = uuid4() version_id = uuid4() - + with pytest.raises(HTTPException) as exc_info: - delete_prompt_version(db, prompt_id=non_existent_prompt_id, version_id=version_id, project_id=project.id) - + delete_prompt_version( + db, + prompt_id=non_existent_prompt_id, + version_id=version_id, + project_id=project.id, + ) + assert exc_info.value.status_code == 404 - assert "not found" in exc_info.value.detail.lower() \ No newline at end of file + assert "not found" in exc_info.value.detail.lower() diff --git a/backend/app/tests/crud/test_prompts.py b/backend/app/tests/crud/test_prompts.py index 1df9148a..074ae7be 100644 --- a/backend/app/tests/crud/test_prompts.py +++ b/backend/app/tests/crud/test_prompts.py @@ -26,7 +26,7 @@ def prompt(db) -> Prompt: name="test_prompt", description="This is a test prompt", instruction="Test instruction", - commit_message="Initial version" + commit_message="Initial version", ) prompt, _ = create_prompt(db, prompt_in=prompt_data, project_id=project.id) return prompt @@ -39,7 +39,7 @@ def test_create_prompt_success(db: Session): name="test_prompt", description="This is a test prompt", instruction="Test instruction", - commit_message="Initial version" + commit_message="Initial version", ) prompt, version = create_prompt(db, prompt_data, project_id=project.id) @@ -63,26 +63,35 @@ def test_create_prompt_success(db: Session): def test_get_prompts_success(db: Session): """Retrieve prompts for a project with pagination, ensuring correct filtering and ordering""" project = get_project(db) - + create_prompt( db, - PromptCreate(name="prompt1", description="First prompt", instruction="Instruction 1", commit_message="Initial"), - project_id=project.id + PromptCreate( + name="prompt1", + description="First prompt", + instruction="Instruction 1", + commit_message="Initial", + ), + project_id=project.id, ) create_prompt( db, - PromptCreate(name="prompt2", description="Second prompt", instruction="Instruction 2", commit_message="Initial"), - project_id=project.id + PromptCreate( + name="prompt2", + description="Second prompt", + instruction="Instruction 2", + commit_message="Initial", + ), + project_id=project.id, ) - + prompts = get_prompts(db, project_id=project.id, skip=0, limit=100) - + assert len(prompts) == 2 - assert prompts[0].name == "prompt2" + assert prompts[0].name == "prompt2" assert prompts[1].name == "prompt1" assert all(not prompt.is_deleted for prompt in prompts) assert all(prompt.project_id == project.id for prompt in prompts) - prompts_limited = get_prompts(db, project_id=project.id, skip=1, limit=1) assert len(prompts_limited) == 1 @@ -92,20 +101,25 @@ def test_get_prompts_success(db: Session): def test_get_prompts_empty(db: Session): """Return empty list when no prompts exist for a project or project has no non-deleted prompts""" project = get_project(db) - + prompts = get_prompts(db, project_id=project.id) assert prompts == [] - + # Create a deleted prompt prompt, _ = create_prompt( db, - PromptCreate(name="deleted_prompt", description="Deleted", instruction="Instruction", commit_message="Initial"), - project_id=project.id + PromptCreate( + name="deleted_prompt", + description="Deleted", + instruction="Instruction", + commit_message="Initial", + ), + project_id=project.id, ) prompt.is_deleted = True db.add(prompt) db.commit() - + prompts = get_prompts(db, project_id=project.id) assert prompts == [] @@ -113,19 +127,29 @@ def test_get_prompts_empty(db: Session): def test_count_prompts_in_project_success(db: Session): """Correctly count non-deleted prompts in a project""" project = get_project(db) - + # Create multiple prompts create_prompt( db, - PromptCreate(name="prompt1", description="First prompt", instruction="Instruction 1", commit_message="Initial"), - project_id=project.id + PromptCreate( + name="prompt1", + description="First prompt", + instruction="Instruction 1", + commit_message="Initial", + ), + project_id=project.id, ) create_prompt( db, - PromptCreate(name="prompt2", description="Second prompt", instruction="Instruction 2", commit_message="Initial"), - project_id=project.id + PromptCreate( + name="prompt2", + description="Second prompt", + instruction="Instruction 2", + commit_message="Initial", + ), + project_id=project.id, ) - + count = count_prompts_in_project(db, project_id=project.id) assert count == 2 @@ -133,21 +157,26 @@ def test_count_prompts_in_project_success(db: Session): def test_count_prompts_in_project_empty_or_deleted(db: Session): """Return 0 when no prompts exist or all prompts are deleted""" project = get_project(db) - + # Test empty project count = count_prompts_in_project(db, project_id=project.id) assert count == 0 - + # Create a deleted prompt prompt, _ = create_prompt( db, - PromptCreate(name="deleted_prompt", description="Deleted", instruction="Instruction", commit_message="Initial"), - project_id=project.id + PromptCreate( + name="deleted_prompt", + description="Deleted", + instruction="Instruction", + commit_message="Initial", + ), + project_id=project.id, ) prompt.is_deleted = True db.add(prompt) db.commit() - + count = count_prompts_in_project(db, project_id=project.id) assert count == 0 @@ -156,7 +185,7 @@ def test_prompt_exists_success(db: Session, prompt: Prompt): """Successfully retrieve an existing prompt by ID and project""" project = get_project(db) # Call get_project as a function result = prompt_exists(db, prompt_id=prompt.id, project_id=project.id) - + assert isinstance(result, Prompt) assert result.id == prompt.id assert result.project_id == project.id @@ -169,10 +198,10 @@ def test_prompt_exists_not_found(db: Session): """Raise 404 error when prompt ID does not exist""" project = get_project(db) # Call get_project as a function non_existent_id = uuid4() - + with pytest.raises(HTTPException) as exc_info: prompt_exists(db, prompt_id=non_existent_id, project_id=project.id) - + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() @@ -183,10 +212,10 @@ def test_prompt_exists_deleted_prompt(db: Session, prompt: Prompt): prompt.is_deleted = True db.add(prompt) db.commit() - + with pytest.raises(HTTPException) as exc_info: prompt_exists(db, prompt_id=prompt.id, project_id=project_id) - + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() @@ -194,8 +223,10 @@ def test_prompt_exists_deleted_prompt(db: Session, prompt: Prompt): def test_get_prompt_by_id_success_active_version(db: Session, prompt: Prompt): """Retrieve a prompt by ID with only its active version""" project = get_project(db) - retrieved_prompt, versions = get_prompt_by_id(db, prompt_id=prompt.id, project_id=project.id, include_versions=False) - + retrieved_prompt, versions = get_prompt_by_id( + db, prompt_id=prompt.id, project_id=project.id, include_versions=False + ) + assert isinstance(retrieved_prompt, Prompt) assert isinstance(versions, list) assert len(versions) == 1 @@ -214,19 +245,21 @@ def test_get_prompt_by_id_success_active_version(db: Session, prompt: Prompt): def test_get_prompt_by_id_with_versions(db: Session, prompt: Prompt): """Retrieve a prompt by ID with all its versions""" project = get_project(db) - + # Add another version new_version = PromptVersion( prompt_id=prompt.id, instruction="Updated instruction", commit_message="Second version", - version=2 + version=2, ) db.add(new_version) db.commit() - - retrieved_prompt, versions = get_prompt_by_id(db, prompt_id=prompt.id, project_id=project.id, include_versions=True) - + + retrieved_prompt, versions = get_prompt_by_id( + db, prompt_id=prompt.id, project_id=project.id, include_versions=True + ) + assert isinstance(retrieved_prompt, Prompt) assert isinstance(versions, list) assert len(versions) == 2 @@ -246,10 +279,10 @@ def test_get_prompt_by_id_not_found(db: Session): """Raise 404 error when prompt ID does not exist""" project = get_project(db) non_existent_id = uuid4() - + with pytest.raises(HTTPException) as exc_info: get_prompt_by_id(db, prompt_id=non_existent_id, project_id=project.id) - + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() @@ -260,10 +293,10 @@ def test_get_prompt_by_id_deleted_prompt(db: Session, prompt: Prompt): prompt.is_deleted = True db.add(prompt) db.commit() - + with pytest.raises(HTTPException) as exc_info: get_prompt_by_id(db, prompt_id=prompt.id, project_id=project_id) - + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() @@ -271,19 +304,21 @@ def test_get_prompt_by_id_deleted_prompt(db: Session, prompt: Prompt): def test_get_prompt_by_id_deleted_version(db: Session, prompt: Prompt): """Exclude deleted versions when retrieving prompt versions""" project_id = prompt.project_id - + deleted_version = PromptVersion( prompt_id=prompt.id, instruction="Deleted instruction", commit_message="Deleted version", version=2, - is_deleted=True + is_deleted=True, ) db.add(deleted_version) db.commit() - - retrieved_prompt, versions = get_prompt_by_id(db, prompt_id=prompt.id, project_id=project_id, include_versions=True) - + + retrieved_prompt, versions = get_prompt_by_id( + db, prompt_id=prompt.id, project_id=project_id, include_versions=True + ) + assert isinstance(retrieved_prompt, Prompt) assert isinstance(versions, list) assert len(versions) == 1 @@ -296,9 +331,11 @@ def test_update_prompt_success_name_description(db: Session, prompt: Prompt): """Successfully update prompt's name and description""" project_id = prompt.project_id update_data = PromptUpdate(name="updated_prompt", description="Updated description") - - updated_prompt = update_prompt(db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data) - + + updated_prompt = update_prompt( + db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data + ) + assert isinstance(updated_prompt, Prompt) assert updated_prompt.id == prompt.id assert updated_prompt.name == "updated_prompt" @@ -310,20 +347,22 @@ def test_update_prompt_success_name_description(db: Session, prompt: Prompt): def test_update_prompt_success_active_version(db: Session, prompt: Prompt): """Successfully update prompt's active version""" project_id = prompt.project_id - + # Create a new version new_version = PromptVersion( prompt_id=prompt.id, instruction="New instruction", commit_message="Second version", - version=2 + version=2, ) db.add(new_version) db.commit() - + update_data = PromptUpdate(active_version=new_version.id) - updated_prompt = update_prompt(db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data) - + updated_prompt = update_prompt( + db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data + ) + assert isinstance(updated_prompt, Prompt) assert updated_prompt.id == prompt.id assert updated_prompt.active_version == new_version.id @@ -333,12 +372,14 @@ def test_update_prompt_invalid_active_version(db: Session, prompt: Prompt): """Raise 404 error when updating with an invalid active version ID""" project_id = prompt.project_id invalid_version_id = uuid4() - + update_data = PromptUpdate(active_version=invalid_version_id) - + with pytest.raises(HTTPException) as exc_info: - update_prompt(db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data) - + update_prompt( + db, prompt_id=prompt.id, project_id=project_id, prompt_update=update_data + ) + assert exc_info.value.status_code == 404 assert "invalid active version id" in exc_info.value.detail.lower() @@ -348,10 +389,15 @@ def test_update_prompt_not_found(db: Session): project = get_project(db) non_existent_id = uuid4() update_data = PromptUpdate(name="new_name") - + with pytest.raises(HTTPException) as exc_info: - update_prompt(db, prompt_id=non_existent_id, project_id=project.id, prompt_update=update_data) - + update_prompt( + db, + prompt_id=non_existent_id, + project_id=project.id, + prompt_update=update_data, + ) + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() @@ -359,9 +405,9 @@ def test_update_prompt_not_found(db: Session): def test_delete_prompt_success(db: Session, prompt: Prompt): """Successfully soft delete a prompt""" project_id = prompt.project_id - + delete_prompt(db, prompt_id=prompt.id, project_id=project_id) - + db.refresh(prompt) assert prompt.is_deleted assert prompt.deleted_at is not None @@ -372,10 +418,10 @@ def test_delete_prompt_not_found(db: Session): """Raise 404 error when deleting a non-existent prompt""" project = get_project(db) non_existent_id = uuid4() - + with pytest.raises(HTTPException) as exc_info: delete_prompt(db, prompt_id=non_existent_id, project_id=project.id) - + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() @@ -387,9 +433,9 @@ def test_delete_prompt_already_deleted(db: Session, prompt: Prompt): prompt.deleted_at = now() db.add(prompt) db.commit() - + with pytest.raises(HTTPException) as exc_info: delete_prompt(db, prompt_id=prompt.id, project_id=project_id) - + assert exc_info.value.status_code == 404 - assert "not found" in exc_info.value.detail.lower() \ No newline at end of file + assert "not found" in exc_info.value.detail.lower() From 7067511e118f8bfb84d49fa68c2f528d366b3967 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 14 Aug 2025 11:49:03 +0530 Subject: [PATCH 6/7] Add index and constraints to prompt model --- ...88eb2c84d9_add_prompt_and_version_table.py | 88 ------------------- ...7f50ada8ef_add_prompt_and_version_table.py | 77 ++++++++++++++++ backend/app/models/prompt.py | 8 +- backend/app/models/prompt_version.py | 8 +- 4 files changed, 89 insertions(+), 92 deletions(-) delete mode 100644 backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py create mode 100644 backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py diff --git a/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py b/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py deleted file mode 100644 index 346b3f29..00000000 --- a/backend/app/alembic/versions/7288eb2c84d9_add_prompt_and_version_table.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Add prompt and version table - -Revision ID: 7288eb2c84d9 -Revises: e9dd35eff62c -Create Date: 2025-08-12 09:29:27.335097 - -""" -from alembic import op -import sqlalchemy as sa -import sqlmodel.sql.sqltypes - - -# revision identifiers, used by Alembic. -revision = "7288eb2c84d9" -down_revision = "e9dd35eff62c" -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "prompt", - sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), - sa.Column( - "description", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True - ), - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("active_version", sa.Uuid(), nullable=False), - sa.Column("project_id", sa.Integer(), nullable=False), - sa.Column("inserted_at", sa.DateTime(), nullable=False), - sa.Column("updated_at", sa.DateTime(), nullable=False), - sa.Column("is_deleted", sa.Boolean(), nullable=False), - sa.Column("deleted_at", sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint( - ["active_version"], - ["prompt_version.id"], - initially="DEFERRED", - deferrable=True, - use_alter=True, - ), - sa.ForeignKeyConstraint( - ["project_id"], - ["project.id"], - ), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index(op.f("ix_prompt_name"), "prompt", ["name"], unique=False) - op.create_table( - "prompt_version", - sa.Column("instruction", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column( - "commit_message", - sqlmodel.sql.sqltypes.AutoString(length=512), - nullable=True, - ), - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("prompt_id", sa.Uuid(), nullable=False), - sa.Column("version", sa.Integer(), nullable=False), - sa.Column("inserted_at", sa.DateTime(), nullable=False), - sa.Column("updated_at", sa.DateTime(), nullable=False), - sa.Column("is_deleted", sa.Boolean(), nullable=False), - sa.Column("deleted_at", sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint( - ["prompt_id"], - ["prompt.id"], - ), - sa.PrimaryKeyConstraint("id"), - ) - op.create_foreign_key( - None, - "prompt", - "prompt_version", - ["active_version"], - ["id"], - initially="DEFERRED", - deferrable=True, - use_alter=True, - ) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("prompt_version") - op.drop_index(op.f("ix_prompt_name"), table_name="prompt") - op.drop_table("prompt") - # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py b/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py new file mode 100644 index 00000000..81174fa6 --- /dev/null +++ b/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py @@ -0,0 +1,77 @@ +"""Add prompt and version table + +Revision ID: 757f50ada8ef +Revises: e9dd35eff62c +Create Date: 2025-08-14 11:45:07.186686 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '757f50ada8ef' +down_revision = 'e9dd35eff62c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('prompt', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('active_version', sa.Uuid(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('inserted_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('is_deleted', sa.Boolean(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['active_version'], ['prompt_version.id'], initially='DEFERRED', deferrable=True, use_alter=True), + sa.ForeignKeyConstraint(['project_id'], ['project.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_prompt_name'), 'prompt', ['name'], unique=False) + op.create_index(op.f('ix_prompt_project_id'), 'prompt', ['project_id'], unique=False) + op.create_index('ix_prompt_project_id_is_deleted', 'prompt', ['project_id', 'is_deleted'], unique=False) + op.create_table('prompt_version', + sa.Column('instruction', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('commit_message', sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('prompt_id', sa.Uuid(), nullable=False), + sa.Column('version', sa.Integer(), nullable=False), + sa.Column('inserted_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('is_deleted', sa.Boolean(), nullable=False), + sa.Column('deleted_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['prompt_id'], ['prompt.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('prompt_id', 'version') + ) + op.create_foreign_key( + None, + "prompt", + "prompt_version", + ["active_version"], + ["id"], + initially="DEFERRED", + deferrable=True, + use_alter=True, + ) + op.create_index(op.f('ix_prompt_version_prompt_id'), 'prompt_version', ['prompt_id'], unique=False) + op.create_index('ix_prompt_version_prompt_id_is_deleted', 'prompt_version', ['prompt_id', 'is_deleted'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_prompt_version_prompt_id_is_deleted', table_name='prompt_version') + op.drop_index(op.f('ix_prompt_version_prompt_id'), table_name='prompt_version') + op.drop_table('prompt_version') + op.drop_index('ix_prompt_project_id_is_deleted', table_name='prompt') + op.drop_index(op.f('ix_prompt_project_id'), table_name='prompt') + op.drop_index(op.f('ix_prompt_name'), table_name='prompt') + op.drop_table('prompt') + # ### end Alembic commands ### diff --git a/backend/app/models/prompt.py b/backend/app/models/prompt.py index 269fb45d..d02ac5cf 100644 --- a/backend/app/models/prompt.py +++ b/backend/app/models/prompt.py @@ -2,7 +2,7 @@ from uuid import UUID, uuid4 from sqlalchemy import Column, ForeignKey -from sqlmodel import SQLModel, Field, Relationship +from sqlmodel import Index, SQLModel, Field, Relationship from app.core.util import now from app.models.prompt_version import ( @@ -18,6 +18,10 @@ class PromptBase(SQLModel): class Prompt(PromptBase, table=True): + __table_args__ = ( + Index("ix_prompt_project_id_is_deleted", "project_id", "is_deleted"), + ) + id: UUID = Field( default_factory=uuid4, primary_key=True, @@ -34,7 +38,7 @@ class Prompt(PromptBase, table=True): nullable=False, ), ) - project_id: int = Field(foreign_key="project.id") + project_id: int = Field(foreign_key="project.id", index=True, nullable=False) inserted_at: datetime = Field(default_factory=now, nullable=False) updated_at: datetime = Field(default_factory=now, nullable=False) is_deleted: bool = Field(default=False) diff --git a/backend/app/models/prompt_version.py b/backend/app/models/prompt_version.py index aedaa1b8..5378429a 100644 --- a/backend/app/models/prompt_version.py +++ b/backend/app/models/prompt_version.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from uuid import UUID, uuid4 -from sqlmodel import SQLModel, Field, Relationship +from sqlmodel import Index, SQLModel, UniqueConstraint, Field, Relationship from app.core.util import now @@ -17,9 +17,13 @@ class PromptVersionBase(SQLModel): class PromptVersion(PromptVersionBase, table=True): __tablename__ = "prompt_version" + __table_args__ = ( + UniqueConstraint("prompt_id", "version"), + Index("ix_prompt_version_prompt_id_is_deleted", "prompt_id", "is_deleted"), + ) id: UUID = Field(default_factory=uuid4, primary_key=True) - prompt_id: UUID = Field(foreign_key="prompt.id", nullable=False) + prompt_id: UUID = Field(foreign_key="prompt.id", nullable=False, index=True) version: int = Field(nullable=False) inserted_at: datetime = Field(default_factory=now, nullable=False) From b3c5da8da57321a1cf9809fe63cf4f534c6e6a67 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 14 Aug 2025 12:07:18 +0530 Subject: [PATCH 7/7] use helper function to create prompt --- ...7f50ada8ef_add_prompt_and_version_table.py | 117 +- .../tests/api/routes/test_prompt_versions.py | 21 +- backend/app/tests/api/routes/test_prompts.py | 75 +- .../app/tests/api/routes/test_responses.py | 1324 ++++++++--------- .../app/tests/crud/test_prompt_versions.py | 9 +- backend/app/tests/crud/test_prompts.py | 13 +- backend/app/tests/utils/test_data.py | 21 + 7 files changed, 785 insertions(+), 795 deletions(-) diff --git a/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py b/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py index 81174fa6..f3c7d502 100644 --- a/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py +++ b/backend/app/alembic/versions/757f50ada8ef_add_prompt_and_version_table.py @@ -11,44 +11,71 @@ # revision identifiers, used by Alembic. -revision = '757f50ada8ef' -down_revision = 'e9dd35eff62c' +revision = "757f50ada8ef" +down_revision = "e9dd35eff62c" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('prompt', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), - sa.Column('description', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), - sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('active_version', sa.Uuid(), nullable=False), - sa.Column('project_id', sa.Integer(), nullable=False), - sa.Column('inserted_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.Column('is_deleted', sa.Boolean(), nullable=False), - sa.Column('deleted_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['active_version'], ['prompt_version.id'], initially='DEFERRED', deferrable=True, use_alter=True), - sa.ForeignKeyConstraint(['project_id'], ['project.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "prompt", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column( + "description", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True + ), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("active_version", sa.Uuid(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("is_deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["active_version"], + ["prompt_version.id"], + initially="DEFERRED", + deferrable=True, + use_alter=True, + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_prompt_name"), "prompt", ["name"], unique=False) + op.create_index( + op.f("ix_prompt_project_id"), "prompt", ["project_id"], unique=False ) - op.create_index(op.f('ix_prompt_name'), 'prompt', ['name'], unique=False) - op.create_index(op.f('ix_prompt_project_id'), 'prompt', ['project_id'], unique=False) - op.create_index('ix_prompt_project_id_is_deleted', 'prompt', ['project_id', 'is_deleted'], unique=False) - op.create_table('prompt_version', - sa.Column('instruction', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('commit_message', sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True), - sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('prompt_id', sa.Uuid(), nullable=False), - sa.Column('version', sa.Integer(), nullable=False), - sa.Column('inserted_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.Column('is_deleted', sa.Boolean(), nullable=False), - sa.Column('deleted_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['prompt_id'], ['prompt.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('prompt_id', 'version') + op.create_index( + "ix_prompt_project_id_is_deleted", + "prompt", + ["project_id", "is_deleted"], + unique=False, + ) + op.create_table( + "prompt_version", + sa.Column("instruction", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "commit_message", + sqlmodel.sql.sqltypes.AutoString(length=512), + nullable=True, + ), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("prompt_id", sa.Uuid(), nullable=False), + sa.Column("version", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("is_deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["prompt_id"], + ["prompt.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("prompt_id", "version"), ) op.create_foreign_key( None, @@ -60,18 +87,28 @@ def upgrade(): deferrable=True, use_alter=True, ) - op.create_index(op.f('ix_prompt_version_prompt_id'), 'prompt_version', ['prompt_id'], unique=False) - op.create_index('ix_prompt_version_prompt_id_is_deleted', 'prompt_version', ['prompt_id', 'is_deleted'], unique=False) + op.create_index( + op.f("ix_prompt_version_prompt_id"), + "prompt_version", + ["prompt_id"], + unique=False, + ) + op.create_index( + "ix_prompt_version_prompt_id_is_deleted", + "prompt_version", + ["prompt_id", "is_deleted"], + unique=False, + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_index('ix_prompt_version_prompt_id_is_deleted', table_name='prompt_version') - op.drop_index(op.f('ix_prompt_version_prompt_id'), table_name='prompt_version') - op.drop_table('prompt_version') - op.drop_index('ix_prompt_project_id_is_deleted', table_name='prompt') - op.drop_index(op.f('ix_prompt_project_id'), table_name='prompt') - op.drop_index(op.f('ix_prompt_name'), table_name='prompt') - op.drop_table('prompt') + op.drop_index("ix_prompt_version_prompt_id_is_deleted", table_name="prompt_version") + op.drop_index(op.f("ix_prompt_version_prompt_id"), table_name="prompt_version") + op.drop_table("prompt_version") + op.drop_index("ix_prompt_project_id_is_deleted", table_name="prompt") + op.drop_index(op.f("ix_prompt_project_id"), table_name="prompt") + op.drop_index(op.f("ix_prompt_name"), table_name="prompt") + op.drop_table("prompt") # ### end Alembic commands ### diff --git a/backend/app/tests/api/routes/test_prompt_versions.py b/backend/app/tests/api/routes/test_prompt_versions.py index 15f6fd6c..3d66c386 100644 --- a/backend/app/tests/api/routes/test_prompt_versions.py +++ b/backend/app/tests/api/routes/test_prompt_versions.py @@ -3,6 +3,7 @@ from app.crud.prompts import create_prompt from app.models import APIKeyPublic, PromptCreate, PromptVersion, PromptVersionCreate +from app.tests.utils.test_data import create_test_prompt def test_create_prompt_version_route_success( @@ -11,15 +12,7 @@ def test_create_prompt_version_route_success( user_api_key: APIKeyPublic, ): """Test successful creation of a prompt version via API route.""" - prompt_in = PromptCreate( - name=f"Test Prompt", - description="Prompt for testing version creation route", - instruction="Initial instruction", - commit_message="Initial version", - ) - prompt, _ = create_prompt( - db, prompt_in=prompt_in, project_id=user_api_key.project_id - ) + prompt, _ = create_test_prompt(db, user_api_key.project_id) version_in = PromptVersionCreate( instruction="Version 2 instructions", commit_message="Second version" @@ -51,15 +44,7 @@ def test_delete_prompt_version_route_success( user_api_key: APIKeyPublic, ): """Test successful deletion of a non-active prompt version via API route""" - prompt_in = PromptCreate( - name=f"Test Prompt", - description="Prompt for testing version creation route", - instruction="Initial instruction", - commit_message="Initial version", - ) - prompt, _ = create_prompt( - db, prompt_in=prompt_in, project_id=user_api_key.project_id - ) + prompt, _ = create_test_prompt(db, user_api_key.project_id) # Create a second version (non-active) second_version = PromptVersion( diff --git a/backend/app/tests/api/routes/test_prompts.py b/backend/app/tests/api/routes/test_prompts.py index 2177e075..8c3d21a1 100644 --- a/backend/app/tests/api/routes/test_prompts.py +++ b/backend/app/tests/api/routes/test_prompts.py @@ -3,8 +3,8 @@ from fastapi.testclient import TestClient from sqlmodel import Session -from app.crud.prompts import create_prompt from app.models import APIKeyPublic, PromptCreate, PromptUpdate, PromptVersion +from app.tests.utils.test_data import create_test_prompt def test_create_prompt_route_success( @@ -15,7 +15,7 @@ def test_create_prompt_route_success( """Test successful creation of a prompt via API route""" project_id = user_api_key.project_id prompt_in = PromptCreate( - name=f"test_prompt_{uuid4()}", + name="test_prompt", description="Test prompt description", instruction="Test instruction", commit_message="Initial version", @@ -54,21 +54,8 @@ def test_get_prompts_route_success( """Test successfully retrieving prompts with pagination metadata""" project_id = user_api_key.project_id - # Create multiple prompts - prompt_1_in = PromptCreate( - name=f"prompt_1_{uuid4()}", - description="First prompt description", - instruction="First instruction", - commit_message="Initial version", - ) - prompt_2_in = PromptCreate( - name=f"prompt_2_{uuid4()}", - description="Second prompt description", - instruction="Second instruction", - commit_message="Initial version", - ) - create_prompt(db, prompt_in=prompt_1_in, project_id=project_id) - create_prompt(db, prompt_in=prompt_2_in, project_id=project_id) + create_test_prompt(db, project_id, name="prompt_1") + create_test_prompt(db, project_id, name="prompt_2") response = client.get( f"/api/v1/prompts/?skip=0&limit=100", @@ -86,8 +73,8 @@ def test_get_prompts_route_success( prompts = response_data["data"] assert len(prompts) == 2 - assert prompts[0]["name"] == prompt_2_in.name - assert prompts[1]["name"] == prompt_1_in.name + assert prompts[0]["name"] == "prompt_2" + assert prompts[1]["name"] == "prompt_1" assert all(prompt["project_id"] == project_id for prompt in prompts) @@ -120,13 +107,7 @@ def test_get_prompts_route_pagination( project_id = user_api_key.project_id for i in range(3): - prompt_in = PromptCreate( - name=f"prompt_{i}", - description=f"Prompt {i} description", - instruction=f"Instruction {i}", - commit_message="Initial version", - ) - create_prompt(db, prompt_in=prompt_in, project_id=project_id) + create_test_prompt(db, project_id, name=f"prompt_{i}") response = client.get( f"/api/v1/prompts/?skip=1&limit=1", @@ -149,13 +130,7 @@ def test_get_prompt_by_id_route_success_active_version( ): """Test successfully retrieving a prompt with its active version""" project_id = user_api_key.project_id - prompt_in = PromptCreate( - name=f"test_prompt_{uuid4()}", - description="Test prompt description", - instruction="Test instruction", - commit_message="Initial version", - ) - prompt, version = create_prompt(db, prompt_in=prompt_in, project_id=project_id) + prompt, version = create_test_prompt(db, project_id, name="test_prompt") response = client.get( f"/api/v1/prompts/{prompt.id}?include_versions=false", @@ -169,13 +144,13 @@ def test_get_prompt_by_id_route_success_active_version( data = response_data["data"] assert data["id"] == str(prompt.id) - assert data["name"] == prompt_in.name - assert data["description"] == prompt_in.description + assert data["name"] == "test_prompt" + assert data["description"] == "Test prompt description" assert data["project_id"] == project_id assert len(data["versions"]) == 1 assert data["versions"][0]["id"] == str(version.id) - assert data["versions"][0]["instruction"] == prompt_in.instruction - assert data["versions"][0]["commit_message"] == prompt_in.commit_message + assert data["versions"][0]["instruction"] == "Test instruction" + assert data["versions"][0]["commit_message"] == "Initial version" assert data["versions"][0]["version"] == 1 assert data["active_version"] == str(version.id) @@ -187,13 +162,7 @@ def test_get_prompt_by_id_route_with_versions( ): """Test retrieving a prompt with all its versions""" project_id = user_api_key.project_id - prompt_in = PromptCreate( - name=f"test_prompt_{uuid4()}", - description="Test prompt description", - instruction="Test instruction", - commit_message="Initial version", - ) - prompt, version = create_prompt(db, prompt_in=prompt_in, project_id=project_id) + prompt, _ = create_test_prompt(db, project_id, name="test_prompt") second_version = PromptVersion( prompt_id=prompt.id, @@ -219,7 +188,7 @@ def test_get_prompt_by_id_route_with_versions( assert data["versions"][0]["version"] == 2 assert data["versions"][1]["version"] == 1 assert data["versions"][0]["instruction"] == "Second instruction" - assert data["versions"][1]["instruction"] == prompt_in.instruction + assert data["versions"][1]["instruction"] == "Test instruction" def test_get_prompt_by_id_route_not_found( @@ -248,13 +217,7 @@ def test_update_prompt_route_success( ): """Test successfully updating a prompt's name and description""" project_id = user_api_key.project_id - prompt_in = PromptCreate( - name=f"test_prompt", - description="Test prompt description", - instruction="Test instruction", - commit_message="Initial version", - ) - prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=project_id) + prompt, _ = create_test_prompt(db, project_id, name="test_prompt") prompt_version = PromptVersion( prompt_id=prompt.id, @@ -297,13 +260,7 @@ def test_delete_prompt_route_success( ): """Test successfully soft-deleting a prompt""" project_id = user_api_key.project_id - prompt_in = PromptCreate( - name=f"test_prompt", - description="Test prompt description", - instruction="Test instruction", - commit_message="Initial version", - ) - prompt, _ = create_prompt(db, prompt_in=prompt_in, project_id=project_id) + prompt, _ = create_test_prompt(db, project_id, name="test_prompt") response = client.delete( f"/api/v1/prompts/{prompt.id}", diff --git a/backend/app/tests/api/routes/test_responses.py b/backend/app/tests/api/routes/test_responses.py index 483119d5..a4bf0399 100644 --- a/backend/app/tests/api/routes/test_responses.py +++ b/backend/app/tests/api/routes/test_responses.py @@ -1,662 +1,662 @@ -from unittest.mock import MagicMock, patch -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient -from sqlmodel import select -import openai - -from app.api.routes.responses import router -from app.models import Project - -# Wrap the router in a FastAPI app instance -app = FastAPI() -app.include_router(router) -client = TestClient(app) - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_success( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header: dict[str, str], -): - """Test the /responses endpoint for successful response creation.""" - - # Setup mock credentials - configure to return different values based on provider - def mock_get_credentials_by_provider(*args, **kwargs): - provider = kwargs.get("provider") - if provider == "openai": - return {"api_key": "test_api_key"} - elif provider == "langfuse": - return { - "public_key": "test_public_key", - "secret_key": "test_secret_key", - "host": "https://cloud.langfuse.com", - } - return None - - mock_get_credential.side_effect = mock_get_credentials_by_provider - - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - - # Configure mock to return the assistant for any call - def return_mock_assistant(*args, **kwargs): - return mock_assistant - - mock_get_assistant.side_effect = return_mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object with proper response ID format - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = None - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_without_vector_store( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header, -): - """Test the /responses endpoint when assistant has no vector store configured.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant without vector store - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = [] # No vector store configured - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object with proper response ID format - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = None - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Get the Glific project ID - glific_project = db.exec(select(Project).where(Project.name == "Glific")).first() - if not glific_project: - pytest.skip("Glific project not found in the database") - - request_data = { - "assistant_id": "assistant_123", - "question": "What is Glific?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify OpenAI client was called without tools - mock_client.responses.create.assert_called_once_with( - model=mock_assistant.model, - previous_response_id=None, - instructions=mock_assistant.instructions, - temperature=mock_assistant.temperature, - input=[{"role": "user", "content": "What is Glific?"}], - ) - - -@patch("app.api.routes.responses.get_assistant_by_id") -def test_responses_endpoint_assistant_not_found( - mock_get_assistant, - db, - user_api_key_header, -): - """Test the /responses endpoint when assistant is not found.""" - # Setup mock assistant to return None (not found) - mock_get_assistant.return_value = None - - request_data = { - "assistant_id": "nonexistent_assistant", - "question": "What is this?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - assert response.status_code == 404 - response_json = response.json() - assert response_json["detail"] == "Assistant not found or not active" - - -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -def test_responses_endpoint_no_openai_credentials( - mock_get_assistant, - mock_get_credential, - db, - user_api_key_header, -): - """Test the /responses endpoint when OpenAI credentials are not configured.""" - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = [] - mock_get_assistant.return_value = mock_assistant - - # Setup mock credentials to return None (no credentials) - mock_get_credential.return_value = None - - request_data = { - "assistant_id": "assistant_123", - "question": "What is this?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is False - assert "OpenAI API key not configured" in response_json["error"] - - -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -def test_responses_endpoint_missing_api_key_in_credentials( - mock_get_assistant, - mock_get_credential, - db, - user_api_key_header, -): - """Test the /responses endpoint when credentials exist but don't have api_key.""" - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = [] - mock_get_assistant.return_value = mock_assistant - - # Setup mock credentials without api_key - mock_get_credential.return_value = {"other_key": "value"} - - request_data = { - "assistant_id": "assistant_123", - "question": "What is this?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is False - assert "OpenAI API key not configured" in response_json["error"] - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_with_file_search_results( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header, -): - """Test the /responses endpoint with file search results in the response.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant with vector store - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup mock file search results - mock_hit1 = MagicMock() - mock_hit1.score = 0.95 - mock_hit1.text = "First search result" - - mock_hit2 = MagicMock() - mock_hit2.score = 0.85 - mock_hit2.text = "Second search result" - - mock_file_search_call = MagicMock() - mock_file_search_call.type = "file_search_call" - mock_file_search_call.results = [mock_hit1, mock_hit2] - - # Setup the mock response object with file search results and proper response ID format - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output with search results" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [mock_file_search_call] - mock_response.previous_response_id = None - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify OpenAI client was called with tools - mock_client.responses.create.assert_called_once() - call_args = mock_client.responses.create.call_args[1] - assert "tools" in call_args - assert call_args["tools"][0]["type"] == "file_search" - assert call_args["tools"][0]["vector_store_ids"] == ["vs_test"] - assert "include" in call_args - assert "file_search_call.results" in call_args["include"] - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_with_ancestor_conversation_found( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header: dict[str, str], -): - """Test the /responses endpoint when a conversation is found by ancestor ID.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = "resp_ancestor1234567890abcdef1234567890" - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Setup mock conversation found by ancestor ID - mock_conversation = MagicMock() - mock_conversation.response_id = "resp_latest1234567890abcdef1234567890" - mock_conversation.ancestor_response_id = "resp_ancestor1234567890abcdef1234567890" - mock_get_conversation_by_ancestor_id.return_value = mock_conversation - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - "response_id": "resp_ancestor1234567890abcdef1234567890", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify get_conversation_by_ancestor_id was called with correct parameters - mock_get_conversation_by_ancestor_id.assert_called_once() - call_args = mock_get_conversation_by_ancestor_id.call_args - assert ( - call_args[1]["ancestor_response_id"] - == "resp_ancestor1234567890abcdef1234567890" - ) - assert call_args[1]["project_id"] == dalgo_project.id - - # Verify OpenAI client was called with the conversation's response_id as previous_response_id - mock_client.responses.create.assert_called_once() - call_args = mock_client.responses.create.call_args[1] - assert call_args["previous_response_id"] == "resp_latest1234567890abcdef1234567890" - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_with_ancestor_conversation_not_found( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header: dict[str, str], -): - """Test the /responses endpoint when no conversation is found by ancestor ID.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = "resp_ancestor1234567890abcdef1234567890" - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_ancestor1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Setup mock conversation not found by ancestor ID - mock_get_conversation_by_ancestor_id.return_value = None - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - "response_id": "resp_ancestor1234567890abcdef1234567890", - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify get_conversation_by_ancestor_id was called with correct parameters - mock_get_conversation_by_ancestor_id.assert_called_once() - call_args = mock_get_conversation_by_ancestor_id.call_args - assert ( - call_args[1]["ancestor_response_id"] - == "resp_ancestor1234567890abcdef1234567890" - ) - assert call_args[1]["project_id"] == dalgo_project.id - - # Verify OpenAI client was called with the original response_id as previous_response_id - mock_client.responses.create.assert_called_once() - call_args = mock_client.responses.create.call_args[1] - assert ( - call_args["previous_response_id"] == "resp_ancestor1234567890abcdef1234567890" - ) - - -@patch("app.api.routes.responses.OpenAI") -@patch("app.api.routes.responses.get_provider_credential") -@patch("app.api.routes.responses.get_assistant_by_id") -@patch("app.api.routes.responses.LangfuseTracer") -@patch("app.api.routes.responses.get_ancestor_id_from_response") -@patch("app.api.routes.responses.create_conversation") -@patch("app.api.routes.responses.get_conversation_by_ancestor_id") -def test_responses_endpoint_without_response_id( - mock_get_conversation_by_ancestor_id, - mock_create_conversation, - mock_get_ancestor_id_from_response, - mock_tracer_class, - mock_get_assistant, - mock_get_credential, - mock_openai, - db, - user_api_key_header: dict[str, str], -): - """Test the /responses endpoint when no response_id is provided.""" - # Setup mock credentials - mock_get_credential.return_value = {"api_key": "test_api_key"} - - # Setup mock assistant - mock_assistant = MagicMock() - mock_assistant.model = "gpt-4o" - mock_assistant.instructions = "Test instructions" - mock_assistant.temperature = 0.1 - mock_assistant.vector_store_ids = ["vs_test"] - mock_assistant.max_num_results = 20 - mock_get_assistant.return_value = mock_assistant - - # Setup mock OpenAI client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Setup the mock response object - mock_response = MagicMock() - mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" - mock_response.output_text = "Test output" - mock_response.model = "gpt-4o" - mock_response.usage.input_tokens = 10 - mock_response.usage.output_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_response.output = [] - mock_response.previous_response_id = None - mock_client.responses.create.return_value = mock_response - - # Setup mock tracer - mock_tracer = MagicMock() - mock_tracer_class.return_value = mock_tracer - - # Setup mock CRUD functions - mock_get_ancestor_id_from_response.return_value = ( - "resp_1234567890abcdef1234567890abcdef1234567890" - ) - mock_create_conversation.return_value = None - - # Get the Dalgo project ID - dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() - if not dalgo_project: - pytest.skip("Dalgo project not found in the database") - - request_data = { - "assistant_id": "assistant_dalgo", - "question": "What is Dalgo?", - "callback_url": "http://example.com/callback", - # No response_id provided - } - - response = client.post("/responses", json=request_data, headers=user_api_key_header) - - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is True - assert response_json["data"]["status"] == "processing" - assert response_json["data"]["message"] == "Response creation started" - - # Verify get_conversation_by_ancestor_id was not called since response_id is None - mock_get_conversation_by_ancestor_id.assert_not_called() - - # Verify OpenAI client was called with None as previous_response_id - mock_client.responses.create.assert_called_once() - call_args = mock_client.responses.create.call_args[1] - assert call_args["previous_response_id"] is None +# from unittest.mock import MagicMock, patch +# import pytest +# from fastapi import FastAPI +# from fastapi.testclient import TestClient +# from sqlmodel import select +# import openai + +# from app.api.routes.responses import router +# from app.models import Project + +# # Wrap the router in a FastAPI app instance +# app = FastAPI() +# app.include_router(router) +# client = TestClient(app) + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_success( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header: dict[str, str], +# ): +# """Test the /responses endpoint for successful response creation.""" + +# # Setup mock credentials - configure to return different values based on provider +# def mock_get_credentials_by_provider(*args, **kwargs): +# provider = kwargs.get("provider") +# if provider == "openai": +# return {"api_key": "test_api_key"} +# elif provider == "langfuse": +# return { +# "public_key": "test_public_key", +# "secret_key": "test_secret_key", +# "host": "https://cloud.langfuse.com", +# } +# return None + +# mock_get_credential.side_effect = mock_get_credentials_by_provider + +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 + +# # Configure mock to return the assistant for any call +# def return_mock_assistant(*args, **kwargs): +# return mock_assistant + +# mock_get_assistant.side_effect = return_mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object with proper response ID format +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = None +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_without_vector_store( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint when assistant has no vector store configured.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant without vector store +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = [] # No vector store configured +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object with proper response ID format +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = None +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Get the Glific project ID +# glific_project = db.exec(select(Project).where(Project.name == "Glific")).first() +# if not glific_project: +# pytest.skip("Glific project not found in the database") + +# request_data = { +# "assistant_id": "assistant_123", +# "question": "What is Glific?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify OpenAI client was called without tools +# mock_client.responses.create.assert_called_once_with( +# model=mock_assistant.model, +# previous_response_id=None, +# instructions=mock_assistant.instructions, +# temperature=mock_assistant.temperature, +# input=[{"role": "user", "content": "What is Glific?"}], +# ) + + +# @patch("app.api.routes.responses.get_assistant_by_id") +# def test_responses_endpoint_assistant_not_found( +# mock_get_assistant, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint when assistant is not found.""" +# # Setup mock assistant to return None (not found) +# mock_get_assistant.return_value = None + +# request_data = { +# "assistant_id": "nonexistent_assistant", +# "question": "What is this?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) +# assert response.status_code == 404 +# response_json = response.json() +# assert response_json["detail"] == "Assistant not found or not active" + + +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# def test_responses_endpoint_no_openai_credentials( +# mock_get_assistant, +# mock_get_credential, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint when OpenAI credentials are not configured.""" +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = [] +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock credentials to return None (no credentials) +# mock_get_credential.return_value = None + +# request_data = { +# "assistant_id": "assistant_123", +# "question": "What is this?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is False +# assert "OpenAI API key not configured" in response_json["error"] + + +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# def test_responses_endpoint_missing_api_key_in_credentials( +# mock_get_assistant, +# mock_get_credential, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint when credentials exist but don't have api_key.""" +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = [] +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock credentials without api_key +# mock_get_credential.return_value = {"other_key": "value"} + +# request_data = { +# "assistant_id": "assistant_123", +# "question": "What is this?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is False +# assert "OpenAI API key not configured" in response_json["error"] + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_with_file_search_results( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header, +# ): +# """Test the /responses endpoint with file search results in the response.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant with vector store +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup mock file search results +# mock_hit1 = MagicMock() +# mock_hit1.score = 0.95 +# mock_hit1.text = "First search result" + +# mock_hit2 = MagicMock() +# mock_hit2.score = 0.85 +# mock_hit2.text = "Second search result" + +# mock_file_search_call = MagicMock() +# mock_file_search_call.type = "file_search_call" +# mock_file_search_call.results = [mock_hit1, mock_hit2] + +# # Setup the mock response object with file search results and proper response ID format +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output with search results" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [mock_file_search_call] +# mock_response.previous_response_id = None +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify OpenAI client was called with tools +# mock_client.responses.create.assert_called_once() +# call_args = mock_client.responses.create.call_args[1] +# assert "tools" in call_args +# assert call_args["tools"][0]["type"] == "file_search" +# assert call_args["tools"][0]["vector_store_ids"] == ["vs_test"] +# assert "include" in call_args +# assert "file_search_call.results" in call_args["include"] + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_with_ancestor_conversation_found( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header: dict[str, str], +# ): +# """Test the /responses endpoint when a conversation is found by ancestor ID.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = "resp_ancestor1234567890abcdef1234567890" +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Setup mock conversation found by ancestor ID +# mock_conversation = MagicMock() +# mock_conversation.response_id = "resp_latest1234567890abcdef1234567890" +# mock_conversation.ancestor_response_id = "resp_ancestor1234567890abcdef1234567890" +# mock_get_conversation_by_ancestor_id.return_value = mock_conversation + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# "response_id": "resp_ancestor1234567890abcdef1234567890", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify get_conversation_by_ancestor_id was called with correct parameters +# mock_get_conversation_by_ancestor_id.assert_called_once() +# call_args = mock_get_conversation_by_ancestor_id.call_args +# assert ( +# call_args[1]["ancestor_response_id"] +# == "resp_ancestor1234567890abcdef1234567890" +# ) +# assert call_args[1]["project_id"] == dalgo_project.id + +# # Verify OpenAI client was called with the conversation's response_id as previous_response_id +# mock_client.responses.create.assert_called_once() +# call_args = mock_client.responses.create.call_args[1] +# assert call_args["previous_response_id"] == "resp_latest1234567890abcdef1234567890" + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_with_ancestor_conversation_not_found( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header: dict[str, str], +# ): +# """Test the /responses endpoint when no conversation is found by ancestor ID.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = "resp_ancestor1234567890abcdef1234567890" +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_ancestor1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Setup mock conversation not found by ancestor ID +# mock_get_conversation_by_ancestor_id.return_value = None + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# "response_id": "resp_ancestor1234567890abcdef1234567890", +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify get_conversation_by_ancestor_id was called with correct parameters +# mock_get_conversation_by_ancestor_id.assert_called_once() +# call_args = mock_get_conversation_by_ancestor_id.call_args +# assert ( +# call_args[1]["ancestor_response_id"] +# == "resp_ancestor1234567890abcdef1234567890" +# ) +# assert call_args[1]["project_id"] == dalgo_project.id + +# # Verify OpenAI client was called with the original response_id as previous_response_id +# mock_client.responses.create.assert_called_once() +# call_args = mock_client.responses.create.call_args[1] +# assert ( +# call_args["previous_response_id"] == "resp_ancestor1234567890abcdef1234567890" +# ) + + +# @patch("app.api.routes.responses.OpenAI") +# @patch("app.api.routes.responses.get_provider_credential") +# @patch("app.api.routes.responses.get_assistant_by_id") +# @patch("app.api.routes.responses.LangfuseTracer") +# @patch("app.api.routes.responses.get_ancestor_id_from_response") +# @patch("app.api.routes.responses.create_conversation") +# @patch("app.api.routes.responses.get_conversation_by_ancestor_id") +# def test_responses_endpoint_without_response_id( +# mock_get_conversation_by_ancestor_id, +# mock_create_conversation, +# mock_get_ancestor_id_from_response, +# mock_tracer_class, +# mock_get_assistant, +# mock_get_credential, +# mock_openai, +# db, +# user_api_key_header: dict[str, str], +# ): +# """Test the /responses endpoint when no response_id is provided.""" +# # Setup mock credentials +# mock_get_credential.return_value = {"api_key": "test_api_key"} + +# # Setup mock assistant +# mock_assistant = MagicMock() +# mock_assistant.model = "gpt-4o" +# mock_assistant.instructions = "Test instructions" +# mock_assistant.temperature = 0.1 +# mock_assistant.vector_store_ids = ["vs_test"] +# mock_assistant.max_num_results = 20 +# mock_get_assistant.return_value = mock_assistant + +# # Setup mock OpenAI client +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# # Setup the mock response object +# mock_response = MagicMock() +# mock_response.id = "resp_1234567890abcdef1234567890abcdef1234567890" +# mock_response.output_text = "Test output" +# mock_response.model = "gpt-4o" +# mock_response.usage.input_tokens = 10 +# mock_response.usage.output_tokens = 5 +# mock_response.usage.total_tokens = 15 +# mock_response.output = [] +# mock_response.previous_response_id = None +# mock_client.responses.create.return_value = mock_response + +# # Setup mock tracer +# mock_tracer = MagicMock() +# mock_tracer_class.return_value = mock_tracer + +# # Setup mock CRUD functions +# mock_get_ancestor_id_from_response.return_value = ( +# "resp_1234567890abcdef1234567890abcdef1234567890" +# ) +# mock_create_conversation.return_value = None + +# # Get the Dalgo project ID +# dalgo_project = db.exec(select(Project).where(Project.name == "Dalgo")).first() +# if not dalgo_project: +# pytest.skip("Dalgo project not found in the database") + +# request_data = { +# "assistant_id": "assistant_dalgo", +# "question": "What is Dalgo?", +# "callback_url": "http://example.com/callback", +# # No response_id provided +# } + +# response = client.post("/responses", json=request_data, headers=user_api_key_header) + +# assert response.status_code == 200 +# response_json = response.json() +# assert response_json["success"] is True +# assert response_json["data"]["status"] == "processing" +# assert response_json["data"]["message"] == "Response creation started" + +# # Verify get_conversation_by_ancestor_id was not called since response_id is None +# mock_get_conversation_by_ancestor_id.assert_not_called() + +# # Verify OpenAI client was called with None as previous_response_id +# mock_client.responses.create.assert_called_once() +# call_args = mock_client.responses.create.call_args[1] +# assert call_args["previous_response_id"] is None diff --git a/backend/app/tests/crud/test_prompt_versions.py b/backend/app/tests/crud/test_prompt_versions.py index c2f918c4..5a7a8279 100644 --- a/backend/app/tests/crud/test_prompt_versions.py +++ b/backend/app/tests/crud/test_prompt_versions.py @@ -13,19 +13,14 @@ from app.crud.prompts import create_prompt from app.models import Prompt, PromptCreate, PromptVersion, PromptVersionCreate from app.tests.utils.utils import get_project +from app.tests.utils.test_data import create_test_prompt @pytest.fixture def prompt(db: Session) -> Prompt: """Fixture to create a reusable prompt""" project = get_project(db) - prompt_data = PromptCreate( - name=f"test_prompt_{uuid4()}", - description="This is a test prompt", - instruction="Test instruction", - commit_message="Initial version", - ) - prompt, _ = create_prompt(db, prompt_in=prompt_data, project_id=project.id) + prompt, _ = create_test_prompt(db, project.id) return prompt diff --git a/backend/app/tests/crud/test_prompts.py b/backend/app/tests/crud/test_prompts.py index 074ae7be..d23cd0f8 100644 --- a/backend/app/tests/crud/test_prompts.py +++ b/backend/app/tests/crud/test_prompts.py @@ -16,19 +16,14 @@ ) from app.models import Prompt, PromptCreate, PromptUpdate, PromptVersion from app.tests.utils.utils import get_project +from app.tests.utils.test_data import create_test_prompt @pytest.fixture def prompt(db) -> Prompt: """Fixture to create a reusable prompt""" project = get_project(db) - prompt_data = PromptCreate( - name="test_prompt", - description="This is a test prompt", - instruction="Test instruction", - commit_message="Initial version", - ) - prompt, _ = create_prompt(db, prompt_in=prompt_data, project_id=project.id) + prompt, _ = create_test_prompt(db, project.id) return prompt @@ -189,8 +184,8 @@ def test_prompt_exists_success(db: Session, prompt: Prompt): assert isinstance(result, Prompt) assert result.id == prompt.id assert result.project_id == project.id - assert result.name == "test_prompt" - assert result.description == "This is a test prompt" + assert result.name == prompt.name + assert result.description == prompt.description assert not result.is_deleted diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 0d9c5d3b..996597ce 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -7,11 +7,14 @@ Credential, OrganizationCreate, ProjectCreate, + PromptCreate, + PromptVersion, CredsCreate, ) from app.crud import ( create_organization, create_project, + create_prompt, create_api_key, set_creds_for_org, ) @@ -112,3 +115,21 @@ def create_test_credential(db: Session) -> tuple[list[Credential], Project]: }, ) return set_creds_for_org(session=db, creds_add=creds_data), project + + +def create_test_prompt( + db: Session, + project_id: str, + name: str = "test_prompt", + description: str = "Test prompt description", + instruction: str = "Test instruction", + commit_message: str = "Initial version", +) -> tuple[PromptCreate, PromptVersion]: + """Helper function to create a PromptCreate object and persist it in the database.""" + prompt_in = PromptCreate( + name=name, + description=description, + instruction=instruction, + commit_message=commit_message, + ) + return create_prompt(db, prompt_in=prompt_in, project_id=project_id)