diff --git a/backend/app/db/migrations/env.py b/backend/app/db/migrations/env.py index 169d9bc7..842fff2e 100644 --- a/backend/app/db/migrations/env.py +++ b/backend/app/db/migrations/env.py @@ -2,6 +2,7 @@ import sys from logging.config import fileConfig +from sqlalchemy import create_engine # Import create_engine for synchronous operations from sqlalchemy.ext.asyncio import create_async_engine # Import create_async_engine from alembic import context @@ -60,31 +61,43 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - import asyncio - - async def process_migrations(): - # Determine the database URL based on environment or settings - if os.getenv("TESTING") == "True" and settings.TEST_DB_NAME: - db_url = ( - f"postgresql+asyncpg://{settings.TEST_DB_USER}:{settings.TEST_DB_PASSWORD}@" - f"{settings.TEST_DB_HOST}:{settings.TEST_DB_PORT}/{settings.TEST_DB_NAME}" - ) - elif settings.DATABASE_URL.startswith("sqlite"): - db_url = settings.DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://") - else: - db_url = settings.DATABASE_URL - + # this callback is used to prevent an auto-migration from running with an async connection + def do_run_migrations(connection): + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + # Determine the database URL based on environment or settings + if os.getenv("TESTING") == "True" and settings.TEST_DB_NAME: + db_url = ( + f"postgresql+asyncpg://{settings.TEST_DB_USER}:{settings.TEST_DB_PASSWORD}@" + f"{settings.TEST_DB_HOST}:{settings.TEST_DB_PORT}/{settings.TEST_DB_NAME}" + ) + elif settings.DATABASE_URL.startswith("sqlite"): + db_url = settings.DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://") + else: + db_url = settings.DATABASE_URL + + if "autogenerate" in sys.argv: + # For autogenerate, use a synchronous engine + # Convert async driver schemes to their sync equivalents for autogenerate + if "+asyncpg" in db_url: + db_url = db_url.replace("+asyncpg", "") + if "+aiosqlite" in db_url: + db_url = db_url.replace("+aiosqlite", "") + connectable = create_engine(db_url) + with connectable.connect() as connection: + do_run_migrations(connection) + else: + # For regular migrations, use an async engine + import asyncio connectable = create_async_engine(db_url) - async with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) - - async with context.begin_transaction(): - context.run_migrations() + async def process_migrations(): + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) - asyncio.run(process_migrations()) + asyncio.run(process_migrations()) if context.is_offline_mode(): diff --git a/backend/app/db/migrations/versions/51d8ee58dab4_add_error_tracking_and_timing_columns_.py b/backend/app/db/migrations/versions/51d8ee58dab4_add_error_tracking_and_timing_columns_.py new file mode 100644 index 00000000..e20c00bb --- /dev/null +++ b/backend/app/db/migrations/versions/51d8ee58dab4_add_error_tracking_and_timing_columns_.py @@ -0,0 +1,30 @@ +"""Add error tracking and timing columns to report_state + +Revision ID: 51d8ee58dab4 +Revises: +Create Date: 2025-12-04 15:08:13.594907 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '51d8ee58dab4' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/backend/app/db/models/report_state.py b/backend/app/db/models/report_state.py index 78bf9194..f17e738d 100644 --- a/backend/app/db/models/report_state.py +++ b/backend/app/db/models/report_state.py @@ -18,6 +18,7 @@ class ReportStatusEnum(PyEnum): NLG_COMPLETED = "nlg_completed" GENERATING_SUMMARY = "generating_summary" SUMMARY_COMPLETED = "summary_completed" + TIMED_OUT = "timed_out" diff --git a/backend/app/db/repositories/report_repository.py b/backend/app/db/repositories/report_repository.py index 2c8f63c5..771d78ce 100644 --- a/backend/app/db/repositories/report_repository.py +++ b/backend/app/db/repositories/report_repository.py @@ -1,4 +1,5 @@ from typing import Callable, Dict, Any +from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update from sqlalchemy.exc import IntegrityError @@ -10,7 +11,7 @@ def __init__(self, session_factory: Callable[..., AsyncSession]): self.session_factory = session_factory async def create_report_entry(self, report_id: str) -> Report: - async with await self.session_factory() as session: + async with self.session_factory() as session: try: report = Report(id=report_id) session.add(report) @@ -34,7 +35,7 @@ async def create_report_entry(self, report_id: str) -> Report: raise async def update_report_status(self, report_id: str, status: ReportStatusEnum) -> ReportState | None: - async with await self.session_factory() as session: + async with self.session_factory() as session: try: stmt = update(ReportState).where(ReportState.report_id == report_id).values(status=status).returning(ReportState) result = await session.execute(stmt) @@ -46,7 +47,7 @@ async def update_report_status(self, report_id: str, status: ReportStatusEnum) - raise async def store_partial_report_results(self, report_id: str, partial_data: Dict[str, Any]) -> ReportState | None: - async with await self.session_factory() as session: + async with self.session_factory() as session: try: stmt = update(ReportState).where(ReportState.report_id == report_id).values(partial_agent_output=partial_data).returning(ReportState) result = await session.execute(stmt) @@ -58,7 +59,7 @@ async def store_partial_report_results(self, report_id: str, partial_data: Dict[ raise async def save_final_report(self, report_id: str, data: Dict[str, Any]) -> ReportState | None: - async with await self.session_factory() as session: + async with self.session_factory() as session: try: stmt = update(ReportState).where(ReportState.report_id == report_id).values(final_report_json=data, status=ReportStatusEnum.COMPLETED).returning(ReportState) result = await session.execute(stmt) @@ -70,13 +71,13 @@ async def save_final_report(self, report_id: str, data: Dict[str, Any]) -> Repor raise async def get_report_by_id(self, report_id: str) -> ReportState | None: - async with await self.session_factory() as session: + async with self.session_factory() as session: stmt = select(ReportState).where(ReportState.report_id == report_id) result = await session.execute(stmt) return result.scalar_one_or_none() async def update_timing_alerts(self, report_id: str, alerts: Dict[str, Any]) -> ReportState | None: - async with await self.session_factory() as session: + async with self.session_factory() as session: try: stmt = update(ReportState).where(ReportState.report_id == report_id).values(timing_alerts=alerts).returning(ReportState) result = await session.execute(stmt) @@ -88,7 +89,7 @@ async def update_timing_alerts(self, report_id: str, alerts: Dict[str, Any]) -> raise async def update_partial(self, report_id: str, data: Dict[str, Any]) -> ReportState | None: - async with await self.session_factory() as session: + async with self.session_factory() as session: try: stmt = update(ReportState).where(ReportState.report_id == report_id).values(**data).returning(ReportState) result = await session.execute(stmt) @@ -98,3 +99,32 @@ async def update_partial(self, report_id: str, data: Dict[str, Any]) -> ReportSt except Exception: await session.rollback() raise + + async def recover_stalled_reports(self, timeout_minutes: int) -> int: + async with self.session_factory() as session: + try: + stalled_threshold = datetime.now(timezone.utc) - timedelta(minutes=timeout_minutes) + + running_states = [ + ReportStatusEnum.RUNNING, + ReportStatusEnum.RUNNING_AGENTS, + ReportStatusEnum.GENERATING_NLG, + ReportStatusEnum.GENERATING_SUMMARY, + ] + + stmt = update(ReportState).where( + ReportState.status.in_(running_states), + ReportState.updated_at < stalled_threshold + ).values( + status=ReportStatusEnum.TIMED_OUT, + error_message="Report stalled in running state for too long." + ).returning(ReportState.report_id) + + result = await session.execute(stmt) + updated_report_ids = result.scalars().all() + await session.commit() + return len(updated_report_ids) + except Exception: + await session.rollback() + raise + diff --git a/backend/app/db/repositories/tests/test_report_repository.py b/backend/app/db/repositories/tests/test_report_repository.py new file mode 100644 index 00000000..fceec8bf --- /dev/null +++ b/backend/app/db/repositories/tests/test_report_repository.py @@ -0,0 +1,112 @@ +import pytest +from sqlalchemy import select +from datetime import datetime, timedelta, timezone +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from backend.app.db.base import Base +from backend.app.db.models.report import Report +from backend.app.db.models.report_state import ReportState, ReportStatusEnum +from backend.app.db.repositories.report_repository import ReportRepository + +# Use a fixed timezone for consistency in tests +FIXED_TZ = timezone.utc + +@pytest.fixture(name="async_session_factory") +async def async_session_factory_fixture(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + AsyncSessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + yield AsyncSessionLocal + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await engine.dispose() + +@pytest.fixture(name="report_repository") +async def report_repository_fixture(async_session_factory): + return ReportRepository(async_session_factory) + +@pytest.mark.asyncio +async def test_recover_stalled_reports(report_repository, async_session_factory): + async with async_session_factory() as session: + # Create a report that is not stalled + report_active = Report(id="report_active") + report_state_active = ReportState( + report_id="report_active", + status=ReportStatusEnum.RUNNING, + updated_at=datetime.now(FIXED_TZ) - timedelta(minutes=5) + ) + session.add(report_active) + session.add(report_state_active) + + # Create a report that is stalled + report_stalled = Report(id="report_stalled") + report_state_stalled = ReportState( + report_id="report_stalled", + status=ReportStatusEnum.RUNNING, + updated_at=datetime.now(FIXED_TZ) - timedelta(minutes=65) # More than 60 minutes + ) + session.add(report_stalled) + session.add(report_state_stalled) + + # Create a report that is failed (should not be recovered) + report_failed = Report(id="report_failed") + report_state_failed = ReportState( + report_id="report_failed", + status=ReportStatusEnum.FAILED, + updated_at=datetime.now(FIXED_TZ) - timedelta(minutes=70) + ) + session.add(report_failed) + session.add(report_state_failed) + + # Create a report that is completed (should not be recovered) + report_completed = Report(id="report_completed") + report_state_completed = ReportState( + report_id="report_completed", + status=ReportStatusEnum.COMPLETED, + updated_at=datetime.now(FIXED_TZ) - timedelta(minutes=70) + ) + session.add(report_completed) + session.add(report_state_completed) + + await session.commit() + + # Recover stalled reports with a timeout of 60 minutes + recovered_count = await report_repository.recover_stalled_reports(timeout_minutes=60) + assert recovered_count == 1 + + async with async_session_factory() as session: + # Verify status of active report + active_report_state = await session.execute( + select(ReportState).where(ReportState.report_id == "report_active") + ) + assert active_report_state.scalar_one().status == ReportStatusEnum.RUNNING + + # Verify status of stalled report + stalled_report_state = await session.execute( + select(ReportState).where(ReportState.report_id == "report_stalled") + ) + recovered_stalled = stalled_report_state.scalar_one() + assert recovered_stalled.status == ReportStatusEnum.TIMED_OUT + assert recovered_stalled.error_message == "Report stalled in running state for too long." + + # Verify status of failed report + failed_report_state = await session.execute( + select(ReportState).where(ReportState.report_id == "report_failed") + ) + assert failed_report_state.scalar_one().status == ReportStatusEnum.FAILED + + # Verify status of completed report + completed_report_state = await session.execute( + select(ReportState).where(ReportState.report_id == "report_completed") + ) + assert completed_report_state.scalar_one().status == ReportStatusEnum.COMPLETED diff --git a/sql_app.db b/sql_app.db new file mode 100644 index 00000000..aa73729d Binary files /dev/null and b/sql_app.db differ