Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""evaluation update constraints

Revision ID: 633e69806207
Revises: 6fe772038a5a
Create Date: 2025-11-13 11:36:16.484694

"""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "633e69806207"
down_revision = "6fe772038a5a"
branch_labels = None
depends_on = None


def upgrade():
op.alter_column(
"evaluation_run",
"config",
existing_type=postgresql.JSON(astext_type=sa.Text()),
type_=postgresql.JSONB(astext_type=sa.Text()),
existing_nullable=False,
)
op.alter_column(
"evaluation_run",
"score",
existing_type=postgresql.JSON(astext_type=sa.Text()),
type_=postgresql.JSONB(astext_type=sa.Text()),
existing_nullable=True,
)
# Remove SET NULL behavior from evaluation_run batch_job foreign keys
# This ensures evaluation runs fail if their batch job is deleted (maintain referential integrity)
op.drop_constraint(
"fk_evaluation_run_embedding_batch_job_id", "evaluation_run", type_="foreignkey"
)
op.drop_constraint(
"evaluation_run_batch_job_id_fkey", "evaluation_run", type_="foreignkey"
)
op.drop_constraint(
"openai_conversation_organization_id_fkey1",
"openai_conversation",
type_="foreignkey",
)
op.drop_constraint(
"openai_conversation_project_id_fkey1",
"openai_conversation",
type_="foreignkey",
)
op.create_foreign_key(
"evaluation_run_batch_job_id_fkey",
"evaluation_run",
"batch_job",
["batch_job_id"],
["id"],
ondelete="SET NULL",
)
op.create_foreign_key(
"fk_evaluation_run_embedding_batch_job_id",
"evaluation_run",
"batch_job",
["embedding_batch_job_id"],
["id"],
ondelete="SET NULL",
)


def downgrade():
op.alter_column(
"evaluation_run",
"score",
existing_type=postgresql.JSONB(astext_type=sa.Text()),
type_=postgresql.JSON(astext_type=sa.Text()),
existing_nullable=True,
)
op.alter_column(
"evaluation_run",
"config",
existing_type=postgresql.JSONB(astext_type=sa.Text()),
type_=postgresql.JSON(astext_type=sa.Text()),
existing_nullable=False,
)
# Restore SET NULL behavior to evaluation_run batch_job foreign keys
op.drop_constraint(
"fk_evaluation_run_embedding_batch_job_id", "evaluation_run", type_="foreignkey"
)
op.drop_constraint(
"evaluation_run_batch_job_id_fkey", "evaluation_run", type_="foreignkey"
)
op.create_foreign_key(
"evaluation_run_batch_job_id_fkey",
"evaluation_run",
"batch_job",
["batch_job_id"],
["id"],
ondelete="SET NULL",
)
op.create_foreign_key(
"fk_evaluation_run_embedding_batch_job_id",
"evaluation_run",
"batch_job",
["embedding_batch_job_id"],
["id"],
ondelete="SET NULL",
)
op.create_foreign_key(
"openai_conversation_organization_id_fkey1",
"openai_conversation",
"organization",
["organization_id"],
["id"],
)
op.create_foreign_key(
"openai_conversation_project_id_fkey1",
"openai_conversation",
"project",
["project_id"],
["id"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
Create Date: 2025-11-05 22:47:18.266070

"""
from alembic import op

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import sqlmodel.sql.sqltypes

