From 98ef8983e81c29ae0de1a0468b3cfbbac320b002 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 9 Jul 2025 00:23:33 +0530 Subject: [PATCH 1/3] utility functions and seeding fixture --- backend/app/crud/__init__.py | 2 + .../collections/test_collection_info.py | 11 - .../collections/test_create_collections.py | 12 +- backend/app/tests/api/routes/test_api_key.py | 95 +++----- backend/app/tests/api/routes/test_creds.py | 195 +++++++---------- backend/app/tests/api/routes/test_org.py | 47 ++-- backend/app/tests/api/routes/test_project.py | 60 ++--- .../app/tests/api/routes/test_responses.py | 9 - backend/app/tests/conftest.py | 7 + backend/app/tests/crud/test_api_key.py | 108 +++------ backend/app/tests/crud/test_credentials.py | 205 +++++++----------- backend/app/tests/crud/test_org.py | 11 +- backend/app/tests/crud/test_project.py | 66 ++---- backend/app/tests/utils/test_data.py | 114 ++++++++++ backend/app/tests/utils/utils.py | 4 + 15 files changed, 405 insertions(+), 541 deletions(-) create mode 100644 backend/app/tests/utils/test_data.py diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 6bf4d5da..5c6fd102 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -37,6 +37,8 @@ get_key_by_org, update_creds_for_org, remove_creds_for_org, + get_provider_credential, + remove_provider_credential, ) from .thread_results import upsert_thread_result, get_thread_result diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 26e7ef91..762a6654 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -1,26 +1,15 @@ -import pytest from uuid import uuid4 from datetime import datetime, timezone from fastapi.testclient import TestClient from sqlmodel import Session from app.core.config import settings from app.models import Collection -from app.crud.collection import CollectionCrud from app.main import app from app.tests.utils.utils import get_user_from_api_key -from app.seed_data.seed_data import seed_database from app.models.collection import CollectionStatus client = TestClient(app) - -@pytest.fixture(scope="function", autouse=True) -def load_seed_data(db): - """Load seed data before each test.""" - seed_database(db) - yield - - original_api_key = "ApiKey No3x47A5qoIGhm0kVKjQ77dhCqEdWRIQZlEPzzzh7i8" diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 6fb0855c..2a97fb45 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -3,29 +3,19 @@ import io import openai_responses -from sqlmodel import Session, select +from sqlmodel import Session from fastapi.testclient import TestClient -from openai import OpenAIError from app.core.config import settings from app.tests.utils.document import DocumentStore from app.tests.utils.utils import openai_credentials, get_user_from_api_key from app.main import app from app.crud.collection import CollectionCrud -from app.api.routes.collections import CreationRequest, ResponsePayload -from app.seed_data.seed_data import seed_database from app.models.collection import CollectionStatus client = TestClient(app) -@pytest.fixture(scope="function", autouse=True) -def load_seed_data(db): - """Load seed data before each test.""" - seed_database(db) - yield - - @pytest.fixture(autouse=True) def mock_s3(monkeypatch): class FakeStorage: diff --git a/backend/app/tests/api/routes/test_api_key.py b/backend/app/tests/api/routes/test_api_key.py index 552530cc..7d018a66 100644 --- a/backend/app/tests/api/routes/test_api_key.py +++ b/backend/app/tests/api/routes/test_api_key.py @@ -1,51 +1,23 @@ -import uuid -import pytest from fastapi.testclient import TestClient from sqlmodel import Session + from app.main import app -from app.models import APIKey, User, Organization, Project +from app.models import APIKey from app.core.config import settings -from app.crud.api_key import create_api_key -from app.tests.utils.utils import random_email -from app.core.security import get_password_hash +from app.tests.utils.utils import get_non_existent_id +from app.tests.utils.user import create_random_user +from app.tests.utils.test_data import ( + create_test_api_key, + create_test_project, + create_test_organization, +) client = TestClient(app) -def create_test_user(db: Session) -> User: - user = User( - email=random_email(), - hashed_password=get_password_hash("password123"), - is_superuser=True, - ) - db.add(user) - db.commit() - db.refresh(user) - return user - - -def create_test_organization(db: Session) -> Organization: - org = Organization( - name=f"Test Organization {uuid.uuid4()}", description="Test Organization" - ) - db.add(org) - db.commit() - db.refresh(org) - return org - - -def create_test_project(db: Session, organization_id: int) -> Project: - project = Project(name="Test Project", organization_id=organization_id) - db.add(project) - db.commit() - db.refresh(project) - return project - - def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) + user = create_random_user(db) + project = create_test_project(db) response = client.post( f"{settings.API_V1_STR}/apikeys", @@ -57,14 +29,13 @@ def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]): assert data["success"] is True assert "id" in data["data"] assert "key" in data["data"] - assert data["data"]["organization_id"] == org.id + assert data["data"]["organization_id"] == project.organization_id assert data["data"]["user_id"] == user.id def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) + user = create_random_user(db) + project = create_test_project(db) client.post( f"{settings.API_V1_STR}/apikeys", @@ -81,16 +52,11 @@ def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) - api_key = create_api_key( - db, organization_id=org.id, user_id=user.id, project_id=project.id - ) + api_key = create_test_api_key(db) response = client.get( f"{settings.API_V1_STR}/apikeys", - params={"project_id": project.id}, + params={"project_id": api_key.project_id}, headers=superuser_token_headers, ) assert response.status_code == 200 @@ -100,17 +66,12 @@ def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]): assert len(data["data"]) > 0 first_key = data["data"][0] - assert first_key["organization_id"] == org.id - assert first_key["user_id"] == user.id + assert first_key["organization_id"] == api_key.organization_id + assert first_key["user_id"] == api_key.user_id def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) - api_key = create_api_key( - db, organization_id=org.id, user_id=user.id, project_id=project.id - ) + api_key = create_test_api_key(db) response = client.get( f"{settings.API_V1_STR}/apikeys/{api_key.id}", @@ -121,12 +82,14 @@ def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): assert data["success"] is True assert data["data"]["id"] == api_key.id assert data["data"]["organization_id"] == api_key.organization_id - assert data["data"]["user_id"] == user.id + assert data["data"]["user_id"] == api_key.user_id def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, str]): + api_key_id = get_non_existent_id(db, APIKey) + print(api_key_id) response = client.get( - f"{settings.API_V1_STR}/apikeys/999999", + f"{settings.API_V1_STR}/apikeys/{api_key_id}", headers=superuser_token_headers, ) assert response.status_code == 404 @@ -134,12 +97,7 @@ def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) - api_key = create_api_key( - db, organization_id=org.id, user_id=user.id, project_id=project.id - ) + api_key = create_test_api_key(db) response = client.delete( f"{settings.API_V1_STR}/apikeys/{api_key.id}", @@ -154,11 +112,10 @@ def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]): def test_revoke_nonexistent_api_key( db: Session, superuser_token_headers: dict[str, str] ): - user = create_test_user(db) - org = create_test_organization(db) + api_key_id = get_non_existent_id(db, APIKey) response = client.delete( - f"{settings.API_V1_STR}/apikeys/999999", + f"{settings.API_V1_STR}/apikeys/{api_key_id}", headers=superuser_token_headers, ) assert response.status_code == 404 diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 428d4c1c..5d0ff141 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -1,59 +1,37 @@ import pytest from fastapi.testclient import TestClient from sqlmodel import Session -import random -import string from app.main import app -from app.api.deps import get_db -from app.crud.credentials import set_creds_for_org -from app.models import CredsCreate, Organization, OrganizationCreate, Project +from app.models import Organization, Project from app.core.config import settings -from app.core.security import encrypt_api_key from app.core.providers import Provider from app.models.credentials import Credential from app.core.security import decrypt_credentials +from app.tests.utils.utils import generate_random_string, get_non_existent_id +from app.tests.utils.test_data import ( + create_test_credential, + create_test_organization, + create_test_project, + test_credential_data, +) client = TestClient(app) -def generate_random_string(length=10): - return "".join(random.choices(string.ascii_letters + string.digits, k=length)) - - @pytest.fixture -def create_organization_and_creds(db: Session): - unique_org_name = "Test Organization " + generate_random_string(5) - org = Organization(name=unique_org_name, is_active=True) - db.add(org) - db.commit() - db.refresh(org) - - api_key = "sk-" + generate_random_string(10) - creds_data = CredsCreate( - organization_id=org.id, - is_active=True, - credential={ - Provider.OPENAI.value: { - "api_key": api_key, - "model": "gpt-4", - "temperature": 0.7, - } - }, - ) - return org, creds_data +def create_test_credentials(db: Session): + return create_test_credential(db) -def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]): - org = Organization(name="Org for Set Creds", is_active=True) - db.add(org) - db.commit() - db.refresh(org) +def test_set_creds(db: Session, superuser_token_headers: dict[str, str]): + project = create_test_project(db) api_key = "sk-" + generate_random_string(10) creds_data = { - "organization_id": org.id, + "organization_id": project.organization_id, + "project_id": project.id, "is_active": True, "credential": { Provider.OPENAI.value: { @@ -74,7 +52,7 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) data = response.json()["data"] assert isinstance(data, list) assert len(data) == 1 - assert data[0]["organization_id"] == org.id + assert data[0]["organization_id"] == project.organization_id assert data[0]["provider"] == Provider.OPENAI.value assert data[0]["credential"]["model"] == "gpt-4" @@ -82,19 +60,9 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) def test_set_creds_for_invalid_project_org_relationship( db: Session, superuser_token_headers: dict[str, str] ): - # Setup: Create two organizations and one project for each - org1 = Organization(name="Org 1", is_active=True) - org2 = Organization(name="Org 2", is_active=True) - db.add_all([org1, org2]) - db.commit() - db.refresh(org1) - db.refresh(org2) - - project2 = Project(name="Project Org2", organization_id=org2.id) - db.add(project2) - db.commit() - - # Invalid case: Organization mismatch (org1's creds for project2 of org2) + org1 = create_test_organization(db) + project2 = create_test_project(db) + creds_data_invalid = { "organization_id": org1.id, "is_active": True, @@ -119,15 +87,13 @@ def test_set_creds_for_project_not_found( db: Session, superuser_token_headers: dict[str, str] ): # Setup: Create an organization but no project - org = Organization(name="Org for Project Not Found", is_active=True) - db.add(org) - db.commit() - db.refresh(org) + org = create_test_organization(db) + non_existent_project_id = get_non_existent_id(db, Project) creds_data_invalid_project = { "organization_id": org.id, "is_active": True, - "project_id": 99999, + "project_id": non_existent_project_id, "credential": {Provider.OPENAI.value: {"api_key": "sk-123", "model": "gpt-4"}}, } @@ -142,13 +108,12 @@ def test_set_creds_for_project_not_found( def test_read_credentials_with_creds( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) - + creds = create_test_credentials[0] + print(creds) response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", headers=superuser_token_headers, ) @@ -156,7 +121,7 @@ def test_read_credentials_with_creds( data = response.json()["data"] assert isinstance(data, list) assert len(data) == 1 - assert data[0]["organization_id"] == org.id + assert data[0]["organization_id"] == creds.organization_id assert data[0]["provider"] == Provider.OPENAI.value assert data[0]["credential"]["model"] == "gpt-4" @@ -164,8 +129,9 @@ def test_read_credentials_with_creds( def test_read_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): + non_existent_creds_id = get_non_existent_id(db, Credential) response = client.get( - f"{settings.API_V1_STR}/credentials/999999", + f"{settings.API_V1_STR}/credentials/{non_existent_creds_id}", headers=superuser_token_headers, ) assert response.status_code == 404 @@ -173,13 +139,13 @@ def test_read_credentials_not_found( def test_read_provider_credential( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] + print(creds) response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) @@ -190,9 +156,9 @@ def test_read_provider_credential( def test_read_provider_credential_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, _ = create_organization_and_creds + org = create_test_organization(db) response = client.get( f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", @@ -204,10 +170,9 @@ def test_read_provider_credential_not_found( def test_update_credentials( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] update_data = { "provider": Provider.OPENAI.value, @@ -219,7 +184,7 @@ def test_update_credentials( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", json=update_data, headers=superuser_token_headers, ) @@ -234,21 +199,14 @@ def test_update_credentials( def test_update_credentials_failed_update( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] - org_without_creds = Organization(name="Org Without Creds", is_active=True) - db.add(org_without_creds) - db.commit() - db.refresh(org_without_creds) + org_without_creds = create_test_organization(db) existing_creds = ( - db.query(Credential) - .filter(Credential.organization_id == org_without_creds.id) - .all() + db.query(Credential).filter(creds.organization_id == org_without_creds.id).all() ) assert len(existing_creds) == 0 @@ -262,7 +220,7 @@ def test_update_credentials_failed_update( } response_invalid_org = client.patch( - f"{settings.API_V1_STR}/credentials/{org_without_creds.id}", # Valid org id but no creds + f"{settings.API_V1_STR}/credentials/{org_without_creds.id}", json=update_data, headers=superuser_token_headers, ) @@ -276,8 +234,7 @@ def test_update_credentials_failed_update( def test_update_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): - # Create a non-existent organization ID - non_existent_org_id = 999999 + non_existent_org_id = get_non_existent_id(db, Organization) update_data = { "provider": Provider.OPENAI.value, @@ -299,13 +256,12 @@ def test_update_credentials_not_found( def test_delete_provider_credential( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) @@ -315,9 +271,9 @@ def test_delete_provider_credential( def test_delete_provider_credential_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, _ = create_organization_and_creds + org = create_test_organization(db) response = client.delete( f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", @@ -330,13 +286,12 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", headers=superuser_token_headers, ) @@ -346,7 +301,7 @@ def test_delete_all_credentials( # Verify the credentials are soft deleted response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", headers=superuser_token_headers, ) assert response.status_code == 404 # Expect 404 as credentials are soft deleted @@ -356,8 +311,9 @@ def test_delete_all_credentials( def test_delete_all_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): + non_existent_creds_id = get_non_existent_id(db, Credential) response = client.delete( - f"{settings.API_V1_STR}/credentials/999999", + f"{settings.API_V1_STR}/credentials/{non_existent_creds_id}", headers=superuser_token_headers, ) @@ -366,31 +322,33 @@ def test_delete_all_credentials_not_found( def test_duplicate_credential_creation( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, creds_data = create_organization_and_creds - # First create credentials + creds = test_credential_data(db) + response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data.dict(), + json=creds.dict(), headers=superuser_token_headers, ) + print(response) assert response.status_code == 200 # Try to create the same credentials again response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data.dict(), + json=creds.dict(), headers=superuser_token_headers, ) assert response.status_code == 400 + assert "already exist" in response.json()["error"] def test_multiple_provider_credentials( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, _ = create_organization_and_creds + org = create_test_organization(db) # Create OpenAI credentials openai_creds = { @@ -446,16 +404,14 @@ def test_multiple_provider_credentials( assert Provider.LANGFUSE.value in providers -def test_credential_encryption( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): - org, creds_data = create_organization_and_creds - original_api_key = creds_data.credential[Provider.OPENAI.value]["api_key"] +def test_credential_encryption(db: Session, superuser_token_headers: dict[str, str]): + creds = test_credential_data(db) + original_api_key = creds.credential[Provider.OPENAI.value]["api_key"] # Create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data.dict(), + json=creds.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 @@ -463,7 +419,7 @@ def test_credential_encryption( db_cred = ( db.query(Credential) .filter( - Credential.organization_id == org.id, + Credential.organization_id == creds.organization_id, Credential.provider == Provider.OPENAI.value, ) .first() @@ -479,22 +435,22 @@ def test_credential_encryption( def test_credential_encryption_consistency( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, creds_data = create_organization_and_creds - original_api_key = creds_data.credential[Provider.OPENAI.value]["api_key"] + creds = test_credential_data(db) + original_api_key = creds.credential[Provider.OPENAI.value]["api_key"] # Create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data.dict(), + json=creds.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 # Fetch the credentials through the API response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 @@ -515,15 +471,14 @@ def test_credential_encryption_consistency( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", json=update_data, headers=superuser_token_headers, ) assert response.status_code == 200 - # Verify the updated value is also properly encrypted/decrypted response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 diff --git a/backend/app/tests/api/routes/test_org.py b/backend/app/tests/api/routes/test_org.py index 709bb7f5..607e1aef 100644 --- a/backend/app/tests/api/routes/test_org.py +++ b/backend/app/tests/api/routes/test_org.py @@ -1,44 +1,25 @@ import pytest from fastapi.testclient import TestClient -from sqlmodel import Session, select +from sqlmodel import Session -from app import crud from app.core.config import settings -from app.core.security import verify_password -from app.models import User, UserCreate -from app.tests.utils.utils import random_email, random_lower_string -from app.models import Organization, OrganizationCreate, OrganizationUpdate -from app.api.deps import get_db +from app.models import Organization from app.main import app from app.crud.organization import create_organization, get_organization_by_id +from app.tests.utils.test_data import create_test_organization client = TestClient(app) @pytest.fixture def test_organization(db: Session, superuser_token_headers: dict[str, str]): - unique_name = f"TestOrg-{random_lower_string()}" - org_data = OrganizationCreate(name=unique_name, is_active=True) - organization = create_organization(session=db, org_create=org_data) - db.commit() - return organization - - -# Test retrieving organizations -def test_read_organizations(db: Session, superuser_token_headers: dict[str, str]): - response = client.get( - f"{settings.API_V1_STR}/organizations/", headers=superuser_token_headers - ) - assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert isinstance(response_data["data"], list) + return create_test_organization(db) # Test creating an organization def test_create_organization(db: Session, superuser_token_headers: dict[str, str]): - unique_name = f"Org-{random_lower_string()}" - org_data = {"name": unique_name, "is_active": True} + org_name = "Test-Org" + org_data = {"name": org_name, "is_active": True} response = client.post( f"{settings.API_V1_STR}/organizations/", json=org_data, @@ -55,13 +36,25 @@ def test_create_organization(db: Session, superuser_token_headers: dict[str, str assert org.is_active == created_org_data["is_active"] +# Test retrieving organizations +def test_read_organizations(db: Session, superuser_token_headers: dict[str, str]): + response = client.get( + f"{settings.API_V1_STR}/organizations/", headers=superuser_token_headers + ) + assert response.status_code == 200 + response_data = response.json() + assert "data" in response_data + assert isinstance(response_data["data"], list) + + +# Updating an organization def test_update_organization( db: Session, test_organization: Organization, superuser_token_headers: dict[str, str], ): - unique_name = f"UpdatedOrg-{random_lower_string()}" # Ensure a unique name - update_data = {"name": unique_name, "is_active": False} + updated_name = "UpdatedOrg" + update_data = {"name": updated_name, "is_active": False} response = client.patch( f"{settings.API_V1_STR}/organizations/{test_organization.id}", diff --git a/backend/app/tests/api/routes/test_project.py b/backend/app/tests/api/routes/test_project.py index 98d1f96d..69b11cea 100644 --- a/backend/app/tests/api/routes/test_project.py +++ b/backend/app/tests/api/routes/test_project.py @@ -1,61 +1,26 @@ import pytest from fastapi.testclient import TestClient from sqlmodel import Session -from app.core.security import decrypt_api_key, verify_password from app.main import app from app.core.config import settings -from app.models import Project, ProjectCreate, ProjectUpdate -from app.models import Organization, OrganizationCreate, ProjectUpdate -from app.api.deps import get_db -from app.tests.utils.utils import random_lower_string, random_email -from app.crud.project import create_project, get_project_by_id -from app.crud.organization import create_organization -from app.crud import api_key as api_key_crud +from app.models import Project, ProjectCreate +from app.tests.utils.test_data import create_test_organization, create_test_project + client = TestClient(app) @pytest.fixture -def test_project(db: Session, superuser_token_headers: dict[str, str]): - unique_org_name = f"TestOrg-{random_lower_string()}" - org_data = OrganizationCreate(name=unique_org_name, is_active=True) - organization = create_organization(session=db, org_create=org_data) - db.commit() - - unique_project_name = f"TestProject-{random_lower_string()}" - project_description = "This is a test project description." - project_data = ProjectCreate( - name=unique_project_name, - description=project_description, - is_active=True, - organization_id=organization.id, - ) - project = create_project(session=db, project_create=project_data) - db.commit() - - return project - - -# Test retrieving projects -def test_read_projects(db: Session, superuser_token_headers: dict[str, str]): - response = client.get( - f"{settings.API_V1_STR}/projects/", headers=superuser_token_headers - ) - assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert isinstance(response_data["data"], list) +def test_project(db: Session) -> Project: + return create_test_project(db) # Test creating a project def test_create_new_project(db: Session, superuser_token_headers: dict[str, str]): - unique_org_name = f"TestOrg-{random_lower_string()}" - org_data = OrganizationCreate(name=unique_org_name, is_active=True) - organization = create_organization(session=db, org_create=org_data) - db.commit() + organization = create_test_organization(db) - unique_project_name = f"TestProject-{random_lower_string()}" + unique_project_name = "TestProject" project_description = "This is a test project description." project_data = ProjectCreate( name=unique_project_name, @@ -82,6 +47,17 @@ def test_create_new_project(db: Session, superuser_token_headers: dict[str, str] assert created_project["data"]["organization_id"] == organization.id +# Test retrieving projects +def test_read_projects(db: Session, superuser_token_headers: dict[str, str]): + response = client.get( + f"{settings.API_V1_STR}/projects/", headers=superuser_token_headers + ) + assert response.status_code == 200 + response_data = response.json() + assert "data" in response_data + assert isinstance(response_data["data"], list) + + # Test updating a project def test_update_project( db: Session, test_project: Project, superuser_token_headers: dict[str, str] diff --git a/backend/app/tests/api/routes/test_responses.py b/backend/app/tests/api/routes/test_responses.py index 257d2585..698cd1c6 100644 --- a/backend/app/tests/api/routes/test_responses.py +++ b/backend/app/tests/api/routes/test_responses.py @@ -6,7 +6,6 @@ from app.api.routes.responses import router from app.models import Project -from app.seed_data.seed_data import seed_database # Wrap the router in a FastAPI app instance app = FastAPI() @@ -14,14 +13,6 @@ client = TestClient(app) -@pytest.fixture(scope="function", autouse=True) -def load_seed_data(db): - """Load seed data before each test.""" - seed_database(db) - yield - # Cleanup is handled by the db fixture in conftest.py - - @patch("app.api.routes.responses.OpenAI") @patch("app.api.routes.responses.get_provider_credential") def test_responses_endpoint_success( diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index e2a6464d..9f6fe76c 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -20,6 +20,7 @@ ) from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers +from app.seed_data.seed_data import seed_database @pytest.fixture(scope="session", autouse=True) @@ -56,3 +57,9 @@ def normal_user_token_headers(client: TestClient, db: Session) -> dict[str, str] return authentication_token_from_email( client=client, email=settings.EMAIL_TEST_USER, db=db ) + + +@pytest.fixture(scope="session", autouse=True) +def load_seed_data(db): + seed_database(db) + yield diff --git a/backend/app/tests/crud/test_api_key.py b/backend/app/tests/crud/test_api_key.py index 197756bd..a0f89ba3 100644 --- a/backend/app/tests/crud/test_api_key.py +++ b/backend/app/tests/crud/test_api_key.py @@ -1,86 +1,45 @@ -import uuid -import pytest -from datetime import datetime from sqlmodel import Session, select -from app.crud import api_key as api_key_crud -from app.models import APIKey, User, Organization, Project -from app.tests.utils.utils import random_email -from app.core.security import get_password_hash, verify_password, decrypt_api_key -from app.core.exception_handlers import HTTPException - - -# Helper function to create a user -def create_test_user(db: Session) -> User: - user = User(email=random_email(), hashed_password=get_password_hash("password123")) - db.add(user) - db.commit() - db.refresh(user) - return user - -# Helper function to create an organization with a random name -def create_test_organization(db: Session) -> Organization: - org = Organization( - name=f"Test Organization {uuid.uuid4()}", description="Test Organization" - ) - db.add(org) - db.commit() - db.refresh(org) - return org - - -def create_test_project(db: Session, organization_id: int) -> Project: - project = Project( - name=f"Test Project {uuid.uuid4()}", - description="Test project", - organization_id=organization_id, - is_active=True, - ) - db.add(project) - db.commit() - db.refresh(project) - return project +from app.crud import api_key as api_key_crud +from app.models import APIKey +from app.tests.utils.utils import get_non_existent_id +from app.tests.utils.user import create_random_user +from app.tests.utils.test_data import create_test_api_key, create_test_project def test_create_api_key(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) + user = create_random_user(db) + project = create_test_project(db) - api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) + api_key = api_key_crud.create_api_key( + db, project.organization_id, user.id, project.id + ) assert api_key.key.startswith("ApiKey ") assert len(api_key.key) > 32 - assert api_key.organization_id == org.id + assert api_key.organization_id == project.organization_id assert api_key.user_id == user.id assert api_key.project_id == project.id def test_get_api_key(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) - - created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) - retrieved_key = api_key_crud.get_api_key(db, created_key.id) + api_key = create_test_api_key(db) + retrieved_key = api_key_crud.get_api_key(db, api_key.id) assert retrieved_key is not None - assert retrieved_key.id == created_key.id + assert retrieved_key.id == api_key.id assert retrieved_key.key.startswith("ApiKey ") - assert retrieved_key.project_id == project.id + assert retrieved_key.project_id == api_key.project_id def test_get_api_key_not_found(db: Session) -> None: - result = api_key_crud.get_api_key(db, 9999) # Non-existent ID + api_key_id = get_non_existent_id(db, APIKey) + result = api_key_crud.get_api_key(db, api_key_id) assert result is None def test_delete_api_key(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) - - api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) + api_key = create_test_api_key(db) api_key_crud.delete_api_key(db, api_key.id) deleted_key = db.exec(select(APIKey).where(APIKey.id == api_key.id)).first() @@ -91,13 +50,7 @@ def test_delete_api_key(db: Session) -> None: def test_get_api_key_by_value(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) - - # Create an API key - api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) - # Get the raw key that was returned during creation + api_key = create_test_api_key(db) raw_key = api_key.key # Test retrieving the API key by its value @@ -105,8 +58,8 @@ def test_get_api_key_by_value(db: Session) -> None: assert retrieved_key is not None assert retrieved_key.id == api_key.id - assert retrieved_key.organization_id == org.id - assert retrieved_key.user_id == user.id + assert retrieved_key.organization_id == api_key.organization_id + assert retrieved_key.user_id == api_key.user_id # The key should be in its original format assert retrieved_key.key == raw_key # Should be exactly the same key assert retrieved_key.key.startswith("ApiKey ") @@ -114,11 +67,12 @@ def test_get_api_key_by_value(db: Session) -> None: def test_get_api_key_by_project_user(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) + user = create_random_user(db) + project = create_test_project(db) - created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) + created_key = api_key_crud.create_api_key( + db, project.organization_id, user.id, project.id + ) retrieved_key = api_key_crud.get_api_key_by_project_user(db, project.id, user.id) assert retrieved_key is not None @@ -128,11 +82,13 @@ def test_get_api_key_by_project_user(db: Session) -> None: def test_get_api_keys_by_project(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) + user = create_random_user(db) + project = create_test_project(db) + + created_key = api_key_crud.create_api_key( + db, project.organization_id, user.id, project.id + ) - created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) retrieved_keys = api_key_crud.get_api_keys_by_project(db, project.id) assert retrieved_keys is not None diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index f47ea8d6..1ff20105 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -1,39 +1,26 @@ -import uuid from sqlmodel import Session import pytest -from datetime import datetime -from app.crud import credentials as credentials_crud -from app.models import Credential, CredsCreate, CredsUpdate, Organization, Project -from app.tests.utils.utils import random_email -from app.core.security import get_password_hash - - -def create_organization_and_project(db: Session) -> tuple[Organization, Project]: - """Helper function to create an organization and a project.""" - organization = Organization( - name=f"Test Organization {uuid.uuid4()}", is_active=True - ) - db.add(organization) - db.commit() - db.refresh(organization) - - project = Project( - name=f"Test Project {uuid.uuid4()}", - description="A test project", - organization_id=organization.id, - is_active=True, - ) - db.add(project) - db.commit() - db.refresh(project) - - return organization, project +from app.crud import ( + set_creds_for_org, + get_creds_by_org, + get_provider_credential, + update_creds_for_org, + remove_provider_credential, + remove_creds_for_org, +) +from app.models import CredsCreate, CredsUpdate +from app.core.providers import Provider +from app.tests.utils.test_data import ( + create_test_project, + create_test_credential, + test_credential_data, +) def test_set_creds_for_org(db: Session) -> None: """Test setting credentials for an organization.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Test credentials for supported providers creds_data = { @@ -45,42 +32,26 @@ def test_set_creds_for_org(db: Session) -> None: }, } - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - - created_creds = credentials_crud.set_creds_for_org( - session=db, creds_add=creds_create + creds_create = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + credential=creds_data, ) + created_creds = set_creds_for_org(session=db, creds_add=creds_create) + assert len(created_creds) == 2 - assert all(cred.organization_id == organization.id for cred in created_creds) + assert all( + cred.organization_id == project.organization_id for cred in created_creds + ) + assert all(cred.project_id == project.id for cred in created_creds) assert all(cred.is_active for cred in created_creds) assert {cred.provider for cred in created_creds} == {"openai", "langfuse"} -def test_set_creds_for_org_with_project(db: Session) -> None: - """Test setting credentials for an organization with a specific project.""" - organization, project = create_organization_and_project(db) - - creds_data = {"openai": {"api_key": "test-openai-key"}} - - creds_create = CredsCreate( - organization_id=organization.id, project_id=project.id, credential=creds_data - ) - - created_creds = credentials_crud.set_creds_for_org( - session=db, creds_add=creds_create - ) - - assert len(created_creds) == 1 - assert created_creds[0].organization_id == organization.id - assert created_creds[0].project_id == project.id - assert created_creds[0].provider == "openai" - assert created_creds[0].is_active - - def test_get_creds_by_org(db: Session) -> None: """Test retrieving all credentials for an organization.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Set up test credentials creds_data = { @@ -92,101 +63,81 @@ def test_get_creds_by_org(db: Session) -> None: }, } - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + credential=creds_data, + ) + set_creds_for_org(session=db, creds_add=creds_create) # Test retrieving credentials - retrieved_creds = credentials_crud.get_creds_by_org( - session=db, org_id=organization.id - ) + retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) assert len(retrieved_creds) == 2 - assert all(cred.organization_id == organization.id for cred in retrieved_creds) + assert all( + cred.organization_id == project.organization_id for cred in retrieved_creds + ) assert {cred.provider for cred in retrieved_creds} == {"openai", "langfuse"} def test_get_provider_credential(db: Session) -> None: """Test retrieving credentials for a specific provider.""" - organization, _ = create_organization_and_project(db) - - # Set up test credentials - creds_data = {"openai": {"api_key": "test-openai-key"}} - - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds_create = test_credential_data(db) + original_api_key = creds_create.credential[Provider.OPENAI.value]["api_key"] + set_creds_for_org(session=db, creds_add=creds_create) # Test retrieving specific provider credentials - retrieved_cred = credentials_crud.get_provider_credential( - session=db, org_id=organization.id, provider="openai" + retrieved_cred = get_provider_credential( + session=db, org_id=creds_create.organization_id, provider="openai" ) assert retrieved_cred is not None assert "api_key" in retrieved_cred - assert retrieved_cred["api_key"] == "test-openai-key" + assert retrieved_cred["api_key"] == original_api_key def test_update_creds_for_org(db: Session) -> None: """Test updating credentials for a provider.""" - organization, _ = create_organization_and_project(db) - - # Set up initial credentials - initial_creds = {"openai": {"api_key": "initial-key"}} - creds_create = CredsCreate( - organization_id=organization.id, credential=initial_creds - ) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds = create_test_credential(db)[0] # Update credentials updated_creds = {"api_key": "updated-key"} creds_update = CredsUpdate(provider="openai", credential=updated_creds) - updated = credentials_crud.update_creds_for_org( - session=db, org_id=organization.id, creds_in=creds_update + updated = update_creds_for_org( + session=db, org_id=creds.organization_id, creds_in=creds_update ) assert len(updated) == 1 assert updated[0].provider == "openai" - retrieved_cred = credentials_crud.get_provider_credential( - session=db, org_id=organization.id, provider="openai" + retrieved_cred = get_provider_credential( + session=db, org_id=creds.organization_id, provider="openai" ) assert retrieved_cred["api_key"] == "updated-key" def test_remove_provider_credential(db: Session) -> None: """Test removing credentials for a specific provider.""" - organization, _ = create_organization_and_project(db) - - # Set up test credentials - creds_data = { - "openai": {"api_key": "test-openai-key"}, - "langfuse": { - "public_key": "test-public-key", - "secret_key": "test-secret-key", - "host": "https://cloud.langfuse.com", - }, - } - - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds = create_test_credential(db)[0] # Remove one provider's credentials - removed = credentials_crud.remove_provider_credential( - session=db, org_id=organization.id, provider="openai" + removed = remove_provider_credential( + session=db, org_id=creds.organization_id, provider="openai" ) assert removed.is_active is False assert removed.updated_at is not None # Verify the credentials are no longer retrievable - retrieved_cred = credentials_crud.get_provider_credential( - session=db, org_id=organization.id, provider="openai" + retrieved_cred = get_provider_credential( + session=db, org_id=creds.organization_id, provider="openai" ) assert retrieved_cred is None def test_remove_creds_for_org(db: Session) -> None: """Test removing all credentials for an organization.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Set up test credentials creds_data = { @@ -198,49 +149,55 @@ def test_remove_creds_for_org(db: Session) -> None: }, } - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + credential=creds_data, + ) + set_creds_for_org(session=db, creds_add=creds_create) # Remove all credentials - removed = credentials_crud.remove_creds_for_org(session=db, org_id=organization.id) + removed = remove_creds_for_org(session=db, org_id=project.organization_id) assert len(removed) == 2 assert all(not cred.is_active for cred in removed) assert all(cred.updated_at is not None for cred in removed) # Verify no credentials are retrievable - retrieved_creds = credentials_crud.get_creds_by_org( - session=db, org_id=organization.id - ) + retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) assert len(retrieved_creds) == 0 def test_invalid_provider(db: Session) -> None: """Test handling of invalid provider names.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Test with unsupported provider creds_data = {"gemini": {"api_key": "test-key"}} - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + creds_create = CredsCreate( + organization_id=project.organization_id, credential=creds_data + ) with pytest.raises(ValueError, match="Unsupported provider"): - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=creds_create) def test_duplicate_provider_credentials(db: Session) -> None: """Test handling of duplicate provider credentials.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Set up initial credentials creds_data = {"openai": {"api_key": "test-key"}} - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate( + organization_id=project.organization_id, credential=creds_data + ) + set_creds_for_org(session=db, creds_add=creds_create) # Verify credentials exist and are active - existing_creds = credentials_crud.get_provider_credential( - session=db, org_id=organization.id, provider="openai" + existing_creds = get_provider_credential( + session=db, org_id=project.organization_id, provider="openai" ) assert existing_creds is not None assert "api_key" in existing_creds @@ -249,7 +206,7 @@ def test_duplicate_provider_credentials(db: Session) -> None: def test_langfuse_credential_validation(db: Session) -> None: """Test validation of Langfuse credentials structure.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Test with missing required fields invalid_creds = { @@ -261,11 +218,11 @@ def test_langfuse_credential_validation(db: Session) -> None: } creds_create = CredsCreate( - organization_id=organization.id, credential=invalid_creds + organization_id=project.organization_id, credential=invalid_creds ) with pytest.raises(ValueError): - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=creds_create) # Test with valid Langfuse credentials valid_creds = { @@ -276,10 +233,10 @@ def test_langfuse_credential_validation(db: Session) -> None: } } - creds_create = CredsCreate(organization_id=organization.id, credential=valid_creds) - - created_creds = credentials_crud.set_creds_for_org( - session=db, creds_add=creds_create + creds_create = CredsCreate( + organization_id=project.organization_id, credential=valid_creds ) + + created_creds = set_creds_for_org(session=db, creds_add=creds_create) assert len(created_creds) == 1 assert created_creds[0].provider == "langfuse" diff --git a/backend/app/tests/crud/test_org.py b/backend/app/tests/crud/test_org.py index 7efba0ec..052988ff 100644 --- a/backend/app/tests/crud/test_org.py +++ b/backend/app/tests/crud/test_org.py @@ -3,6 +3,7 @@ from app.crud.organization import create_organization, get_organization_by_id from app.models import Organization, OrganizationCreate from app.tests.utils.utils import random_lower_string, get_non_existent_id +from app.tests.utils.test_data import create_test_organization def test_create_organization(db: Session) -> None: @@ -13,14 +14,12 @@ def test_create_organization(db: Session) -> None: assert org.name == name assert org.id is not None - assert org.is_active is True # Default should be active + assert org.is_active is True def test_get_organization_by_id(db: Session) -> None: """Test retrieving an organization by ID.""" - name = random_lower_string() - org_in = OrganizationCreate(name=name) - org = create_organization(session=db, org_create=org_in) + org = create_test_organization(db) fetched_org = get_organization_by_id(session=db, org_id=org.id) assert fetched_org @@ -31,7 +30,5 @@ def test_get_organization_by_id(db: Session) -> None: def test_get_non_existent_organization(db: Session) -> None: """Test retrieving a non-existent organization should return None.""" org_id = get_non_existent_id(db, Organization) - fetched_org = get_organization_by_id( - session=db, org_id=org_id - ) # Assuming ID 999 does not exist + fetched_org = get_organization_by_id(session=db, org_id=org_id) assert fetched_org is None diff --git a/backend/app/tests/crud/test_project.py b/backend/app/tests/crud/test_project.py index c9b135ce..041821cc 100644 --- a/backend/app/tests/crud/test_project.py +++ b/backend/app/tests/crud/test_project.py @@ -2,22 +2,20 @@ from sqlmodel import Session from fastapi import HTTPException -from app.models import Project, ProjectCreate, Organization +from app.models import Project, ProjectCreate from app.crud.project import ( create_project, get_project_by_id, get_projects_by_organization, validate_project, ) -from app.tests.utils.utils import random_lower_string +from app.tests.utils.utils import random_lower_string, get_non_existent_id +from app.tests.utils.test_data import create_test_project, create_test_organization def test_create_project(db: Session) -> None: """Test creating a project linked to an organization.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) + org = create_test_organization(db) project_name = random_lower_string() project_data = ProjectCreate( @@ -37,17 +35,7 @@ def test_create_project(db: Session) -> None: def test_get_project_by_id(db: Session) -> None: """Test retrieving a project by ID.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) - - project_name = random_lower_string() - project_data = ProjectCreate( - name=project_name, description="Test", organization_id=org.id - ) - - project = create_project(session=db, project_create=project_data) + project = create_test_project(db) fetched_project = get_project_by_id(session=db, project_id=project.id) assert fetched_project is not None @@ -57,53 +45,43 @@ def test_get_project_by_id(db: Session) -> None: def test_get_projects_by_organization(db: Session) -> None: """Test retrieving all projects for an organization.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) + org = create_test_organization(db) project_1 = create_project( session=db, project_create=ProjectCreate( - name=random_lower_string(), organization_id=org.id + name="Project 1", + description="Test project 1", + is_active=True, + organization_id=org.id, ), ) + project_2 = create_project( session=db, project_create=ProjectCreate( - name=random_lower_string(), organization_id=org.id + name="Project 2", + description="Test project 2", + is_active=True, + organization_id=org.id, ), ) projects = get_projects_by_organization(session=db, org_id=org.id) - assert len(projects) == 2 assert project_1 in projects assert project_2 in projects def test_get_non_existent_project(db: Session) -> None: - """Test retrieving a non-existent project should return None.""" - fetched_project = get_project_by_id(session=db, project_id=999) + non_existent_project_id = get_non_existent_id(db, Project) + fetched_project = get_project_by_id(session=db, project_id=non_existent_project_id) assert fetched_project is None def test_validate_project_success(db: Session) -> None: """Test that a valid and active project passes validation.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) - - project = create_project( - session=db, - project_create=ProjectCreate( - name=random_lower_string(), - description="Valid project", - is_active=True, - organization_id=org.id, - ), - ) + project = create_test_project(db) validated_project = validate_project(session=db, project_id=project.id) assert validated_project.id == project.id @@ -111,16 +89,14 @@ def test_validate_project_success(db: Session) -> None: def test_validate_project_not_found(db: Session) -> None: """Test that validation fails when project does not exist.""" + non_existent_project_id = get_non_existent_id(db, Project) with pytest.raises(HTTPException, match="Project not found"): - validate_project(session=db, project_id=9999) + validate_project(session=db, project_id=non_existent_project_id) def test_validate_project_inactive(db: Session) -> None: """Test that validation fails when project is inactive.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) + org = create_test_organization(db) inactive_project = create_project( session=db, diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py new file mode 100644 index 00000000..a623cd1e --- /dev/null +++ b/backend/app/tests/utils/test_data.py @@ -0,0 +1,114 @@ +from sqlmodel import Session +from app.models import ( + Organization, + Project, + APIKey, + Credential, + OrganizationCreate, + ProjectCreate, + APIKeyPublic, + CredsCreate, +) +from app.crud import ( + create_organization, + create_project, + create_api_key, + set_creds_for_org, +) +from app.core.providers import Provider +from app.tests.utils.user import create_random_user +from app.tests.utils.utils import random_lower_string, generate_random_string + + +def create_test_organization(db: Session) -> Organization: + """ + Creates and returns a test organization with a unique name. + + Persists the organization to the database. + """ + name = f"TestOrg-{random_lower_string()}" + org_in = OrganizationCreate(name=name, is_active=True) + return create_organization(session=db, org_create=org_in) + + +def create_test_project(db: Session) -> Project: + """ + Creates and returns a test project under a newly created test organization. + + Persists both the organization and the project to the database. + + """ + org = create_test_organization(db) + name = f"TestProject-{random_lower_string()}" + project_description = "This is a test project description." + project_in = ProjectCreate( + name=name, + description=project_description, + is_active=True, + organization_id=org.id, + ) + return create_project(session=db, project_create=project_in) + + +def create_test_api_key(db: Session) -> APIKey: + """ + Creates and returns an API key for a test project and test user. + + Persists a test user, organization, project, and API key to the database + """ + project = create_test_project(db) + user = create_random_user(db) + api_key = create_api_key( + db, + organization_id=project.organization_id, + user_id=user.id, + project_id=project.id, + ) + return api_key + + +def test_credential_data(db: Session) -> CredsCreate: + """ + Returns credential data for a test project in the form of a CredsCreate schema. + + Use this when you just need credential input data without persisting it to the database. + """ + project = create_test_project(db) + api_key = "sk-" + generate_random_string(10) + creds_data = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + is_active=True, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, + ) + return creds_data + + +def create_test_credential(db: Session) -> Credential: + """ + Creates and returns a test credential for a test project. + + Persists the organization, project, and credential to the database. + + """ + project = create_test_project(db) + api_key = "sk-" + generate_random_string(10) + creds_data = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + is_active=True, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, + ) + return set_creds_for_org(session=db, creds_add=creds_data) diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 9cba7aae..351444b4 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -26,6 +26,10 @@ def random_lower_string() -> str: return "".join(random.choices(string.ascii_lowercase, k=32)) +def generate_random_string(length=10): + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + def random_email() -> str: return f"{random_lower_string()}@{random_lower_string()}.com" From 5d2a54736bb322f616497a18ae166ab37fc6ed38 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 9 Jul 2025 00:45:53 +0530 Subject: [PATCH 2/3] collections util --- backend/app/tests/utils/collection.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index ded5c728..0a68125c 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -7,6 +7,7 @@ from app.core.config import settings from app.models import Collection, Organization, Project from app.tests.utils.utils import get_user_id_by_email +from app.tests.utils.test_data import create_test_project from app.crud import create_api_key @@ -24,19 +25,14 @@ def get_collection(db: Session, client=None): owner_id = get_user_id_by_email(db) # Step 1: Create real organization and project entries - organization = Organization(name=f"Test Org {uuid4()}") - db.add(organization) - db.commit() - db.refresh(organization) - - project = Project(name="Test Project {uuid4()}", organization_id=organization.id) - db.add(project) - db.commit() - db.refresh(project) + project = create_test_project(db) # Step 2: Create API key for user with valid foreign keys create_api_key( - db, organization_id=organization.id, user_id=owner_id, project_id=project.id + db, + organization_id=project.organization_id, + user_id=owner_id, + project_id=project.id, ) if client is None: @@ -51,7 +47,7 @@ def get_collection(db: Session, client=None): return Collection( owner_id=owner_id, - organization_id=organization.id, + organization_id=project.organization_id, project_id=project.id, llm_service_id=assistant.id, llm_service_name=constants.llm_service_name, From bdc7244438f83d2a3ae484f68bc85daafd19e477 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 9 Jul 2025 13:00:31 +0530 Subject: [PATCH 3/3] keeping only seeding fixture changes --- backend/app/crud/__init__.py | 2 - backend/app/tests/api/routes/test_api_key.py | 95 ++++++--- backend/app/tests/api/routes/test_creds.py | 195 +++++++++++------- backend/app/tests/api/routes/test_org.py | 47 +++-- backend/app/tests/api/routes/test_project.py | 60 ++++-- backend/app/tests/crud/test_api_key.py | 108 +++++++--- backend/app/tests/crud/test_credentials.py | 205 +++++++++++-------- backend/app/tests/crud/test_org.py | 11 +- backend/app/tests/crud/test_project.py | 66 ++++-- backend/app/tests/utils/collection.py | 18 +- backend/app/tests/utils/test_data.py | 114 ----------- backend/app/tests/utils/utils.py | 4 - 12 files changed, 521 insertions(+), 404 deletions(-) delete mode 100644 backend/app/tests/utils/test_data.py diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 5c6fd102..6bf4d5da 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -37,8 +37,6 @@ get_key_by_org, update_creds_for_org, remove_creds_for_org, - get_provider_credential, - remove_provider_credential, ) from .thread_results import upsert_thread_result, get_thread_result diff --git a/backend/app/tests/api/routes/test_api_key.py b/backend/app/tests/api/routes/test_api_key.py index 7d018a66..552530cc 100644 --- a/backend/app/tests/api/routes/test_api_key.py +++ b/backend/app/tests/api/routes/test_api_key.py @@ -1,23 +1,51 @@ +import uuid +import pytest from fastapi.testclient import TestClient from sqlmodel import Session - from app.main import app -from app.models import APIKey +from app.models import APIKey, User, Organization, Project from app.core.config import settings -from app.tests.utils.utils import get_non_existent_id -from app.tests.utils.user import create_random_user -from app.tests.utils.test_data import ( - create_test_api_key, - create_test_project, - create_test_organization, -) +from app.crud.api_key import create_api_key +from app.tests.utils.utils import random_email +from app.core.security import get_password_hash client = TestClient(app) +def create_test_user(db: Session) -> User: + user = User( + email=random_email(), + hashed_password=get_password_hash("password123"), + is_superuser=True, + ) + db.add(user) + db.commit() + db.refresh(user) + return user + + +def create_test_organization(db: Session) -> Organization: + org = Organization( + name=f"Test Organization {uuid.uuid4()}", description="Test Organization" + ) + db.add(org) + db.commit() + db.refresh(org) + return org + + +def create_test_project(db: Session, organization_id: int) -> Project: + project = Project(name="Test Project", organization_id=organization_id) + db.add(project) + db.commit() + db.refresh(project) + return project + + def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_random_user(db) - project = create_test_project(db) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, organization_id=org.id) response = client.post( f"{settings.API_V1_STR}/apikeys", @@ -29,13 +57,14 @@ def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]): assert data["success"] is True assert "id" in data["data"] assert "key" in data["data"] - assert data["data"]["organization_id"] == project.organization_id + assert data["data"]["organization_id"] == org.id assert data["data"]["user_id"] == user.id def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_random_user(db) - project = create_test_project(db) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, organization_id=org.id) client.post( f"{settings.API_V1_STR}/apikeys", @@ -52,11 +81,16 @@ def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, organization_id=org.id) + api_key = create_api_key( + db, organization_id=org.id, user_id=user.id, project_id=project.id + ) response = client.get( f"{settings.API_V1_STR}/apikeys", - params={"project_id": api_key.project_id}, + params={"project_id": project.id}, headers=superuser_token_headers, ) assert response.status_code == 200 @@ -66,12 +100,17 @@ def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]): assert len(data["data"]) > 0 first_key = data["data"][0] - assert first_key["organization_id"] == api_key.organization_id - assert first_key["user_id"] == api_key.user_id + assert first_key["organization_id"] == org.id + assert first_key["user_id"] == user.id def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, organization_id=org.id) + api_key = create_api_key( + db, organization_id=org.id, user_id=user.id, project_id=project.id + ) response = client.get( f"{settings.API_V1_STR}/apikeys/{api_key.id}", @@ -82,14 +121,12 @@ def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): assert data["success"] is True assert data["data"]["id"] == api_key.id assert data["data"]["organization_id"] == api_key.organization_id - assert data["data"]["user_id"] == api_key.user_id + assert data["data"]["user_id"] == user.id def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key_id = get_non_existent_id(db, APIKey) - print(api_key_id) response = client.get( - f"{settings.API_V1_STR}/apikeys/{api_key_id}", + f"{settings.API_V1_STR}/apikeys/999999", headers=superuser_token_headers, ) assert response.status_code == 404 @@ -97,7 +134,12 @@ def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, organization_id=org.id) + api_key = create_api_key( + db, organization_id=org.id, user_id=user.id, project_id=project.id + ) response = client.delete( f"{settings.API_V1_STR}/apikeys/{api_key.id}", @@ -112,10 +154,11 @@ def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]): def test_revoke_nonexistent_api_key( db: Session, superuser_token_headers: dict[str, str] ): - api_key_id = get_non_existent_id(db, APIKey) + user = create_test_user(db) + org = create_test_organization(db) response = client.delete( - f"{settings.API_V1_STR}/apikeys/{api_key_id}", + f"{settings.API_V1_STR}/apikeys/999999", headers=superuser_token_headers, ) assert response.status_code == 404 diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 5d0ff141..428d4c1c 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -1,37 +1,59 @@ import pytest from fastapi.testclient import TestClient from sqlmodel import Session +import random +import string from app.main import app -from app.models import Organization, Project +from app.api.deps import get_db +from app.crud.credentials import set_creds_for_org +from app.models import CredsCreate, Organization, OrganizationCreate, Project from app.core.config import settings +from app.core.security import encrypt_api_key from app.core.providers import Provider from app.models.credentials import Credential from app.core.security import decrypt_credentials -from app.tests.utils.utils import generate_random_string, get_non_existent_id -from app.tests.utils.test_data import ( - create_test_credential, - create_test_organization, - create_test_project, - test_credential_data, -) client = TestClient(app) +def generate_random_string(length=10): + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + @pytest.fixture -def create_test_credentials(db: Session): - return create_test_credential(db) +def create_organization_and_creds(db: Session): + unique_org_name = "Test Organization " + generate_random_string(5) + org = Organization(name=unique_org_name, is_active=True) + db.add(org) + db.commit() + db.refresh(org) + + api_key = "sk-" + generate_random_string(10) + creds_data = CredsCreate( + organization_id=org.id, + is_active=True, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, + ) + return org, creds_data -def test_set_creds(db: Session, superuser_token_headers: dict[str, str]): - project = create_test_project(db) +def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]): + org = Organization(name="Org for Set Creds", is_active=True) + db.add(org) + db.commit() + db.refresh(org) api_key = "sk-" + generate_random_string(10) creds_data = { - "organization_id": project.organization_id, - "project_id": project.id, + "organization_id": org.id, "is_active": True, "credential": { Provider.OPENAI.value: { @@ -52,7 +74,7 @@ def test_set_creds(db: Session, superuser_token_headers: dict[str, str]): data = response.json()["data"] assert isinstance(data, list) assert len(data) == 1 - assert data[0]["organization_id"] == project.organization_id + assert data[0]["organization_id"] == org.id assert data[0]["provider"] == Provider.OPENAI.value assert data[0]["credential"]["model"] == "gpt-4" @@ -60,9 +82,19 @@ def test_set_creds(db: Session, superuser_token_headers: dict[str, str]): def test_set_creds_for_invalid_project_org_relationship( db: Session, superuser_token_headers: dict[str, str] ): - org1 = create_test_organization(db) - project2 = create_test_project(db) - + # Setup: Create two organizations and one project for each + org1 = Organization(name="Org 1", is_active=True) + org2 = Organization(name="Org 2", is_active=True) + db.add_all([org1, org2]) + db.commit() + db.refresh(org1) + db.refresh(org2) + + project2 = Project(name="Project Org2", organization_id=org2.id) + db.add(project2) + db.commit() + + # Invalid case: Organization mismatch (org1's creds for project2 of org2) creds_data_invalid = { "organization_id": org1.id, "is_active": True, @@ -87,13 +119,15 @@ def test_set_creds_for_project_not_found( db: Session, superuser_token_headers: dict[str, str] ): # Setup: Create an organization but no project - org = create_test_organization(db) - non_existent_project_id = get_non_existent_id(db, Project) + org = Organization(name="Org for Project Not Found", is_active=True) + db.add(org) + db.commit() + db.refresh(org) creds_data_invalid_project = { "organization_id": org.id, "is_active": True, - "project_id": non_existent_project_id, + "project_id": 99999, "credential": {Provider.OPENAI.value: {"api_key": "sk-123", "model": "gpt-4"}}, } @@ -108,12 +142,13 @@ def test_set_creds_for_project_not_found( def test_read_credentials_with_creds( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - creds = create_test_credentials[0] - print(creds) + org, creds_data = create_organization_and_creds + set_creds_for_org(session=db, creds_add=creds_data) + response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers, ) @@ -121,7 +156,7 @@ def test_read_credentials_with_creds( data = response.json()["data"] assert isinstance(data, list) assert len(data) == 1 - assert data[0]["organization_id"] == creds.organization_id + assert data[0]["organization_id"] == org.id assert data[0]["provider"] == Provider.OPENAI.value assert data[0]["credential"]["model"] == "gpt-4" @@ -129,9 +164,8 @@ def test_read_credentials_with_creds( def test_read_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): - non_existent_creds_id = get_non_existent_id(db, Credential) response = client.get( - f"{settings.API_V1_STR}/credentials/{non_existent_creds_id}", + f"{settings.API_V1_STR}/credentials/999999", headers=superuser_token_headers, ) assert response.status_code == 404 @@ -139,13 +173,13 @@ def test_read_credentials_not_found( def test_read_provider_credential( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - creds = create_test_credentials[0] - print(creds) + org, creds_data = create_organization_and_creds + set_creds_for_org(session=db, creds_add=creds_data) response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) @@ -156,9 +190,9 @@ def test_read_provider_credential( def test_read_provider_credential_not_found( - db: Session, superuser_token_headers: dict[str, str] + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - org = create_test_organization(db) + org, _ = create_organization_and_creds response = client.get( f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", @@ -170,9 +204,10 @@ def test_read_provider_credential_not_found( def test_update_credentials( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - creds = create_test_credentials[0] + org, creds_data = create_organization_and_creds + set_creds_for_org(session=db, creds_add=creds_data) update_data = { "provider": Provider.OPENAI.value, @@ -184,7 +219,7 @@ def test_update_credentials( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{org.id}", json=update_data, headers=superuser_token_headers, ) @@ -199,14 +234,21 @@ def test_update_credentials( def test_update_credentials_failed_update( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - creds = create_test_credentials[0] + org, creds_data = create_organization_and_creds + + set_creds_for_org(session=db, creds_add=creds_data) - org_without_creds = create_test_organization(db) + org_without_creds = Organization(name="Org Without Creds", is_active=True) + db.add(org_without_creds) + db.commit() + db.refresh(org_without_creds) existing_creds = ( - db.query(Credential).filter(creds.organization_id == org_without_creds.id).all() + db.query(Credential) + .filter(Credential.organization_id == org_without_creds.id) + .all() ) assert len(existing_creds) == 0 @@ -220,7 +262,7 @@ def test_update_credentials_failed_update( } response_invalid_org = client.patch( - f"{settings.API_V1_STR}/credentials/{org_without_creds.id}", + f"{settings.API_V1_STR}/credentials/{org_without_creds.id}", # Valid org id but no creds json=update_data, headers=superuser_token_headers, ) @@ -234,7 +276,8 @@ def test_update_credentials_failed_update( def test_update_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): - non_existent_org_id = get_non_existent_id(db, Organization) + # Create a non-existent organization ID + non_existent_org_id = 999999 update_data = { "provider": Provider.OPENAI.value, @@ -256,12 +299,13 @@ def test_update_credentials_not_found( def test_delete_provider_credential( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - creds = create_test_credentials[0] + org, creds_data = create_organization_and_creds + set_creds_for_org(session=db, creds_add=creds_data) response = client.delete( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) @@ -271,9 +315,9 @@ def test_delete_provider_credential( def test_delete_provider_credential_not_found( - db: Session, superuser_token_headers: dict[str, str] + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - org = create_test_organization(db) + org, _ = create_organization_and_creds response = client.delete( f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", @@ -286,12 +330,13 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - creds = create_test_credentials[0] + org, creds_data = create_organization_and_creds + set_creds_for_org(session=db, creds_add=creds_data) response = client.delete( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers, ) @@ -301,7 +346,7 @@ def test_delete_all_credentials( # Verify the credentials are soft deleted response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers, ) assert response.status_code == 404 # Expect 404 as credentials are soft deleted @@ -311,9 +356,8 @@ def test_delete_all_credentials( def test_delete_all_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): - non_existent_creds_id = get_non_existent_id(db, Credential) response = client.delete( - f"{settings.API_V1_STR}/credentials/{non_existent_creds_id}", + f"{settings.API_V1_STR}/credentials/999999", headers=superuser_token_headers, ) @@ -322,33 +366,31 @@ def test_delete_all_credentials_not_found( def test_duplicate_credential_creation( - db: Session, superuser_token_headers: dict[str, str] + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - creds = test_credential_data(db) - + org, creds_data = create_organization_and_creds + # First create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds.dict(), + json=creds_data.dict(), headers=superuser_token_headers, ) - print(response) assert response.status_code == 200 # Try to create the same credentials again response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds.dict(), + json=creds_data.dict(), headers=superuser_token_headers, ) assert response.status_code == 400 - assert "already exist" in response.json()["error"] def test_multiple_provider_credentials( - db: Session, superuser_token_headers: dict[str, str] + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - org = create_test_organization(db) + org, _ = create_organization_and_creds # Create OpenAI credentials openai_creds = { @@ -404,14 +446,16 @@ def test_multiple_provider_credentials( assert Provider.LANGFUSE.value in providers -def test_credential_encryption(db: Session, superuser_token_headers: dict[str, str]): - creds = test_credential_data(db) - original_api_key = creds.credential[Provider.OPENAI.value]["api_key"] +def test_credential_encryption( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): + org, creds_data = create_organization_and_creds + original_api_key = creds_data.credential[Provider.OPENAI.value]["api_key"] # Create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds.dict(), + json=creds_data.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 @@ -419,7 +463,7 @@ def test_credential_encryption(db: Session, superuser_token_headers: dict[str, s db_cred = ( db.query(Credential) .filter( - Credential.organization_id == creds.organization_id, + Credential.organization_id == org.id, Credential.provider == Provider.OPENAI.value, ) .first() @@ -435,22 +479,22 @@ def test_credential_encryption(db: Session, superuser_token_headers: dict[str, s def test_credential_encryption_consistency( - db: Session, superuser_token_headers: dict[str, str] + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - creds = test_credential_data(db) - original_api_key = creds.credential[Provider.OPENAI.value]["api_key"] + org, creds_data = create_organization_and_creds + original_api_key = creds_data.credential[Provider.OPENAI.value]["api_key"] # Create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds.dict(), + json=creds_data.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 # Fetch the credentials through the API response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 @@ -471,14 +515,15 @@ def test_credential_encryption_consistency( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{org.id}", json=update_data, headers=superuser_token_headers, ) assert response.status_code == 200 + # Verify the updated value is also properly encrypted/decrypted response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 diff --git a/backend/app/tests/api/routes/test_org.py b/backend/app/tests/api/routes/test_org.py index 607e1aef..709bb7f5 100644 --- a/backend/app/tests/api/routes/test_org.py +++ b/backend/app/tests/api/routes/test_org.py @@ -1,25 +1,44 @@ import pytest from fastapi.testclient import TestClient -from sqlmodel import Session +from sqlmodel import Session, select +from app import crud from app.core.config import settings -from app.models import Organization +from app.core.security import verify_password +from app.models import User, UserCreate +from app.tests.utils.utils import random_email, random_lower_string +from app.models import Organization, OrganizationCreate, OrganizationUpdate +from app.api.deps import get_db from app.main import app from app.crud.organization import create_organization, get_organization_by_id -from app.tests.utils.test_data import create_test_organization client = TestClient(app) @pytest.fixture def test_organization(db: Session, superuser_token_headers: dict[str, str]): - return create_test_organization(db) + unique_name = f"TestOrg-{random_lower_string()}" + org_data = OrganizationCreate(name=unique_name, is_active=True) + organization = create_organization(session=db, org_create=org_data) + db.commit() + return organization + + +# Test retrieving organizations +def test_read_organizations(db: Session, superuser_token_headers: dict[str, str]): + response = client.get( + f"{settings.API_V1_STR}/organizations/", headers=superuser_token_headers + ) + assert response.status_code == 200 + response_data = response.json() + assert "data" in response_data + assert isinstance(response_data["data"], list) # Test creating an organization def test_create_organization(db: Session, superuser_token_headers: dict[str, str]): - org_name = "Test-Org" - org_data = {"name": org_name, "is_active": True} + unique_name = f"Org-{random_lower_string()}" + org_data = {"name": unique_name, "is_active": True} response = client.post( f"{settings.API_V1_STR}/organizations/", json=org_data, @@ -36,25 +55,13 @@ def test_create_organization(db: Session, superuser_token_headers: dict[str, str assert org.is_active == created_org_data["is_active"] -# Test retrieving organizations -def test_read_organizations(db: Session, superuser_token_headers: dict[str, str]): - response = client.get( - f"{settings.API_V1_STR}/organizations/", headers=superuser_token_headers - ) - assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert isinstance(response_data["data"], list) - - -# Updating an organization def test_update_organization( db: Session, test_organization: Organization, superuser_token_headers: dict[str, str], ): - updated_name = "UpdatedOrg" - update_data = {"name": updated_name, "is_active": False} + unique_name = f"UpdatedOrg-{random_lower_string()}" # Ensure a unique name + update_data = {"name": unique_name, "is_active": False} response = client.patch( f"{settings.API_V1_STR}/organizations/{test_organization.id}", diff --git a/backend/app/tests/api/routes/test_project.py b/backend/app/tests/api/routes/test_project.py index 69b11cea..98d1f96d 100644 --- a/backend/app/tests/api/routes/test_project.py +++ b/backend/app/tests/api/routes/test_project.py @@ -1,26 +1,61 @@ import pytest from fastapi.testclient import TestClient from sqlmodel import Session +from app.core.security import decrypt_api_key, verify_password from app.main import app from app.core.config import settings -from app.models import Project, ProjectCreate -from app.tests.utils.test_data import create_test_organization, create_test_project - +from app.models import Project, ProjectCreate, ProjectUpdate +from app.models import Organization, OrganizationCreate, ProjectUpdate +from app.api.deps import get_db +from app.tests.utils.utils import random_lower_string, random_email +from app.crud.project import create_project, get_project_by_id +from app.crud.organization import create_organization +from app.crud import api_key as api_key_crud client = TestClient(app) @pytest.fixture -def test_project(db: Session) -> Project: - return create_test_project(db) +def test_project(db: Session, superuser_token_headers: dict[str, str]): + unique_org_name = f"TestOrg-{random_lower_string()}" + org_data = OrganizationCreate(name=unique_org_name, is_active=True) + organization = create_organization(session=db, org_create=org_data) + db.commit() + + unique_project_name = f"TestProject-{random_lower_string()}" + project_description = "This is a test project description." + project_data = ProjectCreate( + name=unique_project_name, + description=project_description, + is_active=True, + organization_id=organization.id, + ) + project = create_project(session=db, project_create=project_data) + db.commit() + + return project + + +# Test retrieving projects +def test_read_projects(db: Session, superuser_token_headers: dict[str, str]): + response = client.get( + f"{settings.API_V1_STR}/projects/", headers=superuser_token_headers + ) + assert response.status_code == 200 + response_data = response.json() + assert "data" in response_data + assert isinstance(response_data["data"], list) # Test creating a project def test_create_new_project(db: Session, superuser_token_headers: dict[str, str]): - organization = create_test_organization(db) + unique_org_name = f"TestOrg-{random_lower_string()}" + org_data = OrganizationCreate(name=unique_org_name, is_active=True) + organization = create_organization(session=db, org_create=org_data) + db.commit() - unique_project_name = "TestProject" + unique_project_name = f"TestProject-{random_lower_string()}" project_description = "This is a test project description." project_data = ProjectCreate( name=unique_project_name, @@ -47,17 +82,6 @@ def test_create_new_project(db: Session, superuser_token_headers: dict[str, str] assert created_project["data"]["organization_id"] == organization.id -# Test retrieving projects -def test_read_projects(db: Session, superuser_token_headers: dict[str, str]): - response = client.get( - f"{settings.API_V1_STR}/projects/", headers=superuser_token_headers - ) - assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert isinstance(response_data["data"], list) - - # Test updating a project def test_update_project( db: Session, test_project: Project, superuser_token_headers: dict[str, str] diff --git a/backend/app/tests/crud/test_api_key.py b/backend/app/tests/crud/test_api_key.py index a0f89ba3..197756bd 100644 --- a/backend/app/tests/crud/test_api_key.py +++ b/backend/app/tests/crud/test_api_key.py @@ -1,45 +1,86 @@ +import uuid +import pytest +from datetime import datetime from sqlmodel import Session, select - from app.crud import api_key as api_key_crud -from app.models import APIKey -from app.tests.utils.utils import get_non_existent_id -from app.tests.utils.user import create_random_user -from app.tests.utils.test_data import create_test_api_key, create_test_project +from app.models import APIKey, User, Organization, Project +from app.tests.utils.utils import random_email +from app.core.security import get_password_hash, verify_password, decrypt_api_key +from app.core.exception_handlers import HTTPException -def test_create_api_key(db: Session) -> None: - user = create_random_user(db) - project = create_test_project(db) +# Helper function to create a user +def create_test_user(db: Session) -> User: + user = User(email=random_email(), hashed_password=get_password_hash("password123")) + db.add(user) + db.commit() + db.refresh(user) + return user + - api_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id +# Helper function to create an organization with a random name +def create_test_organization(db: Session) -> Organization: + org = Organization( + name=f"Test Organization {uuid.uuid4()}", description="Test Organization" ) + db.add(org) + db.commit() + db.refresh(org) + return org + + +def create_test_project(db: Session, organization_id: int) -> Project: + project = Project( + name=f"Test Project {uuid.uuid4()}", + description="Test project", + organization_id=organization_id, + is_active=True, + ) + db.add(project) + db.commit() + db.refresh(project) + return project + + +def test_create_api_key(db: Session) -> None: + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, org.id) + + api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) assert api_key.key.startswith("ApiKey ") assert len(api_key.key) > 32 - assert api_key.organization_id == project.organization_id + assert api_key.organization_id == org.id assert api_key.user_id == user.id assert api_key.project_id == project.id def test_get_api_key(db: Session) -> None: - api_key = create_test_api_key(db) - retrieved_key = api_key_crud.get_api_key(db, api_key.id) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, org.id) + + created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) + retrieved_key = api_key_crud.get_api_key(db, created_key.id) assert retrieved_key is not None - assert retrieved_key.id == api_key.id + assert retrieved_key.id == created_key.id assert retrieved_key.key.startswith("ApiKey ") - assert retrieved_key.project_id == api_key.project_id + assert retrieved_key.project_id == project.id def test_get_api_key_not_found(db: Session) -> None: - api_key_id = get_non_existent_id(db, APIKey) - result = api_key_crud.get_api_key(db, api_key_id) + result = api_key_crud.get_api_key(db, 9999) # Non-existent ID assert result is None def test_delete_api_key(db: Session) -> None: - api_key = create_test_api_key(db) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, org.id) + + api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) api_key_crud.delete_api_key(db, api_key.id) deleted_key = db.exec(select(APIKey).where(APIKey.id == api_key.id)).first() @@ -50,7 +91,13 @@ def test_delete_api_key(db: Session) -> None: def test_get_api_key_by_value(db: Session) -> None: - api_key = create_test_api_key(db) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, org.id) + + # Create an API key + api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) + # Get the raw key that was returned during creation raw_key = api_key.key # Test retrieving the API key by its value @@ -58,8 +105,8 @@ def test_get_api_key_by_value(db: Session) -> None: assert retrieved_key is not None assert retrieved_key.id == api_key.id - assert retrieved_key.organization_id == api_key.organization_id - assert retrieved_key.user_id == api_key.user_id + assert retrieved_key.organization_id == org.id + assert retrieved_key.user_id == user.id # The key should be in its original format assert retrieved_key.key == raw_key # Should be exactly the same key assert retrieved_key.key.startswith("ApiKey ") @@ -67,12 +114,11 @@ def test_get_api_key_by_value(db: Session) -> None: def test_get_api_key_by_project_user(db: Session) -> None: - user = create_random_user(db) - project = create_test_project(db) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, org.id) - created_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) + created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) retrieved_key = api_key_crud.get_api_key_by_project_user(db, project.id, user.id) assert retrieved_key is not None @@ -82,13 +128,11 @@ def test_get_api_key_by_project_user(db: Session) -> None: def test_get_api_keys_by_project(db: Session) -> None: - user = create_random_user(db) - project = create_test_project(db) - - created_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) + user = create_test_user(db) + org = create_test_organization(db) + project = create_test_project(db, org.id) + created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) retrieved_keys = api_key_crud.get_api_keys_by_project(db, project.id) assert retrieved_keys is not None diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 1ff20105..f47ea8d6 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -1,26 +1,39 @@ +import uuid from sqlmodel import Session import pytest +from datetime import datetime -from app.crud import ( - set_creds_for_org, - get_creds_by_org, - get_provider_credential, - update_creds_for_org, - remove_provider_credential, - remove_creds_for_org, -) -from app.models import CredsCreate, CredsUpdate -from app.core.providers import Provider -from app.tests.utils.test_data import ( - create_test_project, - create_test_credential, - test_credential_data, -) +from app.crud import credentials as credentials_crud +from app.models import Credential, CredsCreate, CredsUpdate, Organization, Project +from app.tests.utils.utils import random_email +from app.core.security import get_password_hash + + +def create_organization_and_project(db: Session) -> tuple[Organization, Project]: + """Helper function to create an organization and a project.""" + organization = Organization( + name=f"Test Organization {uuid.uuid4()}", is_active=True + ) + db.add(organization) + db.commit() + db.refresh(organization) + + project = Project( + name=f"Test Project {uuid.uuid4()}", + description="A test project", + organization_id=organization.id, + is_active=True, + ) + db.add(project) + db.commit() + db.refresh(project) + + return organization, project def test_set_creds_for_org(db: Session) -> None: """Test setting credentials for an organization.""" - project = create_test_project(db) + organization, _ = create_organization_and_project(db) # Test credentials for supported providers creds_data = { @@ -32,26 +45,42 @@ def test_set_creds_for_org(db: Session) -> None: }, } - creds_create = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, - credential=creds_data, - ) + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - created_creds = set_creds_for_org(session=db, creds_add=creds_create) + created_creds = credentials_crud.set_creds_for_org( + session=db, creds_add=creds_create + ) assert len(created_creds) == 2 - assert all( - cred.organization_id == project.organization_id for cred in created_creds - ) - assert all(cred.project_id == project.id for cred in created_creds) + assert all(cred.organization_id == organization.id for cred in created_creds) assert all(cred.is_active for cred in created_creds) assert {cred.provider for cred in created_creds} == {"openai", "langfuse"} +def test_set_creds_for_org_with_project(db: Session) -> None: + """Test setting credentials for an organization with a specific project.""" + organization, project = create_organization_and_project(db) + + creds_data = {"openai": {"api_key": "test-openai-key"}} + + creds_create = CredsCreate( + organization_id=organization.id, project_id=project.id, credential=creds_data + ) + + created_creds = credentials_crud.set_creds_for_org( + session=db, creds_add=creds_create + ) + + assert len(created_creds) == 1 + assert created_creds[0].organization_id == organization.id + assert created_creds[0].project_id == project.id + assert created_creds[0].provider == "openai" + assert created_creds[0].is_active + + def test_get_creds_by_org(db: Session) -> None: """Test retrieving all credentials for an organization.""" - project = create_test_project(db) + organization, _ = create_organization_and_project(db) # Set up test credentials creds_data = { @@ -63,81 +92,101 @@ def test_get_creds_by_org(db: Session) -> None: }, } - creds_create = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, - credential=creds_data, - ) - set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) # Test retrieving credentials - retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) + retrieved_creds = credentials_crud.get_creds_by_org( + session=db, org_id=organization.id + ) assert len(retrieved_creds) == 2 - assert all( - cred.organization_id == project.organization_id for cred in retrieved_creds - ) + assert all(cred.organization_id == organization.id for cred in retrieved_creds) assert {cred.provider for cred in retrieved_creds} == {"openai", "langfuse"} def test_get_provider_credential(db: Session) -> None: """Test retrieving credentials for a specific provider.""" - creds_create = test_credential_data(db) - original_api_key = creds_create.credential[Provider.OPENAI.value]["api_key"] + organization, _ = create_organization_and_project(db) + + # Set up test credentials + creds_data = {"openai": {"api_key": "test-openai-key"}} + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) - set_creds_for_org(session=db, creds_add=creds_create) # Test retrieving specific provider credentials - retrieved_cred = get_provider_credential( - session=db, org_id=creds_create.organization_id, provider="openai" + retrieved_cred = credentials_crud.get_provider_credential( + session=db, org_id=organization.id, provider="openai" ) assert retrieved_cred is not None assert "api_key" in retrieved_cred - assert retrieved_cred["api_key"] == original_api_key + assert retrieved_cred["api_key"] == "test-openai-key" def test_update_creds_for_org(db: Session) -> None: """Test updating credentials for a provider.""" - creds = create_test_credential(db)[0] + organization, _ = create_organization_and_project(db) + + # Set up initial credentials + initial_creds = {"openai": {"api_key": "initial-key"}} + creds_create = CredsCreate( + organization_id=organization.id, credential=initial_creds + ) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) # Update credentials updated_creds = {"api_key": "updated-key"} creds_update = CredsUpdate(provider="openai", credential=updated_creds) - updated = update_creds_for_org( - session=db, org_id=creds.organization_id, creds_in=creds_update + updated = credentials_crud.update_creds_for_org( + session=db, org_id=organization.id, creds_in=creds_update ) assert len(updated) == 1 assert updated[0].provider == "openai" - retrieved_cred = get_provider_credential( - session=db, org_id=creds.organization_id, provider="openai" + retrieved_cred = credentials_crud.get_provider_credential( + session=db, org_id=organization.id, provider="openai" ) assert retrieved_cred["api_key"] == "updated-key" def test_remove_provider_credential(db: Session) -> None: """Test removing credentials for a specific provider.""" - creds = create_test_credential(db)[0] + organization, _ = create_organization_and_project(db) + + # Set up test credentials + creds_data = { + "openai": {"api_key": "test-openai-key"}, + "langfuse": { + "public_key": "test-public-key", + "secret_key": "test-secret-key", + "host": "https://cloud.langfuse.com", + }, + } + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) # Remove one provider's credentials - removed = remove_provider_credential( - session=db, org_id=creds.organization_id, provider="openai" + removed = credentials_crud.remove_provider_credential( + session=db, org_id=organization.id, provider="openai" ) assert removed.is_active is False assert removed.updated_at is not None # Verify the credentials are no longer retrievable - retrieved_cred = get_provider_credential( - session=db, org_id=creds.organization_id, provider="openai" + retrieved_cred = credentials_crud.get_provider_credential( + session=db, org_id=organization.id, provider="openai" ) assert retrieved_cred is None def test_remove_creds_for_org(db: Session) -> None: """Test removing all credentials for an organization.""" - project = create_test_project(db) + organization, _ = create_organization_and_project(db) # Set up test credentials creds_data = { @@ -149,55 +198,49 @@ def test_remove_creds_for_org(db: Session) -> None: }, } - creds_create = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, - credential=creds_data, - ) - set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) # Remove all credentials - removed = remove_creds_for_org(session=db, org_id=project.organization_id) + removed = credentials_crud.remove_creds_for_org(session=db, org_id=organization.id) assert len(removed) == 2 assert all(not cred.is_active for cred in removed) assert all(cred.updated_at is not None for cred in removed) # Verify no credentials are retrievable - retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) + retrieved_creds = credentials_crud.get_creds_by_org( + session=db, org_id=organization.id + ) assert len(retrieved_creds) == 0 def test_invalid_provider(db: Session) -> None: """Test handling of invalid provider names.""" - project = create_test_project(db) + organization, _ = create_organization_and_project(db) # Test with unsupported provider creds_data = {"gemini": {"api_key": "test-key"}} - creds_create = CredsCreate( - organization_id=project.organization_id, credential=creds_data - ) + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) with pytest.raises(ValueError, match="Unsupported provider"): - set_creds_for_org(session=db, creds_add=creds_create) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) def test_duplicate_provider_credentials(db: Session) -> None: """Test handling of duplicate provider credentials.""" - project = create_test_project(db) + organization, _ = create_organization_and_project(db) # Set up initial credentials creds_data = {"openai": {"api_key": "test-key"}} - creds_create = CredsCreate( - organization_id=project.organization_id, credential=creds_data - ) - set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) # Verify credentials exist and are active - existing_creds = get_provider_credential( - session=db, org_id=project.organization_id, provider="openai" + existing_creds = credentials_crud.get_provider_credential( + session=db, org_id=organization.id, provider="openai" ) assert existing_creds is not None assert "api_key" in existing_creds @@ -206,7 +249,7 @@ def test_duplicate_provider_credentials(db: Session) -> None: def test_langfuse_credential_validation(db: Session) -> None: """Test validation of Langfuse credentials structure.""" - project = create_test_project(db) + organization, _ = create_organization_and_project(db) # Test with missing required fields invalid_creds = { @@ -218,11 +261,11 @@ def test_langfuse_credential_validation(db: Session) -> None: } creds_create = CredsCreate( - organization_id=project.organization_id, credential=invalid_creds + organization_id=organization.id, credential=invalid_creds ) with pytest.raises(ValueError): - set_creds_for_org(session=db, creds_add=creds_create) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) # Test with valid Langfuse credentials valid_creds = { @@ -233,10 +276,10 @@ def test_langfuse_credential_validation(db: Session) -> None: } } - creds_create = CredsCreate( - organization_id=project.organization_id, credential=valid_creds - ) + creds_create = CredsCreate(organization_id=organization.id, credential=valid_creds) - created_creds = set_creds_for_org(session=db, creds_add=creds_create) + created_creds = credentials_crud.set_creds_for_org( + session=db, creds_add=creds_create + ) assert len(created_creds) == 1 assert created_creds[0].provider == "langfuse" diff --git a/backend/app/tests/crud/test_org.py b/backend/app/tests/crud/test_org.py index 052988ff..7efba0ec 100644 --- a/backend/app/tests/crud/test_org.py +++ b/backend/app/tests/crud/test_org.py @@ -3,7 +3,6 @@ from app.crud.organization import create_organization, get_organization_by_id from app.models import Organization, OrganizationCreate from app.tests.utils.utils import random_lower_string, get_non_existent_id -from app.tests.utils.test_data import create_test_organization def test_create_organization(db: Session) -> None: @@ -14,12 +13,14 @@ def test_create_organization(db: Session) -> None: assert org.name == name assert org.id is not None - assert org.is_active is True + assert org.is_active is True # Default should be active def test_get_organization_by_id(db: Session) -> None: """Test retrieving an organization by ID.""" - org = create_test_organization(db) + name = random_lower_string() + org_in = OrganizationCreate(name=name) + org = create_organization(session=db, org_create=org_in) fetched_org = get_organization_by_id(session=db, org_id=org.id) assert fetched_org @@ -30,5 +31,7 @@ def test_get_organization_by_id(db: Session) -> None: def test_get_non_existent_organization(db: Session) -> None: """Test retrieving a non-existent organization should return None.""" org_id = get_non_existent_id(db, Organization) - fetched_org = get_organization_by_id(session=db, org_id=org_id) + fetched_org = get_organization_by_id( + session=db, org_id=org_id + ) # Assuming ID 999 does not exist assert fetched_org is None diff --git a/backend/app/tests/crud/test_project.py b/backend/app/tests/crud/test_project.py index 041821cc..c9b135ce 100644 --- a/backend/app/tests/crud/test_project.py +++ b/backend/app/tests/crud/test_project.py @@ -2,20 +2,22 @@ from sqlmodel import Session from fastapi import HTTPException -from app.models import Project, ProjectCreate +from app.models import Project, ProjectCreate, Organization from app.crud.project import ( create_project, get_project_by_id, get_projects_by_organization, validate_project, ) -from app.tests.utils.utils import random_lower_string, get_non_existent_id -from app.tests.utils.test_data import create_test_project, create_test_organization +from app.tests.utils.utils import random_lower_string def test_create_project(db: Session) -> None: """Test creating a project linked to an organization.""" - org = create_test_organization(db) + org = Organization(name=random_lower_string()) + db.add(org) + db.commit() + db.refresh(org) project_name = random_lower_string() project_data = ProjectCreate( @@ -35,7 +37,17 @@ def test_create_project(db: Session) -> None: def test_get_project_by_id(db: Session) -> None: """Test retrieving a project by ID.""" - project = create_test_project(db) + org = Organization(name=random_lower_string()) + db.add(org) + db.commit() + db.refresh(org) + + project_name = random_lower_string() + project_data = ProjectCreate( + name=project_name, description="Test", organization_id=org.id + ) + + project = create_project(session=db, project_create=project_data) fetched_project = get_project_by_id(session=db, project_id=project.id) assert fetched_project is not None @@ -45,43 +57,53 @@ def test_get_project_by_id(db: Session) -> None: def test_get_projects_by_organization(db: Session) -> None: """Test retrieving all projects for an organization.""" - org = create_test_organization(db) + org = Organization(name=random_lower_string()) + db.add(org) + db.commit() + db.refresh(org) project_1 = create_project( session=db, project_create=ProjectCreate( - name="Project 1", - description="Test project 1", - is_active=True, - organization_id=org.id, + name=random_lower_string(), organization_id=org.id ), ) - project_2 = create_project( session=db, project_create=ProjectCreate( - name="Project 2", - description="Test project 2", - is_active=True, - organization_id=org.id, + name=random_lower_string(), organization_id=org.id ), ) projects = get_projects_by_organization(session=db, org_id=org.id) + assert len(projects) == 2 assert project_1 in projects assert project_2 in projects def test_get_non_existent_project(db: Session) -> None: - non_existent_project_id = get_non_existent_id(db, Project) - fetched_project = get_project_by_id(session=db, project_id=non_existent_project_id) + """Test retrieving a non-existent project should return None.""" + fetched_project = get_project_by_id(session=db, project_id=999) assert fetched_project is None def test_validate_project_success(db: Session) -> None: """Test that a valid and active project passes validation.""" - project = create_test_project(db) + org = Organization(name=random_lower_string()) + db.add(org) + db.commit() + db.refresh(org) + + project = create_project( + session=db, + project_create=ProjectCreate( + name=random_lower_string(), + description="Valid project", + is_active=True, + organization_id=org.id, + ), + ) validated_project = validate_project(session=db, project_id=project.id) assert validated_project.id == project.id @@ -89,14 +111,16 @@ def test_validate_project_success(db: Session) -> None: def test_validate_project_not_found(db: Session) -> None: """Test that validation fails when project does not exist.""" - non_existent_project_id = get_non_existent_id(db, Project) with pytest.raises(HTTPException, match="Project not found"): - validate_project(session=db, project_id=non_existent_project_id) + validate_project(session=db, project_id=9999) def test_validate_project_inactive(db: Session) -> None: """Test that validation fails when project is inactive.""" - org = create_test_organization(db) + org = Organization(name=random_lower_string()) + db.add(org) + db.commit() + db.refresh(org) inactive_project = create_project( session=db, diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 0a68125c..ded5c728 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -7,7 +7,6 @@ from app.core.config import settings from app.models import Collection, Organization, Project from app.tests.utils.utils import get_user_id_by_email -from app.tests.utils.test_data import create_test_project from app.crud import create_api_key @@ -25,14 +24,19 @@ def get_collection(db: Session, client=None): owner_id = get_user_id_by_email(db) # Step 1: Create real organization and project entries - project = create_test_project(db) + organization = Organization(name=f"Test Org {uuid4()}") + db.add(organization) + db.commit() + db.refresh(organization) + + project = Project(name="Test Project {uuid4()}", organization_id=organization.id) + db.add(project) + db.commit() + db.refresh(project) # Step 2: Create API key for user with valid foreign keys create_api_key( - db, - organization_id=project.organization_id, - user_id=owner_id, - project_id=project.id, + db, organization_id=organization.id, user_id=owner_id, project_id=project.id ) if client is None: @@ -47,7 +51,7 @@ def get_collection(db: Session, client=None): return Collection( owner_id=owner_id, - organization_id=project.organization_id, + organization_id=organization.id, project_id=project.id, llm_service_id=assistant.id, llm_service_name=constants.llm_service_name, diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py deleted file mode 100644 index a623cd1e..00000000 --- a/backend/app/tests/utils/test_data.py +++ /dev/null @@ -1,114 +0,0 @@ -from sqlmodel import Session -from app.models import ( - Organization, - Project, - APIKey, - Credential, - OrganizationCreate, - ProjectCreate, - APIKeyPublic, - CredsCreate, -) -from app.crud import ( - create_organization, - create_project, - create_api_key, - set_creds_for_org, -) -from app.core.providers import Provider -from app.tests.utils.user import create_random_user -from app.tests.utils.utils import random_lower_string, generate_random_string - - -def create_test_organization(db: Session) -> Organization: - """ - Creates and returns a test organization with a unique name. - - Persists the organization to the database. - """ - name = f"TestOrg-{random_lower_string()}" - org_in = OrganizationCreate(name=name, is_active=True) - return create_organization(session=db, org_create=org_in) - - -def create_test_project(db: Session) -> Project: - """ - Creates and returns a test project under a newly created test organization. - - Persists both the organization and the project to the database. - - """ - org = create_test_organization(db) - name = f"TestProject-{random_lower_string()}" - project_description = "This is a test project description." - project_in = ProjectCreate( - name=name, - description=project_description, - is_active=True, - organization_id=org.id, - ) - return create_project(session=db, project_create=project_in) - - -def create_test_api_key(db: Session) -> APIKey: - """ - Creates and returns an API key for a test project and test user. - - Persists a test user, organization, project, and API key to the database - """ - project = create_test_project(db) - user = create_random_user(db) - api_key = create_api_key( - db, - organization_id=project.organization_id, - user_id=user.id, - project_id=project.id, - ) - return api_key - - -def test_credential_data(db: Session) -> CredsCreate: - """ - Returns credential data for a test project in the form of a CredsCreate schema. - - Use this when you just need credential input data without persisting it to the database. - """ - project = create_test_project(db) - api_key = "sk-" + generate_random_string(10) - creds_data = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, - is_active=True, - credential={ - Provider.OPENAI.value: { - "api_key": api_key, - "model": "gpt-4", - "temperature": 0.7, - } - }, - ) - return creds_data - - -def create_test_credential(db: Session) -> Credential: - """ - Creates and returns a test credential for a test project. - - Persists the organization, project, and credential to the database. - - """ - project = create_test_project(db) - api_key = "sk-" + generate_random_string(10) - creds_data = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, - is_active=True, - credential={ - Provider.OPENAI.value: { - "api_key": api_key, - "model": "gpt-4", - "temperature": 0.7, - } - }, - ) - return set_creds_for_org(session=db, creds_add=creds_data) diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 351444b4..9cba7aae 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -26,10 +26,6 @@ def random_lower_string() -> str: return "".join(random.choices(string.ascii_lowercase, k=32)) -def generate_random_string(length=10): - return "".join(random.choices(string.ascii_letters + string.digits, k=length)) - - def random_email() -> str: return f"{random_lower_string()}@{random_lower_string()}.com"