Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions backend/app/alembic/versions/79e47bc3aac6_add_threads_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""add threads table

Revision ID: 79e47bc3aac6
Revises: f23675767ed2
Create Date: 2025-05-12 15:49:39.142806

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = "79e47bc3aac6"
down_revision = "f23675767ed2"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"openai_thread",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("thread_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("prompt", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("response", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("status", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("error", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("inserted_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_openai_thread_thread_id"), "openai_thread", ["thread_id"], unique=True
)
op.drop_constraint(
"credential_organization_id_fkey", "credential", type_="foreignkey"
)
op.create_foreign_key(
None, "credential", "organization", ["organization_id"], ["id"]
)
op.drop_constraint("project_organization_id_fkey", "project", type_="foreignkey")
op.create_foreign_key(None, "project", "organization", ["organization_id"], ["id"])
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "project", type_="foreignkey")
op.create_foreign_key(
"project_organization_id_fkey",
"project",
"organization",
["organization_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint(None, "credential", type_="foreignkey")
op.create_foreign_key(
"credential_organization_id_fkey",
"credential",
"organization",
["organization_id"],
["id"],
ondelete="CASCADE",
)
op.drop_index(op.f("ix_openai_thread_thread_id"), table_name="openai_thread")
op.drop_table("openai_thread")
# ### end Alembic commands ###
124 changes: 123 additions & 1 deletion backend/app/api/routes/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from app.api.deps import get_current_user_org, get_db
from app.core import logging, settings
from app.models import UserOrganization
from app.models import UserOrganization, OpenAIThreadCreate
from app.crud import upsert_thread_result, get_thread_result
from app.utils import APIResponse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,6 +114,24 @@
)


def run_and_poll_thread(client: OpenAI, thread_id: str, assistant_id: str):
"""Runs and polls a thread with the specified assistant using the OpenAI client."""
return client.beta.threads.runs.create_and_poll(
thread_id=thread_id,
assistant_id=assistant_id,
)


def extract_response_from_thread(
client: OpenAI, thread_id: str, remove_citation: bool = False
) -> str:
"""Fetches and processes the latest message from a thread."""
messages = client.beta.threads.messages.list(thread_id=thread_id)
latest_message = messages.data[0]
message_content = latest_message.content[0].text.value
return process_message_content(message_content, remove_citation)


@observe(as_type="generation")
def process_run(request: dict, client: OpenAI):
"""Process a run and send callback with results."""
Expand Down Expand Up @@ -159,6 +178,40 @@
send_callback(request["callback_url"], callback_response.model_dump())


def poll_run_and_prepare_response(request: dict, client: OpenAI, db: Session):
"""Handles a thread run, processes the response, and upserts the result to the database."""
thread_id = request["thread_id"]
prompt = request["question"]

try:
run = run_and_poll_thread(client, thread_id, request["assistant_id"])

status = run.status or "unknown"
response = None
error = None

if status == "completed":
response = extract_response_from_thread(
client, thread_id, request.get("remove_citation", False)
)

except openai.OpenAIError as e:
status = "failed"
error = str(e)
response = None

upsert_thread_result(
db,
OpenAIThreadCreate(
thread_id=thread_id,
prompt=prompt,
response=response,
status=status,
error=error,
),
)


@router.post("/threads")
async def threads(
request: dict,
Expand Down Expand Up @@ -240,3 +293,72 @@

except openai.OpenAIError as e:
return APIResponse.failure_response(error=handle_openai_error(e))


@router.post("/threads/start")
async def start_thread(
request: OpenAIThreadCreate,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org),
):
"""
Create a new OpenAI thread for the given question and start polling in the background.
"""
prompt = request["question"]
client = OpenAI(api_key=settings.OPENAI_API_KEY)

Check warning on line 309 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L308-L309

Added lines #L308 - L309 were not covered by tests

is_success, error = setup_thread(client, request)
if not is_success:
return APIResponse.failure_response(error=error)

Check warning on line 313 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L311-L313

Added lines #L311 - L313 were not covered by tests

thread_id = request["thread_id"]

Check warning on line 315 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L315

Added line #L315 was not covered by tests

upsert_thread_result(

Check warning on line 317 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L317

Added line #L317 was not covered by tests
db,
OpenAIThreadCreate(
thread_id=thread_id,
prompt=prompt,
response=None,
status="processing",
error=None,
),
)

background_tasks.add_task(poll_run_and_prepare_response, request, client, db)

Check warning on line 328 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L328

Added line #L328 was not covered by tests

return APIResponse.success_response(

Check warning on line 330 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L330

Added line #L330 was not covered by tests
data={
"thread_id": thread_id,
"prompt": prompt,
"status": "processing",
"message": "Thread created and polling started in background.",
}
)


@router.get("/threads/result/{thread_id}")
async def get_thread(
thread_id: str,
db: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org),
):
"""
Retrieve the result of a previously started OpenAI thread using its thread ID.
"""
result = get_thread_result(db, thread_id)