from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "6fe772038a5a"
Expand All @@ -23,68 +23,26 @@ def upgrade():
op.create_table(
"batch_job",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"provider",
sa.String(),
nullable=False,
comment="LLM provider name (e.g., 'openai', 'anthropic')",
),
sa.Column(
"job_type",
sa.String(),
nullable=False,
comment="Type of batch job (e.g., 'evaluation', 'classification', 'embedding')",
),
sa.Column("provider", sa.String(), nullable=False),
sa.Column("job_type", sa.String(), nullable=False),
sa.Column(
"config",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text("'{}'::jsonb"),
comment="Complete batch configuration",
),
sa.Column(
"provider_batch_id",
sa.String(),
nullable=True,
comment="Provider's batch job ID",
),
sa.Column(
"provider_file_id",
sa.String(),
nullable=True,
comment="Provider's input file ID",
),
sa.Column(
"provider_output_file_id",
sa.String(),
nullable=True,
comment="Provider's output file ID",
),
sa.Column(
"provider_status",
sa.String(),
nullable=True,
comment="Provider-specific status (e.g., OpenAI: validating, in_progress, completed, failed)",
),
sa.Column(
"raw_output_url",
sa.String(),
nullable=True,
comment="S3 URL of raw batch output file",
),
sa.Column("provider_batch_id", sa.String(), nullable=True),
sa.Column("provider_file_id", sa.String(), nullable=True),
sa.Column("provider_output_file_id", sa.String(), nullable=True),
sa.Column("provider_status", sa.String(), nullable=True),
sa.Column("raw_output_url", sa.String(), nullable=True),
sa.Column(
"total_items",
sa.Integer(),
nullable=False,
server_default=sa.text("0"),
comment="Total number of items in the batch",
),
sa.Column(
"error_message",
sa.Text(),
nullable=True,
comment="Error message if batch failed",
),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("project_id", sa.Integer(), nullable=False),
sa.Column("inserted_at", sa.DateTime(), nullable=False),
Expand Down Expand Up @@ -136,9 +94,7 @@ def upgrade():
"object_store_url", sqlmodel.sql.sqltypes.AutoString(), nullable=True
),
sa.Column(
"langfuse_dataset_id",
sqlmodel.sql.sqltypes.AutoString(),
nullable=True,
"langfuse_dataset_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True
),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("project_id", sa.Integer(), nullable=False),
Expand Down Expand Up @@ -170,12 +126,7 @@ def upgrade():
sa.Column("dataset_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("config", sa.JSON(), nullable=False),
sa.Column("batch_job_id", sa.Integer(), nullable=True),
sa.Column(
"embedding_batch_job_id",
sa.Integer(),
nullable=True,
comment="Reference to the batch_job for embedding-based similarity scoring",
),
sa.Column("embedding_batch_job_id", sa.Integer(), nullable=True),
sa.Column("dataset_id", sa.Integer(), nullable=False),
sa.Column("status", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column(
Expand Down
54 changes: 40 additions & 14 deletions backend/app/models/batch_job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Optional

from sqlalchemy import Column
from sqlalchemy import Column, Index, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field, Relationship, SQLModel

Expand All @@ -16,55 +16,81 @@ class BatchJob(SQLModel, table=True):
"""Batch job table for tracking async LLM batch operations."""

__tablename__ = "batch_job"
__table_args__ = (
Index("idx_batch_job_status_org", "provider_status", "organization_id"),
Index("idx_batch_job_status_project", "provider_status", "project_id"),
)

id: int | None = Field(default=None, primary_key=True)

# Provider and job type
provider: str = Field(description="LLM provider name (e.g., 'openai', 'anthropic')")
provider: str = Field(
description="LLM provider name (e.g., 'openai', 'anthropic')",
)
job_type: str = Field(
description="Type of batch job (e.g., 'evaluation', 'classification', 'embedding')"
index=True,
description=(
"Type of batch job (e.g., 'evaluation', 'classification', 'embedding')"
),
)

# Batch configuration - stores all provider-specific config
config: dict[str, Any] = Field(
default_factory=dict,
sa_column=Column(JSONB()),
description="Complete batch configuration including model, temperature, instructions, tools, etc.",
sa_column=Column(JSONB, nullable=False),
description=(
"Complete batch configuration including model, temperature, "
"instructions, tools, etc."
),
)

# Provider-specific batch tracking
provider_batch_id: str | None = Field(
default=None, description="Provider's batch job ID (e.g., OpenAI batch_id)"
default=None,
description="Provider's batch job ID (e.g., OpenAI batch_id)",
)
provider_file_id: str | None = Field(
default=None, description="Provider's input file ID"
default=None,
description="Provider's input file ID",
)
provider_output_file_id: str | None = Field(
default=None, description="Provider's output file ID"
default=None,
description="Provider's output file ID",
)

# Provider status tracking
provider_status: str | None = Field(
default=None,
description="Provider-specific status (e.g., OpenAI: validating, in_progress, finalizing, completed, failed, expired, cancelling, cancelled)",
description=(
"Provider-specific status (e.g., OpenAI: validating, in_progress, "
"finalizing, completed, failed, expired, cancelling, cancelled)"
),
)

# Raw results (before parent-specific processing)
raw_output_url: str | None = Field(
default=None, description="S3 URL of raw batch output file"
default=None,
description="S3 URL of raw batch output file",
)
total_items: int = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all the Fields,

You don’t need to use sa_column here. SQLModel already infers the SQLAlchemy column from the type annotation + Field() metadata.

total_items: int = Field( default=0, description="Total number of items in the batch" )

default=0, description="Total number of items in the batch"
default=0,
description="Total number of items in the batch",
)

# Error handling
error_message: str | None = Field(
default=None, description="Error message if batch failed"
default=None,
sa_column=Column(Text, nullable=True),
description="Error message if batch failed",
)

# Foreign keys
organization_id: int = Field(foreign_key="organization.id")
project_id: int = Field(foreign_key="project.id")
organization_id: int = Field(
foreign_key="organization.id", nullable=False, ondelete="CASCADE", index=True
)
project_id: int = Field(
foreign_key="project.id", nullable=False, ondelete="CASCADE", index=True
)

# Timestamps
inserted_at: datetime = Field(
Expand Down
Loading