Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions backend/app/alembic/versions/219033c644de_add_llm_im_jobs_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Add LLM in jobs table

Revision ID: 219033c644de
Revises: e7c68e43ce6f
Create Date: 2025-10-17 15:38:33.565674

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "219033c644de"
down_revision = "e7c68e43ce6f"
branch_labels = None
depends_on = None


def upgrade():
op.execute("ALTER TYPE jobtype ADD VALUE IF NOT EXISTS 'LLM_API'")


def downgrade():
pass
2 changes: 2 additions & 0 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
documents,
doc_transformation_job,
login,
llm,
organization,
openai_conversation,
project,
Expand All @@ -31,6 +32,7 @@
api_router.include_router(credentials.router)
api_router.include_router(documents.router)
api_router.include_router(doc_transformation_job.router)
api_router.include_router(llm.router)
api_router.include_router(login.router)
api_router.include_router(onboarding.router)
api_router.include_router(openai_conversation.router)
Expand Down
36 changes: 36 additions & 0 deletions backend/app/api/routes/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging

from fastapi import APIRouter

from app.api.deps import AuthContextDep, SessionDep
from app.models import LLMCallRequest, Message
from app.services.llm.jobs import start_job
from app.utils import APIResponse


logger = logging.getLogger(__name__)
router = APIRouter(tags=["LLM"])


@router.post("/llm/call", response_model=APIResponse[Message])
async def llm_call(
_current_user: AuthContextDep, _session: SessionDep, request: LLMCallRequest
):
"""
Endpoint to initiate an LLM call as a background job.
"""
project_id = _current_user.project.id
organization_id = _current_user.organization.id

start_job(
db=_session,
request=request,
project_id=project_id,
organization_id=organization_id,
)

return APIResponse.success_response(
data=Message(
message=f"Your response is being generated and will be delivered via callback."
),
)
5 changes: 5 additions & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@

from .job import Job, JobType, JobStatus, JobUpdate

from .llm import (
LLMCallRequest,
LLMCallResponse,
)