Check warning on line 349 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L349

Added line #L349 was not covered by tests

if not result:
return APIResponse.failure_response(error="Thread not found.")

Check warning on line 352 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L351-L352

Added lines #L351 - L352 were not covered by tests

status = result.status or ("success" if result.response else "processing")

Check warning on line 354 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L354

Added line #L354 was not covered by tests

return APIResponse.success_response(

Check warning on line 356 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L356

Added line #L356 was not covered by tests
data={
"thread_id": result.thread_id,
"prompt": result.prompt,
"status": status,
"response": result.response,
"error": result.error,
}
)
2 changes: 2 additions & 0 deletions backend/app/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@
get_api_keys_by_organization,
delete_api_key,
)

from .thread_results import upsert_thread_result, get_thread_result
25 changes: 25 additions & 0 deletions backend/app/crud/thread_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sqlmodel import Session, select
from datetime import datetime
from app.models import OpenAIThreadCreate, OpenAI_Thread


def upsert_thread_result(session: Session, data: OpenAIThreadCreate):
statement = select(OpenAI_Thread).where(OpenAI_Thread.thread_id == data.thread_id)
existing = session.exec(statement).first()

if existing:
existing.prompt = data.prompt
existing.response = data.response
existing.status = data.status
existing.error = data.error
existing.updated_at = datetime.utcnow()
else:
new_thread = OpenAI_Thread(**data.dict())
session.add(new_thread)

session.commit()


def get_thread_result(session: Session, thread_id: str) -> OpenAI_Thread | None:
statement = select(OpenAI_Thread).where(OpenAI_Thread.thread_id == thread_id)
return session.exec(statement).first()
2 changes: 2 additions & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@
CredsPublic,
CredsUpdate,
)

from .threads import OpenAI_Thread, OpenAIThreadBase, OpenAIThreadCreate
21 changes: 21 additions & 0 deletions backend/app/models/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from sqlmodel import SQLModel, Field
from typing import Optional
from datetime import datetime


class OpenAIThreadBase(SQLModel):
thread_id: str = Field(index=True, unique=True)
prompt: str
response: Optional[str] = None
status: Optional[str] = None
error: Optional[str] = None


class OpenAIThreadCreate(OpenAIThreadBase):
pass # Used for requests, no `id` or timestamps


class OpenAI_Thread(OpenAIThreadBase, table=True):
id: int = Field(default=None, primary_key=True)
inserted_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
20 changes: 7 additions & 13 deletions backend/app/tests/api/routes/test_creds.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,16 @@ def create_organization_and_creds(db: Session, superuser_token_headers: dict[str


def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]):
unique_org_id = 2
existing_org = (
db.query(Organization).filter(Organization.id == unique_org_id).first()
)
unique_name = "Test Organization " + generate_random_string(5)

if not existing_org:
new_org = Organization(
id=unique_org_id, name="Test Organization", is_active=True
)
db.add(new_org)
db.commit()
new_org = Organization(name=unique_name, is_active=True)
db.add(new_org)
db.commit()
db.refresh(new_org)

api_key = "sk-" + generate_random_string(10)
creds_data = {
"organization_id": unique_org_id,
"organization_id": new_org.id,
"is_active": True,
"credential": {"openai": {"api_key": api_key}},
}
Expand All @@ -69,10 +64,9 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str])
)

assert response.status_code == 200

created_creds = response.json()
assert "data" in created_creds
assert created_creds["data"]["organization_id"] == unique_org_id
assert created_creds["data"]["organization_id"] == new_org.id
assert created_creds["data"]["credential"]["openai"]["api_key"] == api_key


Expand Down
Loading