Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions backend/app/db/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from logging.config import fileConfig

from sqlalchemy import create_engine # Import create_engine for synchronous operations
from sqlalchemy.ext.asyncio import create_async_engine # Import create_async_engine

from alembic import context
Expand Down Expand Up @@ -60,31 +61,43 @@ def run_migrations_online() -> None:
and associate a connection with the context.

"""
import asyncio

async def process_migrations():
# Determine the database URL based on environment or settings
if os.getenv("TESTING") == "True" and settings.TEST_DB_NAME:
db_url = (
f"postgresql+asyncpg://{settings.TEST_DB_USER}:{settings.TEST_DB_PASSWORD}@"
f"{settings.TEST_DB_HOST}:{settings.TEST_DB_PORT}/{settings.TEST_DB_NAME}"
)
elif settings.DATABASE_URL.startswith("sqlite"):
db_url = settings.DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://")
else:
db_url = settings.DATABASE_URL

# this callback is used to prevent an auto-migration from running with an async connection
def do_run_migrations(connection):
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()

# Determine the database URL based on environment or settings
if os.getenv("TESTING") == "True" and settings.TEST_DB_NAME:
db_url = (
f"postgresql+asyncpg://{settings.TEST_DB_USER}:{settings.TEST_DB_PASSWORD}@"
f"{settings.TEST_DB_HOST}:{settings.TEST_DB_PORT}/{settings.TEST_DB_NAME}"
)
elif settings.DATABASE_URL.startswith("sqlite"):
db_url = settings.DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://")
else:
db_url = settings.DATABASE_URL

if "autogenerate" in sys.argv:
# For autogenerate, use a synchronous engine
# Convert async driver schemes to their sync equivalents for autogenerate
if "+asyncpg" in db_url:
db_url = db_url.replace("+asyncpg", "")
if "+aiosqlite" in db_url:
db_url = db_url.replace("+aiosqlite", "")
connectable = create_engine(db_url)
with connectable.connect() as connection:
do_run_migrations(connection)
else:
# For regular migrations, use an async engine
import asyncio
connectable = create_async_engine(db_url)

async with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)

async with context.begin_transaction():
context.run_migrations()
async def process_migrations():
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)

asyncio.run(process_migrations())
asyncio.run(process_migrations())


if context.is_offline_mode():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Add error tracking and timing columns to report_state

Revision ID: 51d8ee58dab4
Revises:
Create Date: 2025-12-04 15:08:13.594907

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '51d8ee58dab4'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions backend/app/db/models/report_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ReportStatusEnum(PyEnum):
NLG_COMPLETED = "nlg_completed"
GENERATING_SUMMARY = "generating_summary"
SUMMARY_COMPLETED = "summary_completed"
TIMED_OUT = "timed_out"



Expand Down
44 changes: 37 additions & 7 deletions backend/app/db/repositories/report_repository.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Dict, Any
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from sqlalchemy.exc import IntegrityError
Expand All @@ -10,7 +11,7 @@ def __init__(self, session_factory: Callable[..., AsyncSession]):
self.session_factory = session_factory

async def create_report_entry(self, report_id: str) -> Report:
async with await self.session_factory() as session:
async with self.session_factory() as session:
try:
report = Report(id=report_id)
session.add(report)
Expand All @@ -34,7 +35,7 @@ async def create_report_entry(self, report_id: str) -> Report:
raise

async def update_report_status(self, report_id: str, status: ReportStatusEnum) -> ReportState | None:
async with await self.session_factory() as session:
async with self.session_factory() as session:
try:
stmt = update(ReportState).where(ReportState.report_id == report_id).values(status=status).returning(ReportState)
result = await session.execute(stmt)
Expand All @@ -46,7 +47,7 @@ async def update_report_status(self, report_id: str, status: ReportStatusEnum) -
raise

async def store_partial_report_results(self, report_id: str, partial_data: Dict[str, Any]) -> ReportState | None:
async with await self.session_factory() as session:
async with self.session_factory() as session:
try:
stmt = update(ReportState).where(ReportState.report_id == report_id).values(partial_agent_output=partial_data).returning(ReportState)
result = await session.execute(stmt)
Expand All @@ -58,7 +59,7 @@ async def store_partial_report_results(self, report_id: str, partial_data: Dict[
raise

async def save_final_report(self, report_id: str, data: Dict[str, Any]) -> ReportState | None:
async with await self.session_factory() as session:
async with self.session_factory() as session:
try:
stmt = update(ReportState).where(ReportState.report_id == report_id).values(final_report_json=data, status=ReportStatusEnum.COMPLETED).returning(ReportState)
result = await session.execute(stmt)
Expand All @@ -70,13 +71,13 @@ async def save_final_report(self, report_id: str, data: Dict[str, Any]) -> Repor
raise

async def get_report_by_id(self, report_id: str) -> ReportState | None:
async with await self.session_factory() as session:
async with self.session_factory() as session:
stmt = select(ReportState).where(ReportState.report_id == report_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()

async def update_timing_alerts(self, report_id: str, alerts: Dict[str, Any]) -> ReportState | None:
async with await self.session_factory() as session:
async with self.session_factory() as session:
try:
stmt = update(ReportState).where(ReportState.report_id == report_id).values(timing_alerts=alerts).returning(ReportState)
result = await session.execute(stmt)
Expand All @@ -88,7 +89,7 @@ async def update_timing_alerts(self, report_id: str, alerts: Dict[str, Any]) ->
raise

async def update_partial(self, report_id: str, data: Dict[str, Any]) -> ReportState | None:
async with await self.session_factory() as session:
async with self.session_factory() as session:
try:
stmt = update(ReportState).where(ReportState.report_id == report_id).values(**data).returning(ReportState)
result = await session.execute(stmt)
Expand All @@ -98,3 +99,32 @@ async def update_partial(self, report_id: str, data: Dict[str, Any]) -> ReportSt
except Exception:
await session.rollback()
raise

async def recover_stalled_reports(self, timeout_minutes: int) -> int:
async with self.session_factory() as session:
try:
stalled_threshold = datetime.now(timezone.utc) - timedelta(minutes=timeout_minutes)

running_states = [
ReportStatusEnum.RUNNING,
ReportStatusEnum.RUNNING_AGENTS,
ReportStatusEnum.GENERATING_NLG,
ReportStatusEnum.GENERATING_SUMMARY,
]

stmt = update(ReportState).where(
ReportState.status.in_(running_states),
ReportState.updated_at < stalled_threshold
).values(
status=ReportStatusEnum.TIMED_OUT,
error_message="Report stalled in running state for too long."
).returning(ReportState.report_id)

result = await session.execute(stmt)
updated_report_ids = result.scalars().all()
await session.commit()
return len(updated_report_ids)
except Exception:
await session.rollback()
raise

112 changes: 112 additions & 0 deletions backend/app/db/repositories/tests/test_report_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pytest
from sqlalchemy import select
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from backend.app.db.base import Base
from backend.app.db.models.report import Report
from backend.app.db.models.report_state import ReportState, ReportStatusEnum
from backend.app.db.repositories.report_repository import ReportRepository

# Use a fixed timezone for consistency in tests
FIXED_TZ = timezone.utc

@pytest.fixture(name="async_session_factory")
async def async_session_factory_fixture():
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

AsyncSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)

yield AsyncSessionLocal

async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()

@pytest.fixture(name="report_repository")
async def report_repository_fixture(async_session_factory):
return ReportRepository(async_session_factory)

@pytest.mark.asyncio
async def test_recover_stalled_reports(report_repository, async_session_factory):
async with async_session_factory() as session:
# Create a report that is not stalled
report_active = Report(id="report_active")
report_state_active = ReportState(
report_id="report_active",
status=ReportStatusEnum.RUNNING,
updated_at=datetime.now(FIXED_TZ) - timedelta(minutes=5)
)
session.add(report_active)
session.add(report_state_active)

# Create a report that is stalled
report_stalled = Report(id="report_stalled")
report_state_stalled = ReportState(
report_id="report_stalled",
status=ReportStatusEnum.RUNNING,
updated_at=datetime.now(FIXED_TZ) - timedelta(minutes=65) # More than 60 minutes
)
session.add(report_stalled)
session.add(report_state_stalled)

# Create a report that is failed (should not be recovered)
report_failed = Report(id="report_failed")
report_state_failed = ReportState(
report_id="report_failed",
status=ReportStatusEnum.FAILED,
updated_at=datetime.now(FIXED_TZ) - timedelta(minutes=70)
)
session.add(report_failed)
session.add(report_state_failed)

# Create a report that is completed (should not be recovered)
report_completed = Report(id="report_completed")
report_state_completed = ReportState(
report_id="report_completed",
status=ReportStatusEnum.COMPLETED,
updated_at=datetime.now(FIXED_TZ) - timedelta(minutes=70)
)
session.add(report_completed)
session.add(report_state_completed)

await session.commit()

# Recover stalled reports with a timeout of 60 minutes
recovered_count = await report_repository.recover_stalled_reports(timeout_minutes=60)
assert recovered_count == 1

async with async_session_factory() as session:
# Verify status of active report
active_report_state = await session.execute(
select(ReportState).where(ReportState.report_id == "report_active")
)
assert active_report_state.scalar_one().status == ReportStatusEnum.RUNNING

# Verify status of stalled report
stalled_report_state = await session.execute(
select(ReportState).where(ReportState.report_id == "report_stalled")
)
recovered_stalled = stalled_report_state.scalar_one()
assert recovered_stalled.status == ReportStatusEnum.TIMED_OUT
assert recovered_stalled.error_message == "Report stalled in running state for too long."

# Verify status of failed report
failed_report_state = await session.execute(
select(ReportState).where(ReportState.report_id == "report_failed")
)
assert failed_report_state.scalar_one().status == ReportStatusEnum.FAILED

# Verify status of completed report
completed_report_state = await session.execute(
select(ReportState).where(ReportState.report_id == "report_completed")
)
assert completed_report_state.scalar_one().status == ReportStatusEnum.COMPLETED
Binary file added sql_app.db
Binary file not shown.