from .message import Message
from .model_evaluation import (
ModelEvaluation,
Expand Down
1 change: 1 addition & 0 deletions backend/app/models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class JobStatus(str, Enum):

class JobType(str, Enum):
RESPONSE = "RESPONSE"
LLM_API = "LLM_API"


class Job(SQLModel, table=True):
Expand Down
2 changes: 2 additions & 0 deletions backend/app/models/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from app.models.llm.request import LLMCallRequest, CompletionConfig, QueryParams
from app.models.llm.response import LLMCallResponse
48 changes: 48 additions & 0 deletions backend/app/models/llm/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Any, Literal

from sqlmodel import Field, SQLModel


# Query Parameters (dynamic per request)
class QueryParams(SQLModel):
"""Query-specific parameters for each LLM call."""

input: str = Field(..., min_length=1, description="User input text/prompt")
conversation_id: str | None = Field(
default=None,
description="Optional conversation ID. If not provided, a new conversation will be created.",
)


class CompletionConfig(SQLModel):
"""Completion configuration with provider and parameters."""

provider: Literal["openai"] = Field(
default="openai", description="LLM provider to use"
)
params: dict[str, Any] = Field(
..., description="Provider-specific parameters (schema varies by provider)"
)


class LLMCallConfig(SQLModel):
"""Complete configuration for LLM call including all processing stages."""

completion: CompletionConfig = Field(..., description="Completion configuration")
# Future additions:
# classifier: ClassifierConfig | None = None
# pre_filter: PreFilterConfig | None = None


class LLMCallRequest(SQLModel):
"""User-facing API request for LLM completion."""

query: QueryParams = Field(..., description="Query-specific parameters")
config: LLMCallConfig = Field(..., description="Configuration for the LLM call")
callback_url: str | None = Field(
default=None, description="Webhook URL for async response delivery"
)
include_provider_response: bool = Field(
default=False,
description="Whether to include the raw LLM provider response in the output",
)
23 changes: 23 additions & 0 deletions backend/app/models/llm/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""LLM response models.

This module contains response models for LLM API calls.
"""
from sqlmodel import SQLModel, Field


class Diagnostics(SQLModel):
input_tokens: int
output_tokens: int
total_tokens: int
model: str
provider: str


class LLMCallResponse(SQLModel):
id: str = Field(..., description="Unique id provided by the LLM provider.")
conversation_id: str | None = None
output: str
usage: Diagnostics
llm_response: dict | None = Field(
default=None, description="Raw Response from LLM provider."
)
10 changes: 10 additions & 0 deletions backend/app/services/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Providers
from app.services.llm.providers import (
BaseProvider,
OpenAIProvider,
)
from app.services.llm.providers import (
PROVIDER_REGISTRY,
get_llm_provider,
get_supported_providers,
)
144 changes: 144 additions & 0 deletions backend/app/services/llm/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import logging
from uuid import UUID

from asgi_correlation_id import correlation_id
from fastapi import HTTPException
from sqlmodel import Session

from app.core.db import engine
from app.crud.jobs import JobCrud
from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMCallResponse
from app.utils import APIResponse, send_callback
from app.celery.utils import start_high_priority_job
from app.services.llm.providers.registry import get_llm_provider


logger = logging.getLogger(__name__)


def start_job(
db: Session, request: LLMCallRequest, project_id: int, organization_id: int
) -> UUID:
"""Create an LLM job and schedule Celery task."""
trace_id = correlation_id.get() or "N/A"
job_crud = JobCrud(session=db)
job = job_crud.create(job_type=JobType.LLM_API, trace_id=trace_id)

try:
task_id = start_high_priority_job(
function_path="app.services.llm.jobs.execute_job",
project_id=project_id,
job_id=str(job.id),
trace_id=trace_id,
request_data=request.model_dump(),
organization_id=organization_id,
)
except Exception as e:
logger.error(
f"[start_job] Error starting Celery task: {str(e)} | job_id={job.id}, project_id={project_id}",
exc_info=True,
)
job_update = JobUpdate(status=JobStatus.FAILED, error_message=str(e))
job_crud.update(job_id=job.id, job_update=job_update)
raise HTTPException(
status_code=500, detail="Internal server error while executing LLM call"
)

logger.info(
f"[start_job] Job scheduled for LLM call | job_id={job.id}, project_id={project_id}, task_id={task_id}"
)
return job.id


def handle_job_error(job_id: UUID, callback_url: str | None, error: str):
"""Handle job failure uniformly callback, and DB update."""
with Session(engine) as session:
job_crud = JobCrud(session=session)

callback = APIResponse.failure_response(error=error)
if callback_url:
send_callback(
callback_url=callback_url,
data=callback.model_dump(),
)

job_crud.update(
job_id=job_id,
job_update=JobUpdate(status=JobStatus.FAILED, error_message=error),
)

return callback.model_dump()


def execute_job(
request_data: dict,
project_id: int,
organization_id: int,
job_id: str,
task_id: str,
task_instance,
) -> LLMCallResponse | None:
"""Celery task to process an LLM request asynchronously."""

request = LLMCallRequest(**request_data)
job_id: UUID = UUID(job_id)

config = request.config
provider = config.completion.provider

logger.info(
f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}, "
f"provider={provider}"
)

try:
# Update job status to PROCESSING
with Session(engine) as session:
job_crud = JobCrud(session=session)
job_crud.update(
job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING)
)

provider_instance = get_llm_provider(
session=session,
provider_type=provider,
project_id=project_id,
organization_id=organization_id,
)

response, error = provider_instance.execute(
completion_config=config.completion,
query=request.query,
include_provider_response=request.include_provider_response,
)

if response:
callback = APIResponse.success_response(data=response)
send_callback(
callback_url=request.callback_url,
data=callback.model_dump(),
)

with Session(engine) as session:
job_crud = JobCrud(session=session)

job_crud.update(
job_id=job_id, job_update=JobUpdate(status=JobStatus.SUCCESS)
)
logger.info(
f"[execute_job] Successfully completed LLM job | job_id={job_id}, "
f"response_id={response.id}, tokens={response.usage.total_tokens}"
)
return callback.model_dump()

return handle_job_error(
job_id, request.callback_url, error=error or "Unknown error occurred"
)

except Exception as e:
error = f"Unexpected error in LLM job execution: {str(e)}"
logger.error(
f"[execute_job] {error} | job_id={job_id}, task_id={task_id}",
exc_info=True,
)
return handle_job_error(job_id, request.callback_url, error=error)
7 changes: 7 additions & 0 deletions backend/app/services/llm/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from app.services.llm.providers.base import BaseProvider
from app.services.llm.providers.openai import OpenAIProvider
from app.services.llm.providers.registry import (
PROVIDER_REGISTRY,
get_llm_provider,
get_supported_providers,
)
Loading
Loading