diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 762a6654..f6ada599 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -38,7 +38,7 @@ def create_collection( return collection -def test_collection_info_processing(db: Session): +def test_collection_info_processing(db: Session, client: TestClient): headers = {"X-API-KEY": original_api_key} user = get_user_from_api_key(db, headers) collection = create_collection(db, user, status=CollectionStatus.processing) @@ -57,7 +57,7 @@ def test_collection_info_processing(db: Session): assert data["llm_service_name"] is None -def test_collection_info_successful(db: Session): +def test_collection_info_successful(db: Session, client: TestClient): headers = {"X-API-KEY": original_api_key} user = get_user_from_api_key(db, headers) collection = create_collection( @@ -78,7 +78,7 @@ def test_collection_info_successful(db: Session): assert data["llm_service_name"] == "gpt-4o" -def test_collection_info_failed(db: Session): +def test_collection_info_failed(db: Session, client: TestClient): headers = {"X-API-KEY": original_api_key} user = get_user_from_api_key(db, headers) collection = create_collection(db, user, status=CollectionStatus.failed) diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 9f6fe76c..5d68c3f1 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -2,64 +2,60 @@ import pytest from fastapi.testclient import TestClient -from sqlmodel import Session, delete +from sqlmodel import Session +from sqlalchemy import event from app.core.config import settings -from app.core.db import engine, init_db +from app.core.db import engine +from app.api.deps import get_db from app.main import app -from app.models import ( - APIKey, - Assistant, - Organization, - Project, - ProjectUser, - User, - OpenAI_Thread, - Credential, - Collection, -) from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers from app.seed_data.seed_data import seed_database -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="function") def db() -> Generator[Session, None, None]: - with Session(engine) as session: - init_db(session) + connection = engine.connect() + transaction = connection.begin() + session = Session(bind=connection) + + nested = session.begin_nested() + + @event.listens_for(session, "after_transaction_end") + def restart_savepoint(sess, trans): + if trans.nested and not trans._parent.nested: + sess.begin_nested() + + try: yield session - # Delete data in reverse dependency order - session.execute(delete(ProjectUser)) # Many-to-many relationship - session.execute(delete(Assistant)) - session.execute(delete(Credential)) - session.execute(delete(Project)) - session.execute(delete(Organization)) - session.execute(delete(APIKey)) - session.execute(delete(User)) - session.execute(delete(OpenAI_Thread)) - session.execute(delete(Collection)) - session.commit() + finally: + session.close() + transaction.rollback() + connection.close() + + +@pytest.fixture(scope="session", autouse=True) +def seed_baseline(): + with Session(engine) as session: + seed_database(session) # deterministic baseline + yield -@pytest.fixture(scope="module") -def client() -> Generator[TestClient, None, None]: +@pytest.fixture(scope="function") +def client(db: Session): + app.dependency_overrides[get_db] = lambda: db with TestClient(app) as c: yield c -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def superuser_token_headers(client: TestClient) -> dict[str, str]: return get_superuser_token_headers(client) -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def normal_user_token_headers(client: TestClient, db: Session) -> dict[str, str]: return authentication_token_from_email( client=client, email=settings.EMAIL_TEST_USER, db=db ) - - -@pytest.fixture(scope="session", autouse=True) -def load_seed_data(db): - seed_database(db) - yield