diff --git a/backend/app/alembic/versions/e9dd35eff62c_add_openai_conversation_table.py b/backend/app/alembic/versions/e9dd35eff62c_add_openai_conversation_table.py new file mode 100644 index 00000000..9ff047ec --- /dev/null +++ b/backend/app/alembic/versions/e9dd35eff62c_add_openai_conversation_table.py @@ -0,0 +1,87 @@ +"""add_openai_conversation_table + +Revision ID: e9dd35eff62c +Revises: e8ee93526b37 +Create Date: 2025-07-25 18:26:38.132146 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = "e9dd35eff62c" +down_revision = "e8ee93526b37" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "openai_conversation", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("response_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "ancestor_response_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column( + "previous_response_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True + ), + sa.Column("user_question", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("response", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("model", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("assistant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("is_deleted", sa.Boolean(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + ) + op.create_index( + op.f("ix_openai_conversation_ancestor_response_id"), + "openai_conversation", + ["ancestor_response_id"], + unique=False, + ) + op.create_index( + op.f("ix_openai_conversation_previous_response_id"), + "openai_conversation", + ["previous_response_id"], + unique=False, + ) + op.create_index( + op.f("ix_openai_conversation_response_id"), + "openai_conversation", + ["response_id"], + unique=False, + ) + op.create_foreign_key( + None, "openai_conversation", "project", ["project_id"], ["id"] + ) + op.create_foreign_key( + None, "openai_conversation", "organization", ["organization_id"], ["id"] + ) + + +def downgrade(): + op.drop_constraint(None, "openai_conversation", type_="foreignkey") + op.drop_constraint(None, "openai_conversation", type_="foreignkey") + op.drop_index( + op.f("ix_openai_conversation_response_id"), table_name="openai_conversation" + ) + op.drop_index( + op.f("ix_openai_conversation_previous_response_id"), + table_name="openai_conversation", + ) + op.drop_index( + op.f("ix_openai_conversation_ancestor_response_id"), + table_name="openai_conversation", + ) + op.drop_table("openai_conversation") diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 7db3c3d5..df0b1016 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -7,6 +7,7 @@ documents, login, organization, + openai_conversation, project, project_user, responses, @@ -27,6 +28,7 @@ api_router.include_router(documents.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) +api_router.include_router(openai_conversation.router) api_router.include_router(organization.router) api_router.include_router(project.router) api_router.include_router(project_user.router) diff --git a/backend/app/api/routes/openai_conversation.py b/backend/app/api/routes/openai_conversation.py new file mode 100644 index 00000000..5d4f8f8d --- /dev/null +++ b/backend/app/api/routes/openai_conversation.py @@ -0,0 +1,152 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Path, HTTPException, Query +from sqlmodel import Session + +from app.api.deps import get_db, get_current_user_org_project +from app.crud import ( + get_conversation_by_id, + get_conversation_by_response_id, + get_conversation_by_ancestor_id, + get_conversations_by_project, + get_conversations_count_by_project, + create_conversation, + delete_conversation, +) +from app.models import ( + UserProjectOrg, + OpenAIConversationPublic, +) +from app.utils import APIResponse + +router = APIRouter(prefix="/openai-conversation", tags=["OpenAI Conversations"]) + + +@router.get( + "/{conversation_id}", + response_model=APIResponse[OpenAIConversationPublic], + summary="Get a single conversation by its ID", +) +def get_conversation_route( + conversation_id: int = Path(..., description="The conversation ID to fetch"), + session: Session = Depends(get_db), + current_user: UserProjectOrg = Depends(get_current_user_org_project), +): + """ + Fetch a single conversation by its ID. + """ + conversation = get_conversation_by_id( + session, conversation_id, current_user.project_id + ) + if not conversation: + raise HTTPException( + status_code=404, detail=f"Conversation with ID {conversation_id} not found." + ) + return APIResponse.success_response(conversation) + + +@router.get( + "/response/{response_id}", + response_model=APIResponse[OpenAIConversationPublic], + summary="Get a conversation by its OpenAI response ID", +) +def get_conversation_by_response_id_route( + response_id: str = Path(..., description="The OpenAI response ID to fetch"), + session: Session = Depends(get_db), + current_user: UserProjectOrg = Depends(get_current_user_org_project), +): + """ + Fetch a conversation by its OpenAI response ID. + """ + conversation = get_conversation_by_response_id( + session, response_id, current_user.project_id + ) + if not conversation: + raise HTTPException( + status_code=404, + detail=f"Conversation with response ID {response_id} not found.", + ) + return APIResponse.success_response(conversation) + + +@router.get( + "/ancestor/{ancestor_response_id}", + response_model=APIResponse[OpenAIConversationPublic], + summary="Get a conversation by its ancestor response ID", +) +def get_conversation_by_ancestor_id_route( + ancestor_response_id: str = Path( + ..., description="The ancestor response ID to fetch" + ), + session: Session = Depends(get_db), + current_user: UserProjectOrg = Depends(get_current_user_org_project), +): + """ + Fetch a conversation by its ancestor response ID. + """ + conversation = get_conversation_by_ancestor_id( + session, ancestor_response_id, current_user.project_id + ) + if not conversation: + raise HTTPException( + status_code=404, + detail=f"Conversation with ancestor response ID {ancestor_response_id} not found.", + ) + return APIResponse.success_response(conversation) + + +@router.get( + "/", + response_model=APIResponse[list[OpenAIConversationPublic]], + summary="List all conversations in the current project", +) +def list_conversations_route( + session: Session = Depends(get_db), + current_user: UserProjectOrg = Depends(get_current_user_org_project), + skip: int = Query(0, ge=0, description="How many items to skip"), + limit: int = Query(100, ge=1, le=100, description="Maximum items to return"), +): + """ + List all conversations in the current project. + """ + conversations = get_conversations_by_project( + session=session, + project_id=current_user.project_id, + skip=skip, # ← Pagination offset + limit=limit, # ← Page size + ) + + # Get total count for pagination metadata + total = get_conversations_count_by_project( + session=session, + project_id=current_user.project_id, + ) + + return APIResponse.success_response( + data=conversations, metadata={"skip": skip, "limit": limit, "total": total} + ) + + +@router.delete("/{conversation_id}", response_model=APIResponse) +def delete_conversation_route( + conversation_id: Annotated[int, Path(description="Conversation ID to delete")], + session: Session = Depends(get_db), + current_user: UserProjectOrg = Depends(get_current_user_org_project), +): + """ + Soft delete a conversation by marking it as deleted. + """ + deleted_conversation = delete_conversation( + session=session, + conversation_id=conversation_id, + project_id=current_user.project_id, + ) + + if not deleted_conversation: + raise HTTPException( + status_code=404, detail=f"Conversation with ID {conversation_id} not found." + ) + + return APIResponse.success_response( + data={"message": "Conversation deleted successfully."} + ) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 49b09f56..e4b973a0 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -54,3 +54,13 @@ get_assistants_by_project, delete_assistant, ) + +from .openai_conversation import ( + get_conversation_by_id, + get_conversation_by_response_id, + get_conversation_by_ancestor_id, + get_conversations_by_project, + get_conversations_count_by_project, + create_conversation, + delete_conversation, +) diff --git a/backend/app/crud/openai_conversation.py b/backend/app/crud/openai_conversation.py new file mode 100644 index 00000000..505e5e40 --- /dev/null +++ b/backend/app/crud/openai_conversation.py @@ -0,0 +1,147 @@ +import logging +from typing import Optional +from sqlmodel import Session, select, func +from app.models import OpenAIConversation, OpenAIConversationCreate +from app.core.util import now + +logger = logging.getLogger(__name__) + + +def get_conversation_by_id( + session: Session, conversation_id: int, project_id: int +) -> OpenAIConversation | None: + """ + Return a conversation by its ID and project. + """ + statement = select(OpenAIConversation).where( + OpenAIConversation.id == conversation_id, + OpenAIConversation.project_id == project_id, + OpenAIConversation.is_deleted == False, + ) + result = session.exec(statement).first() + return result + + +def get_conversation_by_response_id( + session: Session, response_id: str, project_id: int +) -> OpenAIConversation | None: + """ + Return a conversation by its OpenAI response ID and project. + """ + statement = select(OpenAIConversation).where( + OpenAIConversation.response_id == response_id, + OpenAIConversation.project_id == project_id, + OpenAIConversation.is_deleted == False, + ) + result = session.exec(statement).first() + return result + + +def get_conversation_by_ancestor_id( + session: Session, ancestor_response_id: str, project_id: int +) -> OpenAIConversation | None: + """ + Return the latest conversation by its ancestor response ID and project. + """ + statement = ( + select(OpenAIConversation) + .where( + OpenAIConversation.ancestor_response_id == ancestor_response_id, + OpenAIConversation.project_id == project_id, + OpenAIConversation.is_deleted == False, + ) + .order_by(OpenAIConversation.inserted_at.desc()) + .limit(1) + ) + result = session.exec(statement).first() + return result + + +def get_conversations_count_by_project( + session: Session, + project_id: int, +) -> int: + """ + Return the total count of conversations for a given project. + """ + statement = select(func.count(OpenAIConversation.id)).where( + OpenAIConversation.project_id == project_id, + OpenAIConversation.is_deleted == False, + ) + result = session.exec(statement).one() + return result + + +def get_conversations_by_project( + session: Session, + project_id: int, + skip: int = 0, + limit: int = 100, +) -> list[OpenAIConversation]: + """ + Return all conversations for a given project, with optional pagination. + """ + statement = ( + select(OpenAIConversation) + .where( + OpenAIConversation.project_id == project_id, + OpenAIConversation.is_deleted == False, + ) + .order_by(OpenAIConversation.inserted_at.desc()) + .offset(skip) + .limit(limit) + ) + results = session.exec(statement).all() + return results + + +def create_conversation( + session: Session, + conversation: OpenAIConversationCreate, + project_id: int, + organization_id: int, +) -> OpenAIConversation: + """ + Create a new conversation in the database. + """ + db_conversation = OpenAIConversation( + **conversation.model_dump(), + project_id=project_id, + organization_id=organization_id, + ) + session.add(db_conversation) + session.commit() + session.refresh(db_conversation) + + logger.info( + f"Created conversation with response_id={db_conversation.response_id}, " + f"assistant_id={db_conversation.assistant_id}, project_id={project_id}" + ) + + return db_conversation + + +def delete_conversation( + session: Session, + conversation_id: int, + project_id: int, +) -> OpenAIConversation | None: + """ + Soft delete a conversation by marking it as deleted. + """ + db_conversation = get_conversation_by_id(session, conversation_id, project_id) + if not db_conversation: + return None + + db_conversation.is_deleted = True + db_conversation.deleted_at = now() + session.add(db_conversation) + session.commit() + session.refresh(db_conversation) + + logger.info( + f"Deleted conversation with id={conversation_id}, " + f"response_id={db_conversation.response_id}, project_id={project_id}" + ) + + return db_conversation diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 2c4c87e0..a1c2009c 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -55,3 +55,10 @@ from .threads import OpenAI_Thread, OpenAIThreadBase, OpenAIThreadCreate from .assistants import Assistant, AssistantBase, AssistantCreate, AssistantUpdate + +from .openai_conversation import ( + OpenAIConversationPublic, + OpenAIConversation, + OpenAIConversationBase, + OpenAIConversationCreate, +) diff --git a/backend/app/models/openai_conversation.py b/backend/app/models/openai_conversation.py new file mode 100644 index 00000000..93b1106a --- /dev/null +++ b/backend/app/models/openai_conversation.py @@ -0,0 +1,111 @@ +from datetime import datetime +from typing import Optional +import re + +from pydantic import field_validator +from sqlmodel import Field, Relationship, SQLModel + +from app.core.util import now + + +def validate_response_id_pattern(v: str) -> str: + """Shared validation function for response ID patterns""" + if v is None: + return v + if not re.match(r"^resp_[a-zA-Z0-9]{10,}$", v): + raise ValueError( + "response_id fields must follow pattern: resp_ followed by at least 10 alphanumeric characters" + ) + return v + + +class OpenAIConversationBase(SQLModel): + # usually follow the pattern of resp_688704e41190819db512c30568xxxxxxx + response_id: str = Field(index=True, min_length=10) + ancestor_response_id: str = Field( + index=True, + description="Ancestor response ID for conversation threading", + ) + previous_response_id: Optional[str] = Field( + default=None, index=True, description="Previous response ID in the conversation" + ) + user_question: str = Field(description="User's question/input") + response: Optional[str] = Field(default=None, description="AI response") + # there are models with small name like o1 and usually fine tuned models have long names + model: str = Field( + description="The model used for the response", min_length=1, max_length=150 + ) + # usually follow the pattern of asst_WD9bumYqTtpSvxxxxx + assistant_id: Optional[str] = Field( + default=None, + description="The assistant ID used", + min_length=10, + max_length=50, + ) + project_id: int = Field( + foreign_key="project.id", nullable=False, ondelete="CASCADE" + ) + organization_id: int = Field( + foreign_key="organization.id", nullable=False, ondelete="CASCADE" + ) + + @field_validator("response_id", "ancestor_response_id", "previous_response_id") + @classmethod + def validate_response_ids(cls, v): + return validate_response_id_pattern(v) + + +class OpenAIConversation(OpenAIConversationBase, table=True): + __tablename__ = "openai_conversation" + + id: int = Field(default=None, primary_key=True) + inserted_at: datetime = Field(default_factory=now, nullable=False) + updated_at: datetime = Field(default_factory=now, nullable=False) + is_deleted: bool = Field(default=False, nullable=False) + deleted_at: Optional[datetime] = Field(default=None, nullable=True) + + # Relationships + project: "Project" = Relationship(back_populates="openai_conversations") + organization: "Organization" = Relationship(back_populates="openai_conversations") + + +class OpenAIConversationCreate(SQLModel): + # usually follow the pattern of resp_688704e41190819db512c30568dcaebc0a42e02be2c2c49b + response_id: str = Field(min_length=10) + ancestor_response_id: str = Field( + description="Ancestor response ID for conversation threading" + ) + previous_response_id: Optional[str] = Field( + default=None, description="Previous response ID in the conversation" + ) + user_question: str = Field(description="User's question/input", min_length=1) + response: Optional[str] = Field(default=None, description="AI response") + # there are models with small name like o1 and usually fine tuned models have long names + model: str = Field( + description="The model used for the response", min_length=1, max_length=150 + ) + # usually follow the pattern of asst_WD9bumYqTtpSvxxxxx + assistant_id: Optional[str] = Field( + default=None, + description="The assistant ID used", + min_length=10, + max_length=50, + ) + + @field_validator("response_id", "ancestor_response_id", "previous_response_id") + @classmethod + def validate_response_ids(cls, v): + return validate_response_id_pattern(v) + + +class OpenAIConversationPublic(OpenAIConversationBase): + """Public model for OpenAIConversation without sensitive fields""" + + id: int + inserted_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + populate_by_name = True + use_enum_values = True diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 90eed18b..e854b11e 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -11,6 +11,7 @@ from .api_key import APIKey from .assistants import Assistant from .collection import Collection + from .openai_conversation import OpenAIConversation # Shared properties for an Organization @@ -52,6 +53,9 @@ class Organization(OrganizationBase, table=True): collections: list["Collection"] = Relationship( back_populates="organization", cascade_delete=True ) + openai_conversations: list["OpenAIConversation"] = Relationship( + back_populates="organization", cascade_delete=True + ) # Properties to return via API diff --git a/backend/app/models/project.py b/backend/app/models/project.py index de2ceb3c..442b740a 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -49,6 +49,9 @@ class Project(ProjectBase, table=True): collections: list["Collection"] = Relationship( back_populates="project", cascade_delete=True ) + openai_conversations: list["OpenAIConversation"] = Relationship( + back_populates="project", cascade_delete=True + ) # Properties to return via API diff --git a/backend/app/tests/api/routes/test_openai_conversation.py b/backend/app/tests/api/routes/test_openai_conversation.py new file mode 100644 index 00000000..55d309ed --- /dev/null +++ b/backend/app/tests/api/routes/test_openai_conversation.py @@ -0,0 +1,530 @@ +from sqlmodel import Session +from fastapi.testclient import TestClient + +from app.models import APIKeyPublic +from app.crud.openai_conversation import create_conversation +from app.models import OpenAIConversationCreate +from app.tests.utils.openai import generate_openai_id + + +def test_get_conversation_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful conversation retrieval.""" + + response_id = generate_openai_id("resp_", 40) + conversation_data = OpenAIConversationCreate( + response_id=response_id, + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of France?", + response="The capital of France is Paris.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + response = client.get( + f"/api/v1/openai-conversation/{conversation.id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"]["id"] == conversation.id + assert response_data["data"]["response_id"] == conversation.response_id + + +def test_get_conversation_not_found( + client: TestClient, + user_api_key: APIKeyPublic, +): + """Test conversation retrieval with non-existent ID.""" + response = client.get( + "/api/v1/openai-conversation/99999", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 404 + response_data = response.json() + assert "not found" in response_data["error"] + + +def test_get_conversation_by_response_id_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful conversation retrieval by response ID.""" + response_id = generate_openai_id("resp_", 40) + conversation_data = OpenAIConversationCreate( + response_id=response_id, + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of France?", + response="The capital of France is Paris.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + + response = client.get( + f"/api/v1/openai-conversation/response/{response_id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"]["response_id"] == response_id + assert response_data["data"]["id"] == conversation.id + + +def test_get_conversation_by_response_id_not_found( + client: TestClient, + user_api_key: APIKeyPublic, +): + """Test conversation retrieval with non-existent response ID.""" + response = client.get( + "/api/v1/openai-conversation/response/nonexistent_response_id", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 404 + response_data = response.json() + assert "not found" in response_data["error"] + + +def test_get_conversation_by_ancestor_id_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful conversation retrieval by ancestor ID.""" + ancestor_response_id = generate_openai_id("resp_", 40) + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=ancestor_response_id, + previous_response_id=None, + user_question="What is the capital of France?", + response="The capital of France is Paris.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + + response = client.get( + f"/api/v1/openai-conversation/ancestor/{ancestor_response_id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert response_data["data"]["ancestor_response_id"] == ancestor_response_id + assert response_data["data"]["id"] == conversation.id + + +def test_get_conversation_by_ancestor_id_not_found( + client: TestClient, + user_api_key: APIKeyPublic, +): + """Test conversation retrieval with non-existent ancestor ID.""" + response = client.get( + "/api/v1/openai-conversation/ancestor/nonexistent_ancestor_id", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 404 + response_data = response.json() + assert "not found" in response_data["error"] + + +def test_list_conversations_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful conversation listing.""" + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of France?", + response="The capital of France is Paris.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # Actually create the conversation in the database + create_conversation( + session=db, + conversation=conversation_data, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + + response = client.get( + "/api/v1/openai-conversation", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert isinstance(response_data["data"], list) + assert len(response_data["data"]) > 0 + + +def test_list_conversations_with_pagination( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test conversation listing with pagination.""" + # Create multiple conversations + conversation_data_1 = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of Japan?", + response="The capital of Japan is Tokyo.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + conversation_data_2 = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of Brazil?", + response="The capital of Brazil is Brasília.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # Actually create the conversations in the database + create_conversation( + session=db, + conversation=conversation_data_1, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + + create_conversation( + session=db, + conversation=conversation_data_2, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + + response = client.get( + "/api/v1/openai-conversation?skip=1&limit=2", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert isinstance(response_data["data"], list) + assert len(response_data["data"]) <= 2 + + # Check pagination metadata + assert "metadata" in response_data + metadata = response_data["metadata"] + assert metadata["skip"] == 1 + assert metadata["limit"] == 2 + assert "total" in metadata + assert isinstance(metadata["total"], int) + assert metadata["total"] >= 2 + + +def test_list_conversations_pagination_metadata( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test conversation listing pagination metadata.""" + # Create 5 conversations + for i in range(5): + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question=f"Test question {i}", + response=f"Test response {i}", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + create_conversation( + session=db, + conversation=conversation_data, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + + # Test first page + response = client.get( + "/api/v1/openai-conversation?skip=0&limit=3", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + + metadata = response_data["metadata"] + assert metadata["skip"] == 0 + assert metadata["limit"] == 3 + assert ( + metadata["total"] >= 5 + ) # Should include the 5 we created plus any existing ones + + # Test second page + response = client.get( + "/api/v1/openai-conversation?skip=3&limit=3", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + + metadata = response_data["metadata"] + assert metadata["skip"] == 3 + assert metadata["limit"] == 3 + assert metadata["total"] >= 5 + + +def test_list_conversations_default_pagination( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test conversation listing with default pagination parameters.""" + # Create a conversation + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="Test question", + response="Test response", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + create_conversation( + session=db, + conversation=conversation_data, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + + # Test without pagination parameters (should use defaults) + response = client.get( + "/api/v1/openai-conversation", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + + metadata = response_data["metadata"] + assert metadata["skip"] == 0 # Default skip + assert metadata["limit"] == 100 # Default limit + assert "total" in metadata + assert isinstance(metadata["total"], int) + + +def test_list_conversations_edge_cases( + client: TestClient, + user_api_key: APIKeyPublic, +): + """Test conversation listing edge cases for pagination.""" + # Test with skip larger than total + response = client.get( + "/api/v1/openai-conversation?skip=1000&limit=10", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert len(response_data["data"]) == 0 # Should return empty list + + metadata = response_data["metadata"] + assert metadata["skip"] == 1000 + assert metadata["limit"] == 10 + assert "total" in metadata + + # Test with maximum limit + response = client.get( + "/api/v1/openai-conversation?skip=0&limit=100", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + + metadata = response_data["metadata"] + assert metadata["skip"] == 0 + assert metadata["limit"] == 100 + assert "total" in metadata + + +def test_list_conversations_invalid_pagination( + client: TestClient, + user_api_key: APIKeyPublic, +): + """Test conversation listing with invalid pagination parameters.""" + response = client.get( + "/api/v1/openai-conversation?skip=-1&limit=0", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 422 + + +def test_delete_conversation_success( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): + """Test successful conversation deletion.""" + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of Japan?", + response="The capital of Japan is Tokyo.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # Create the conversation in the database and get the created object with ID + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) + + conversation_id = conversation.id + + response = client.delete( + f"/api/v1/openai-conversation/{conversation_id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert "deleted successfully" in response_data["data"]["message"] + + # Verify the conversation is marked as deleted + response = client.get( + f"/api/v1/openai-conversation/{conversation_id}", + headers={"X-API-KEY": user_api_key.key}, + ) + assert response.status_code == 404 + + +def test_delete_conversation_not_found( + client: TestClient, + user_api_key: APIKeyPublic, +): + """Test conversation deletion with non-existent ID.""" + response = client.delete( + "/api/v1/openai-conversation/99999", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 404 + response_data = response.json() + assert "not found" in response_data["error"] + + +def test_get_conversation_unauthorized_no_api_key( + client: TestClient, + db: Session, +): + """Test conversation retrieval without API key.""" + response = client.get("/api/v1/openai-conversation/1") + assert response.status_code == 401 + + +def test_get_conversation_unauthorized_invalid_api_key( + client: TestClient, + db: Session, +): + """Test conversation retrieval with invalid API key.""" + response = client.get( + "/api/v1/openai-conversation/1", + headers={"X-API-KEY": "invalid_api_key"}, + ) + assert response.status_code == 401 + + +def test_list_conversations_unauthorized_no_api_key( + client: TestClient, + db: Session, +): + """Test conversation listing without API key.""" + response = client.get("/api/v1/openai-conversation") + assert response.status_code == 401 + + +def test_list_conversations_unauthorized_invalid_api_key( + client: TestClient, + db: Session, +): + """Test conversation listing with invalid API key.""" + response = client.get( + "/api/v1/openai-conversation", + headers={"X-API-KEY": "invalid_api_key"}, + ) + assert response.status_code == 401 + + +def test_delete_conversation_unauthorized_no_api_key( + client: TestClient, + db: Session, +): + """Test conversation deletion without API key.""" + response = client.delete("/api/v1/openai-conversation/1") + assert response.status_code == 401 + + +def test_delete_conversation_unauthorized_invalid_api_key( + client: TestClient, + db: Session, +): + """Test conversation deletion with invalid API key.""" + response = client.delete( + "/api/v1/openai-conversation/1", + headers={"X-API-KEY": "invalid_api_key"}, + ) + assert response.status_code == 401 diff --git a/backend/app/tests/crud/test_openai_conversation.py b/backend/app/tests/crud/test_openai_conversation.py new file mode 100644 index 00000000..cfb8e092 --- /dev/null +++ b/backend/app/tests/crud/test_openai_conversation.py @@ -0,0 +1,517 @@ +import pytest +from sqlmodel import Session + +from app.crud.openai_conversation import ( + get_conversation_by_id, + get_conversation_by_response_id, + get_conversation_by_ancestor_id, + get_conversations_by_project, + get_conversations_count_by_project, + create_conversation, + delete_conversation, +) +from app.models import OpenAIConversationCreate +from app.tests.utils.utils import get_project, get_organization +from app.tests.utils.openai import generate_openai_id + + +def test_get_conversation_by_id_success(db: Session): + """Test successful conversation retrieval by ID.""" + project = get_project(db) + organization = get_organization(db) + + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of Japan?", + response="The capital of Japan is Tokyo.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # Create the conversation in the database + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + retrieved_conversation = get_conversation_by_id( + session=db, + conversation_id=conversation.id, + project_id=project.id, + ) + + assert retrieved_conversation is not None + assert retrieved_conversation.id == conversation.id + assert retrieved_conversation.response_id == conversation.response_id + + +def test_get_conversation_by_id_not_found(db: Session): + """Test conversation retrieval by non-existent ID.""" + project = get_project(db) + + retrieved_conversation = get_conversation_by_id( + session=db, + conversation_id=99999, + project_id=project.id, + ) + + assert retrieved_conversation is None + + +def test_get_conversation_by_response_id_success(db: Session): + """Test successful conversation retrieval by response ID.""" + project = get_project(db) + organization = get_organization(db) + + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of Japan?", + response="The capital of Japan is Tokyo.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # Create the conversation in the database + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + retrieved_conversation = get_conversation_by_response_id( + session=db, + response_id=conversation.response_id, + project_id=project.id, + ) + + assert retrieved_conversation is not None + assert retrieved_conversation.id == conversation.id + assert retrieved_conversation.response_id == conversation.response_id + + +def test_get_conversation_by_response_id_not_found(db: Session): + """Test conversation retrieval by non-existent response ID.""" + project = get_project(db) + + retrieved_conversation = get_conversation_by_response_id( + session=db, + response_id="nonexistent_response_id", + project_id=project.id, + ) + + assert retrieved_conversation is None + + +def test_get_conversation_by_ancestor_id_success(db: Session): + """Test successful conversation retrieval by ancestor ID.""" + project = get_project(db) + organization = get_organization(db) + + # Create a conversation with an ancestor + ancestor_response_id = generate_openai_id("resp_", 40) + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=ancestor_response_id, + previous_response_id=None, + user_question="What is the capital of France?", + response="The capital of France is Paris.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + retrieved_conversation = get_conversation_by_ancestor_id( + session=db, + ancestor_response_id=ancestor_response_id, + project_id=project.id, + ) + + assert retrieved_conversation is not None + assert retrieved_conversation.id == conversation.id + assert retrieved_conversation.ancestor_response_id == ancestor_response_id + + +def test_get_conversation_by_ancestor_id_not_found(db: Session): + """Test conversation retrieval by non-existent ancestor ID.""" + project = get_project(db) + + retrieved_conversation = get_conversation_by_ancestor_id( + session=db, + ancestor_response_id="nonexistent_ancestor_id", + project_id=project.id, + ) + + assert retrieved_conversation is None + + +def test_get_conversations_by_project_success(db: Session): + """Test successful conversation listing by project.""" + project = get_project(db) + organization = get_organization(db) + + # Create multiple conversations directly + for i in range(3): + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question=f"Test question {i}", + response=f"Test response {i}", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + conversations = get_conversations_by_project( + session=db, + project_id=project.id, + ) + + assert len(conversations) >= 3 + for conversation in conversations: + assert conversation.project_id == project.id + assert conversation.is_deleted is False + + +def test_get_conversations_by_project_with_pagination(db: Session): + """Test conversation listing by project with pagination.""" + project = get_project(db) + organization = get_organization(db) + + # Create multiple conversations + for i in range(5): + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question=f"Test question {i}", + response=f"Test response {i}", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + conversations = get_conversations_by_project( + session=db, + project_id=project.id, + skip=1, + limit=2, + ) + + assert len(conversations) <= 2 + + +def test_delete_conversation_success(db: Session): + """Test successful conversation deletion.""" + project = get_project(db) + organization = get_organization(db) + + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of Japan?", + response="The capital of Japan is Tokyo.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # Create the conversation in the database first + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + deleted_conversation = delete_conversation( + session=db, + conversation_id=conversation.id, + project_id=project.id, + ) + + assert deleted_conversation is not None + assert deleted_conversation.id == conversation.id + assert deleted_conversation.is_deleted is True + assert deleted_conversation.deleted_at is not None + + +def test_delete_conversation_not_found(db: Session): + """Test conversation deletion with non-existent ID.""" + project = get_project(db) + + deleted_conversation = delete_conversation( + session=db, + conversation_id=99999, + project_id=project.id, + ) + + assert deleted_conversation is None + + +def test_conversation_soft_delete_behavior(db: Session): + """Test that deleted conversations are not returned by get functions.""" + project = get_project(db) + organization = get_organization(db) + + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="What is the capital of Japan?", + response="The capital of Japan is Tokyo.", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # Create the conversation in the database first + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + # Delete the conversation + delete_conversation( + session=db, + conversation_id=conversation.id, + project_id=project.id, + ) + + # Verify it's not returned by get functions + retrieved_conversation = get_conversation_by_id( + session=db, + conversation_id=conversation.id, + project_id=project.id, + ) + assert retrieved_conversation is None + + retrieved_conversation = get_conversation_by_response_id( + session=db, + response_id=conversation.response_id, + project_id=project.id, + ) + assert retrieved_conversation is None + + conversations = get_conversations_by_project( + session=db, + project_id=project.id, + ) + assert conversation.id not in [c.id for c in conversations] + + +def test_get_conversations_count_by_project_success(db: Session): + """Test successful conversation count retrieval by project.""" + project = get_project(db) + organization = get_organization(db) + + # Get initial count + initial_count = get_conversations_count_by_project( + session=db, + project_id=project.id, + ) + + # Create multiple conversations + for i in range(3): + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question=f"Test question {i}", + response=f"Test response {i}", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + # Get updated count + updated_count = get_conversations_count_by_project( + session=db, + project_id=project.id, + ) + + assert updated_count == initial_count + 3 + + +def test_get_conversations_count_by_project_excludes_deleted(db: Session): + """Test that deleted conversations are not counted.""" + project = get_project(db) + organization = get_organization(db) + + # Create a conversation + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question="Test question", + response="Test response", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + + # Get count before deletion + count_before = get_conversations_count_by_project( + session=db, + project_id=project.id, + ) + + # Delete the conversation + delete_conversation( + session=db, + conversation_id=conversation.id, + project_id=project.id, + ) + + # Get count after deletion + count_after = get_conversations_count_by_project( + session=db, + project_id=project.id, + ) + + assert count_after == count_before - 1 + + +def test_get_conversations_count_by_project_different_projects(db: Session): + """Test that count is isolated by project.""" + project1 = get_project(db) + organization = get_organization(db) + + # Get another project (assuming there are multiple projects in test data) + project2 = ( + get_project(db, "Dalgo") + if get_project(db, "Dalgo") is not None + else get_project(db) + ) + + # Create conversations in project1 + for i in range(2): + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question=f"Test question {i}", + response=f"Test response {i}", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + create_conversation( + session=db, + conversation=conversation_data, + project_id=project1.id, + organization_id=organization.id, + ) + + # Create conversations in project2 + for i in range(3): + conversation_data = OpenAIConversationCreate( + response_id=generate_openai_id("resp_", 40), + ancestor_response_id=generate_openai_id("resp_", 40), + previous_response_id=None, + user_question=f"Test question {i}", + response=f"Test response {i}", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + create_conversation( + session=db, + conversation=conversation_data, + project_id=project2.id, + organization_id=organization.id, + ) + + # Check counts are isolated + count1 = get_conversations_count_by_project(session=db, project_id=project1.id) + count2 = get_conversations_count_by_project(session=db, project_id=project2.id) + + assert count1 >= 2 + assert count2 >= 3 + + +def test_response_id_validation_pattern(db: Session): + """Test that response ID validation pattern is enforced.""" + project = get_project(db) + organization = get_organization(db) + + # Test valid response ID + valid_response_id = "resp_1234567890abcdef" + conversation_data = OpenAIConversationCreate( + response_id=valid_response_id, + ancestor_response_id="resp_abcdef1234567890", + previous_response_id=None, + user_question="Test question", + response="Test response", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # This should work + conversation = create_conversation( + session=db, + conversation=conversation_data, + project_id=project.id, + organization_id=organization.id, + ) + assert conversation is not None + assert conversation.response_id == valid_response_id + + # Test invalid response ID (too short) + invalid_response_id = "resp_123" + with pytest.raises(ValueError, match="String should have at least 10 characters"): + OpenAIConversationCreate( + response_id=invalid_response_id, + ancestor_response_id="resp_abcdef1234567890", + previous_response_id=None, + user_question="Test question", + response="Test response", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) + + # Test invalid response ID (wrong prefix but long enough) + invalid_response_id2 = "msg_1234567890abcdef" + with pytest.raises(ValueError, match="response_id fields must follow pattern"): + OpenAIConversationCreate( + response_id=invalid_response_id2, + ancestor_response_id="resp_abcdef1234567890", + previous_response_id=None, + user_question="Test question", + response="Test response", + model="gpt-4o", + assistant_id=generate_openai_id("asst_", 20), + ) diff --git a/backend/app/tests/utils/openai.py b/backend/app/tests/utils/openai.py index 6f11bbf5..a864ee33 100644 --- a/backend/app/tests/utils/openai.py +++ b/backend/app/tests/utils/openai.py @@ -1,5 +1,8 @@ -from typing import Optional import time +import secrets +import string + +from typing import Optional from unittest.mock import MagicMock from openai.types.beta import Assistant as OpenAIAssistant @@ -8,6 +11,14 @@ from openai.types.beta.file_search_tool import FileSearch +def generate_openai_id(prefix: str, length: int = 40) -> str: + """Generate a realistic ID similar to OpenAI's format (alphanumeric only)""" + # Generate random alphanumeric string + chars = string.ascii_lowercase + string.digits + random_part = "".join(secrets.choice(chars) for _ in range(length)) + return f"{prefix}{random_part}" + + def mock_openai_assistant( assistant_id: str = "assistant_mock", vector_store_ids: Optional[list[str]] = ["vs_1", "vs_2"], diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index ae4a7bee..9fb5311f 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -12,7 +12,7 @@ from app.core.config import settings from app.crud.user import get_user_by_email from app.crud.api_key import get_api_key_by_value, get_api_key_by_user_id -from app.models import APIKeyPublic, Project, Assistant +from app.models import APIKeyPublic, Project, Assistant, Organization T = TypeVar("T") @@ -113,6 +113,30 @@ def get_assistant(session: Session, name: str | None = None) -> Assistant: return assistant +def get_organization(session: Session, name: str | None = None) -> Organization: + """ + Retrieve an active organization from the database. + + If an organization name is provided, fetch the active organization with that name. + If no name is provided, fetch any random organization. + """ + if name: + statement = ( + select(Organization) + .where(Organization.name == name, Organization.is_active) + .limit(1) + ) + else: + statement = select(Organization).where(Organization.is_active).limit(1) + + organization = session.exec(statement).first() + + if not organization: + raise ValueError("No active organizations found") + + return organization + + class SequentialUuidGenerator: def __init__(self, start=0): self.start = start