Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_collection(
return collection


def test_collection_info_processing(db: Session):
def test_collection_info_processing(db: Session, client: TestClient):
headers = {"X-API-KEY": original_api_key}
user = get_user_from_api_key(db, headers)
collection = create_collection(db, user, status=CollectionStatus.processing)
Expand All @@ -57,7 +57,7 @@ def test_collection_info_processing(db: Session):
assert data["llm_service_name"] is None


def test_collection_info_successful(db: Session):
def test_collection_info_successful(db: Session, client: TestClient):
headers = {"X-API-KEY": original_api_key}
user = get_user_from_api_key(db, headers)
collection = create_collection(
Expand All @@ -78,7 +78,7 @@ def test_collection_info_successful(db: Session):
assert data["llm_service_name"] == "gpt-4o"


def test_collection_info_failed(db: Session):
def test_collection_info_failed(db: Session, client: TestClient):
headers = {"X-API-KEY": original_api_key}
user = get_user_from_api_key(db, headers)
collection = create_collection(db, user, status=CollectionStatus.failed)
Expand Down
70 changes: 33 additions & 37 deletions backend/app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,60 @@

import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session, delete
from sqlmodel import Session
from sqlalchemy import event

from app.core.config import settings
from app.core.db import engine, init_db
from app.core.db import engine
from app.api.deps import get_db
from app.main import app
from app.models import (
APIKey,
Assistant,
Organization,
Project,
ProjectUser,
User,
OpenAI_Thread,
Credential,
Collection,
)
from app.tests.utils.user import authentication_token_from_email
from app.tests.utils.utils import get_superuser_token_headers
from app.seed_data.seed_data import seed_database


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="function")
def db() -> Generator[Session, None, None]:
with Session(engine) as session:
init_db(session)
connection = engine.connect()
transaction = connection.begin()
session = Session(bind=connection)

nested = session.begin_nested()

@event.listens_for(session, "after_transaction_end")
def restart_savepoint(sess, trans):
if trans.nested and not trans._parent.nested:
sess.begin_nested()

try:
yield session
# Delete data in reverse dependency order
session.execute(delete(ProjectUser)) # Many-to-many relationship
session.execute(delete(Assistant))
session.execute(delete(Credential))
session.execute(delete(Project))
session.execute(delete(Organization))
session.execute(delete(APIKey))
session.execute(delete(User))
session.execute(delete(OpenAI_Thread))
session.execute(delete(Collection))
session.commit()
finally:
session.close()
transaction.rollback()
connection.close()


@pytest.fixture(scope="session", autouse=True)
def seed_baseline():
with Session(engine) as session:
seed_database(session) # deterministic baseline
yield


@pytest.fixture(scope="module")
def client() -> Generator[TestClient, None, None]:
@pytest.fixture(scope="function")
def client(db: Session):
app.dependency_overrides[get_db] = lambda: db
with TestClient(app) as c:
yield c


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def superuser_token_headers(client: TestClient) -> dict[str, str]:
return get_superuser_token_headers(client)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def normal_user_token_headers(client: TestClient, db: Session) -> dict[str, str]:
return authentication_token_from_email(
client=client, email=settings.EMAIL_TEST_USER, db=db
)


@pytest.fixture(scope="session", autouse=True)
def load_seed_data(db):
seed_database(db)
yield