diff --git a/backend/app/db/repositories/report_repository.py b/backend/app/db/repositories/report_repository.py index 771d78ce..e9664621 100644 --- a/backend/app/db/repositories/report_repository.py +++ b/backend/app/db/repositories/report_repository.py @@ -1,15 +1,114 @@ -from typing import Callable, Dict, Any +from typing import Callable, Dict, Any, Optional from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update from sqlalchemy.exc import IntegrityError from backend.app.db.models.report import Report from backend.app.db.models.report_state import ReportState, ReportStatusEnum - class ReportRepository: + FINAL_STATUSES = [ReportStatusEnum.COMPLETED, ReportStatusEnum.FAILED, ReportStatusEnum.TIMED_OUT] def __init__(self, session_factory: Callable[..., AsyncSession]): self.session_factory = session_factory + async def save_report_initial_state(self, report_id: str) -> ReportState: + """ + Saves the initial state of a new report to the database. + The report will be created with a PENDING status. + """ + async with self.session_factory() as session: + try: + # Ensure a Report entry exists + report = Report(id=report_id) + session.add(report) + + # Create the initial ReportState + report_state = ReportState(report_id=report_id, status=ReportStatusEnum.PENDING) + session.add(report_state) + + await session.commit() + await session.refresh(report_state) + return report_state + except IntegrityError: + await session.rollback() + # If a Report or ReportState with this ID already exists, fetch and return its state + existing_state = await self.get_report_state(report_id) + if existing_state: + return existing_state + raise # Re-raise if not found or other IntegrityError + except Exception: + await session.rollback() + raise + + async def update_report_partial_results(self, report_id: str, partial_data: Dict[str, Any]) -> ReportState | None: + """ + Updates the partial results of a report and sets its status to RUNNING if it's PENDING. + """ + async with self.session_factory() as session: + try: + # Check current status + current_state_result = await session.execute(select(ReportState.status).where(ReportState.report_id == report_id)) + current_status = current_state_result.scalar_one_or_none() + + if current_status in self.FINAL_STATUSES: + return await self.get_report_by_id(report_id) + + values_to_update = { + "partial_agent_output": partial_data, + "updated_at": datetime.now(timezone.utc) + } + + if current_status == ReportStatusEnum.PENDING: + values_to_update["status"] = ReportStatusEnum.RUNNING + + stmt = update(ReportState).where( + ReportState.report_id == report_id, + ReportState.status.notin_(self.FINAL_STATUSES) + ).values(**values_to_update).returning(ReportState) + result = await session.execute(stmt) + updated_report_state = result.scalar_one_or_none() + await session.commit() + return updated_report_state + except Exception: + await session.rollback() + raise + + async def update_report_final_report( + self, + report_id: str, + final_report_data: Optional[Dict[str, Any]], + status: ReportStatusEnum, + error_message: Optional[str] = None + ) -> ReportState | None: + """ + Updates the final report data, status, and optional error message. + """ + async with self.session_factory() as session: + try: + values_to_update = { + "status": status, + "final_report_json": final_report_data, + "error_message": error_message, + "updated_at": datetime.now(timezone.utc) + } + stmt = update(ReportState).where( + ReportState.report_id == report_id, + ReportState.status.notin_(self.FINAL_STATUSES) + ).values(**values_to_update).returning(ReportState) + result = await session.execute(stmt) + updated_report_state = result.scalar_one_or_none() + await session.commit() + return updated_report_state + except Exception: + await session.rollback() + raise + + async def get_report_state(self, report_id: str) -> ReportState | None: + """ + Retrieves the complete state of a report by its ID. + """ + return await self.get_report_by_id(report_id) + + async def create_report_entry(self, report_id: str) -> Report: async with self.session_factory() as session: try: diff --git a/backend/app/tests/state_management/test_state_management.py b/backend/app/tests/state_management/test_state_management.py new file mode 100644 index 00000000..b4d7f8e9 --- /dev/null +++ b/backend/app/tests/state_management/test_state_management.py @@ -0,0 +1,233 @@ +import pytest +from sqlalchemy import select +from datetime import datetime, timezone, timedelta +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from backend.app.db.base import Base +from backend.app.db.models.report import Report +from backend.app.db.models.report_state import ReportState, ReportStatusEnum +from backend.app.db.repositories.report_repository import ReportRepository + +# Use a fixed timezone for consistency in tests +FIXED_TZ = timezone.utc + +@pytest.fixture(name="async_session_factory") +async def async_session_factory_fixture(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + AsyncSessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + yield AsyncSessionLocal + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await engine.dispose() + +@pytest.fixture(name="report_repository") +async def report_repository_fixture(async_session_factory): + return ReportRepository(async_session_factory) + +@pytest.mark.asyncio +async def test_save_new_report_state(report_repository, async_session_factory): + report_id = "test_report_1" + + # Test saving a new report + initial_state = await report_repository.save_report_initial_state(report_id) + assert initial_state.report_id == report_id + assert initial_state.status == ReportStatusEnum.PENDING + assert initial_state.partial_agent_output is None + assert initial_state.final_report_json is None + assert initial_state.error_message is None + + async with async_session_factory() as session: + # Verify it's in the database + db_state = await session.execute( + select(ReportState).where(ReportState.report_id == report_id) + ) + assert db_state.scalar_one().status == ReportStatusEnum.PENDING + +@pytest.mark.asyncio +async def test_update_report_partial_results(report_repository, async_session_factory): + report_id = "test_report_2" + await report_repository.save_report_initial_state(report_id) + + partial_result_1 = {"step": 1, "data": "processing"} + updated_state = await report_repository.update_report_partial_results(report_id, partial_result_1) + + assert updated_state.report_id == report_id + assert updated_state.status == ReportStatusEnum.RUNNING + assert updated_state.partial_agent_output == partial_result_1 + assert updated_state.final_report_json is None + assert updated_state.error_message is None + + async with async_session_factory() as session: + db_state = await session.execute( + select(ReportState).where(ReportState.report_id == report_id) + ) + db_state_obj = db_state.scalar_one() + assert db_state_obj.status == ReportStatusEnum.RUNNING + assert db_state_obj.partial_agent_output == partial_result_1 + + # Update with more partial results + partial_result_2 = {"step": 2, "data": "more processing"} + updated_state_2 = await report_repository.update_report_partial_results(report_id, partial_result_2) + assert updated_state_2.partial_agent_output == partial_result_2 + +@pytest.mark.asyncio +async def test_update_report_final_report_success(report_repository, async_session_factory): + report_id = "test_report_3" + await report_repository.save_report_initial_state(report_id) + await report_repository.update_report_partial_results(report_id, {"step": 1}) + + final_report_data = {"summary": "Final report data", "score": 95} + final_state = await report_repository.update_report_final_report( + report_id, + final_report_data, + ReportStatusEnum.COMPLETED + ) + + assert final_state.report_id == report_id + assert final_state.status == ReportStatusEnum.COMPLETED + assert final_state.final_report_json == final_report_data + assert final_state.error_message is None + + async with async_session_factory() as session: + db_state = await session.execute( + select(ReportState).where(ReportState.report_id == report_id) + ) + db_state_obj = db_state.scalar_one() + assert db_state_obj.status == ReportStatusEnum.COMPLETED + assert db_state_obj.final_report_json == final_report_data + +@pytest.mark.asyncio +async def test_update_report_final_report_failure(report_repository, async_session_factory): + report_id = "test_report_4" + await report_repository.save_report_initial_state(report_id) + + error_message = "An error occurred during report generation." + final_state = await report_repository.update_report_final_report( + report_id, + None, + ReportStatusEnum.FAILED, + error_message=error_message + ) + + assert final_state.report_id == report_id + assert final_state.status == ReportStatusEnum.FAILED + assert final_state.final_report_json is None + assert final_state.error_message == error_message + + async with async_session_factory() as session: + db_state = await session.execute( + select(ReportState).where(ReportState.report_id == report_id) + ) + db_state_obj = db_state.scalar_one() + assert db_state_obj.status == ReportStatusEnum.FAILED + assert db_state_obj.error_message == error_message + +@pytest.mark.asyncio +async def test_get_report_state(report_repository): + report_id = "test_report_5" + await report_repository.save_report_initial_state(report_id) + + # Get initial state + state = await report_repository.get_report_state(report_id) + assert state.report_id == report_id + assert state.status == ReportStatusEnum.PENDING + + # Update to running + await report_repository.update_report_partial_results(report_id, {"data": "step 1"}) + state = await report_repository.get_report_state(report_id) + assert state.status == ReportStatusEnum.RUNNING + assert state.partial_agent_output == {"data": "step 1"} + + # Update to completed + final_data = {"final": "report"} + await report_repository.update_report_final_report(report_id, final_data, ReportStatusEnum.COMPLETED) + state = await report_repository.get_report_state(report_id) + assert state.status == ReportStatusEnum.COMPLETED + assert state.final_report_json == final_data + +@pytest.mark.asyncio +async def test_report_state_transitions(report_repository): + report_id = "test_report_6" + + # 1. Initial state: PENDING + state_pending = await report_repository.save_report_initial_state(report_id) + assert state_pending.status == ReportStatusEnum.PENDING + + # 2. Transition to RUNNING with partial results + partial_data = {"progress": 50} + state_running = await report_repository.update_report_partial_results(report_id, partial_data) + assert state_running.partial_agent_output == partial_data + + # 3. Transition to COMPLETED with final report + final_data = {"result": "success"} + state_completed = await report_repository.update_report_final_report(report_id, final_data, ReportStatusEnum.COMPLETED) + assert state_completed.status == ReportStatusEnum.COMPLETED + assert state_completed.final_report_json == final_data + assert state_completed.error_message is None + + # Verify that updated_at changes + assert state_completed.updated_at > state_running.updated_at + + # Try to update a completed report (should not change status/final report) + # The repository methods should ideally prevent or handle invalid state transitions. + # For now, we'll check that the report remains completed and no new error is added. + original_updated_at = state_completed.updated_at + unchanged_state = await report_repository.update_report_partial_results(report_id, {"progress": 100}) + assert unchanged_state.status == ReportStatusEnum.COMPLETED + assert unchanged_state.final_report_json == final_data + assert unchanged_state.updated_at == original_updated_at # updated_at should not change if no update occurred + +@pytest.mark.asyncio +async def test_report_state_transitions_to_failed(report_repository): + report_id = "test_report_7" + + # 1. Initial state: PENDING + state_pending = await report_repository.save_report_initial_state(report_id) + assert state_pending.status == ReportStatusEnum.PENDING + + # 2. Transition to RUNNING with partial results + partial_data = {"progress": 25} + state_running = await report_repository.update_report_partial_results(report_id, partial_data) + assert state_running.status == ReportStatusEnum.RUNNING + assert state_running.partial_agent_output == partial_data + + # 3. Transition to FAILED with error message + error_msg = "Critical error during processing." + state_failed = await report_repository.update_report_final_report(report_id, None, ReportStatusEnum.FAILED, error_message=error_msg) + assert state_failed.status == ReportStatusEnum.FAILED + assert state_failed.final_report_json is None + assert state_failed.error_message == error_msg + + # Verify that updated_at changes + assert state_failed.updated_at > state_running.updated_at + + # Try to update a failed report (should not change status/final report) + original_updated_at = state_failed.updated_at + unchanged_state = await report_repository.update_report_partial_results(report_id, {"progress": 75}) + assert unchanged_state.status == ReportStatusEnum.FAILED + assert unchanged_state.error_message == error_msg + assert unchanged_state.updated_at == original_updated_at # updated_at should not change if no update occurred + +@pytest.mark.asyncio +async def test_report_not_found(report_repository): + report_id = "non_existent_report" + state = await report_repository.get_report_state(report_id) + assert state is None + + # Test updating a non-existent report + updated_state = await report_repository.update_report_partial_results(report_id, {"data": "test"}) + assert updated_state is None + + final_state = await report_repository.update_report_final_report(report_id, {"data": "final"}, ReportStatusEnum.COMPLETED) + assert final_state is None