From 7cce87362ced48f10f8f6a2e2ad152bd3f647b5c Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 9 Jul 2025 12:42:19 +0530 Subject: [PATCH 01/15] utility functions --- backend/app/crud/__init__.py | 2 + backend/app/tests/api/routes/test_api_key.py | 95 +++------ backend/app/tests/api/routes/test_creds.py | 195 +++++++----------- backend/app/tests/api/routes/test_org.py | 47 ++--- backend/app/tests/api/routes/test_project.py | 60 ++---- backend/app/tests/crud/test_api_key.py | 108 +++------- backend/app/tests/crud/test_credentials.py | 205 ++++++++----------- backend/app/tests/crud/test_org.py | 11 +- backend/app/tests/crud/test_project.py | 66 ++---- backend/app/tests/utils/collection.py | 18 +- backend/app/tests/utils/test_data.py | 114 +++++++++++ backend/app/tests/utils/utils.py | 4 + 12 files changed, 404 insertions(+), 521 deletions(-) create mode 100644 backend/app/tests/utils/test_data.py diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 6bf4d5da..5c6fd102 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -37,6 +37,8 @@ get_key_by_org, update_creds_for_org, remove_creds_for_org, + get_provider_credential, + remove_provider_credential, ) from .thread_results import upsert_thread_result, get_thread_result diff --git a/backend/app/tests/api/routes/test_api_key.py b/backend/app/tests/api/routes/test_api_key.py index 552530cc..7d018a66 100644 --- a/backend/app/tests/api/routes/test_api_key.py +++ b/backend/app/tests/api/routes/test_api_key.py @@ -1,51 +1,23 @@ -import uuid -import pytest from fastapi.testclient import TestClient from sqlmodel import Session + from app.main import app -from app.models import APIKey, User, Organization, Project +from app.models import APIKey from app.core.config import settings -from app.crud.api_key import create_api_key -from app.tests.utils.utils import random_email -from app.core.security import get_password_hash +from app.tests.utils.utils import get_non_existent_id +from app.tests.utils.user import create_random_user +from app.tests.utils.test_data import ( + create_test_api_key, + create_test_project, + create_test_organization, +) client = TestClient(app) -def create_test_user(db: Session) -> User: - user = User( - email=random_email(), - hashed_password=get_password_hash("password123"), - is_superuser=True, - ) - db.add(user) - db.commit() - db.refresh(user) - return user - - -def create_test_organization(db: Session) -> Organization: - org = Organization( - name=f"Test Organization {uuid.uuid4()}", description="Test Organization" - ) - db.add(org) - db.commit() - db.refresh(org) - return org - - -def create_test_project(db: Session, organization_id: int) -> Project: - project = Project(name="Test Project", organization_id=organization_id) - db.add(project) - db.commit() - db.refresh(project) - return project - - def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) + user = create_random_user(db) + project = create_test_project(db) response = client.post( f"{settings.API_V1_STR}/apikeys", @@ -57,14 +29,13 @@ def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]): assert data["success"] is True assert "id" in data["data"] assert "key" in data["data"] - assert data["data"]["organization_id"] == org.id + assert data["data"]["organization_id"] == project.organization_id assert data["data"]["user_id"] == user.id def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) + user = create_random_user(db) + project = create_test_project(db) client.post( f"{settings.API_V1_STR}/apikeys", @@ -81,16 +52,11 @@ def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) - api_key = create_api_key( - db, organization_id=org.id, user_id=user.id, project_id=project.id - ) + api_key = create_test_api_key(db) response = client.get( f"{settings.API_V1_STR}/apikeys", - params={"project_id": project.id}, + params={"project_id": api_key.project_id}, headers=superuser_token_headers, ) assert response.status_code == 200 @@ -100,17 +66,12 @@ def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]): assert len(data["data"]) > 0 first_key = data["data"][0] - assert first_key["organization_id"] == org.id - assert first_key["user_id"] == user.id + assert first_key["organization_id"] == api_key.organization_id + assert first_key["user_id"] == api_key.user_id def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) - api_key = create_api_key( - db, organization_id=org.id, user_id=user.id, project_id=project.id - ) + api_key = create_test_api_key(db) response = client.get( f"{settings.API_V1_STR}/apikeys/{api_key.id}", @@ -121,12 +82,14 @@ def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): assert data["success"] is True assert data["data"]["id"] == api_key.id assert data["data"]["organization_id"] == api_key.organization_id - assert data["data"]["user_id"] == user.id + assert data["data"]["user_id"] == api_key.user_id def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, str]): + api_key_id = get_non_existent_id(db, APIKey) + print(api_key_id) response = client.get( - f"{settings.API_V1_STR}/apikeys/999999", + f"{settings.API_V1_STR}/apikeys/{api_key_id}", headers=superuser_token_headers, ) assert response.status_code == 404 @@ -134,12 +97,7 @@ def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, organization_id=org.id) - api_key = create_api_key( - db, organization_id=org.id, user_id=user.id, project_id=project.id - ) + api_key = create_test_api_key(db) response = client.delete( f"{settings.API_V1_STR}/apikeys/{api_key.id}", @@ -154,11 +112,10 @@ def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]): def test_revoke_nonexistent_api_key( db: Session, superuser_token_headers: dict[str, str] ): - user = create_test_user(db) - org = create_test_organization(db) + api_key_id = get_non_existent_id(db, APIKey) response = client.delete( - f"{settings.API_V1_STR}/apikeys/999999", + f"{settings.API_V1_STR}/apikeys/{api_key_id}", headers=superuser_token_headers, ) assert response.status_code == 404 diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 428d4c1c..5d0ff141 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -1,59 +1,37 @@ import pytest from fastapi.testclient import TestClient from sqlmodel import Session -import random -import string from app.main import app -from app.api.deps import get_db -from app.crud.credentials import set_creds_for_org -from app.models import CredsCreate, Organization, OrganizationCreate, Project +from app.models import Organization, Project from app.core.config import settings -from app.core.security import encrypt_api_key from app.core.providers import Provider from app.models.credentials import Credential from app.core.security import decrypt_credentials +from app.tests.utils.utils import generate_random_string, get_non_existent_id +from app.tests.utils.test_data import ( + create_test_credential, + create_test_organization, + create_test_project, + test_credential_data, +) client = TestClient(app) -def generate_random_string(length=10): - return "".join(random.choices(string.ascii_letters + string.digits, k=length)) - - @pytest.fixture -def create_organization_and_creds(db: Session): - unique_org_name = "Test Organization " + generate_random_string(5) - org = Organization(name=unique_org_name, is_active=True) - db.add(org) - db.commit() - db.refresh(org) - - api_key = "sk-" + generate_random_string(10) - creds_data = CredsCreate( - organization_id=org.id, - is_active=True, - credential={ - Provider.OPENAI.value: { - "api_key": api_key, - "model": "gpt-4", - "temperature": 0.7, - } - }, - ) - return org, creds_data +def create_test_credentials(db: Session): + return create_test_credential(db) -def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]): - org = Organization(name="Org for Set Creds", is_active=True) - db.add(org) - db.commit() - db.refresh(org) +def test_set_creds(db: Session, superuser_token_headers: dict[str, str]): + project = create_test_project(db) api_key = "sk-" + generate_random_string(10) creds_data = { - "organization_id": org.id, + "organization_id": project.organization_id, + "project_id": project.id, "is_active": True, "credential": { Provider.OPENAI.value: { @@ -74,7 +52,7 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) data = response.json()["data"] assert isinstance(data, list) assert len(data) == 1 - assert data[0]["organization_id"] == org.id + assert data[0]["organization_id"] == project.organization_id assert data[0]["provider"] == Provider.OPENAI.value assert data[0]["credential"]["model"] == "gpt-4" @@ -82,19 +60,9 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) def test_set_creds_for_invalid_project_org_relationship( db: Session, superuser_token_headers: dict[str, str] ): - # Setup: Create two organizations and one project for each - org1 = Organization(name="Org 1", is_active=True) - org2 = Organization(name="Org 2", is_active=True) - db.add_all([org1, org2]) - db.commit() - db.refresh(org1) - db.refresh(org2) - - project2 = Project(name="Project Org2", organization_id=org2.id) - db.add(project2) - db.commit() - - # Invalid case: Organization mismatch (org1's creds for project2 of org2) + org1 = create_test_organization(db) + project2 = create_test_project(db) + creds_data_invalid = { "organization_id": org1.id, "is_active": True, @@ -119,15 +87,13 @@ def test_set_creds_for_project_not_found( db: Session, superuser_token_headers: dict[str, str] ): # Setup: Create an organization but no project - org = Organization(name="Org for Project Not Found", is_active=True) - db.add(org) - db.commit() - db.refresh(org) + org = create_test_organization(db) + non_existent_project_id = get_non_existent_id(db, Project) creds_data_invalid_project = { "organization_id": org.id, "is_active": True, - "project_id": 99999, + "project_id": non_existent_project_id, "credential": {Provider.OPENAI.value: {"api_key": "sk-123", "model": "gpt-4"}}, } @@ -142,13 +108,12 @@ def test_set_creds_for_project_not_found( def test_read_credentials_with_creds( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) - + creds = create_test_credentials[0] + print(creds) response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", headers=superuser_token_headers, ) @@ -156,7 +121,7 @@ def test_read_credentials_with_creds( data = response.json()["data"] assert isinstance(data, list) assert len(data) == 1 - assert data[0]["organization_id"] == org.id + assert data[0]["organization_id"] == creds.organization_id assert data[0]["provider"] == Provider.OPENAI.value assert data[0]["credential"]["model"] == "gpt-4" @@ -164,8 +129,9 @@ def test_read_credentials_with_creds( def test_read_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): + non_existent_creds_id = get_non_existent_id(db, Credential) response = client.get( - f"{settings.API_V1_STR}/credentials/999999", + f"{settings.API_V1_STR}/credentials/{non_existent_creds_id}", headers=superuser_token_headers, ) assert response.status_code == 404 @@ -173,13 +139,13 @@ def test_read_credentials_not_found( def test_read_provider_credential( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] + print(creds) response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) @@ -190,9 +156,9 @@ def test_read_provider_credential( def test_read_provider_credential_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, _ = create_organization_and_creds + org = create_test_organization(db) response = client.get( f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", @@ -204,10 +170,9 @@ def test_read_provider_credential_not_found( def test_update_credentials( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] update_data = { "provider": Provider.OPENAI.value, @@ -219,7 +184,7 @@ def test_update_credentials( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", json=update_data, headers=superuser_token_headers, ) @@ -234,21 +199,14 @@ def test_update_credentials( def test_update_credentials_failed_update( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] - org_without_creds = Organization(name="Org Without Creds", is_active=True) - db.add(org_without_creds) - db.commit() - db.refresh(org_without_creds) + org_without_creds = create_test_organization(db) existing_creds = ( - db.query(Credential) - .filter(Credential.organization_id == org_without_creds.id) - .all() + db.query(Credential).filter(creds.organization_id == org_without_creds.id).all() ) assert len(existing_creds) == 0 @@ -262,7 +220,7 @@ def test_update_credentials_failed_update( } response_invalid_org = client.patch( - f"{settings.API_V1_STR}/credentials/{org_without_creds.id}", # Valid org id but no creds + f"{settings.API_V1_STR}/credentials/{org_without_creds.id}", json=update_data, headers=superuser_token_headers, ) @@ -276,8 +234,7 @@ def test_update_credentials_failed_update( def test_update_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): - # Create a non-existent organization ID - non_existent_org_id = 999999 + non_existent_org_id = get_non_existent_id(db, Organization) update_data = { "provider": Provider.OPENAI.value, @@ -299,13 +256,12 @@ def test_update_credentials_not_found( def test_delete_provider_credential( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) @@ -315,9 +271,9 @@ def test_delete_provider_credential( def test_delete_provider_credential_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, _ = create_organization_and_creds + org = create_test_organization(db) response = client.delete( f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", @@ -330,13 +286,12 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - org, creds_data = create_organization_and_creds - set_creds_for_org(session=db, creds_add=creds_data) + creds = create_test_credentials[0] response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", headers=superuser_token_headers, ) @@ -346,7 +301,7 @@ def test_delete_all_credentials( # Verify the credentials are soft deleted response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", headers=superuser_token_headers, ) assert response.status_code == 404 # Expect 404 as credentials are soft deleted @@ -356,8 +311,9 @@ def test_delete_all_credentials( def test_delete_all_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): + non_existent_creds_id = get_non_existent_id(db, Credential) response = client.delete( - f"{settings.API_V1_STR}/credentials/999999", + f"{settings.API_V1_STR}/credentials/{non_existent_creds_id}", headers=superuser_token_headers, ) @@ -366,31 +322,33 @@ def test_delete_all_credentials_not_found( def test_duplicate_credential_creation( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, creds_data = create_organization_and_creds - # First create credentials + creds = test_credential_data(db) + response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data.dict(), + json=creds.dict(), headers=superuser_token_headers, ) + print(response) assert response.status_code == 200 # Try to create the same credentials again response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data.dict(), + json=creds.dict(), headers=superuser_token_headers, ) assert response.status_code == 400 + assert "already exist" in response.json()["error"] def test_multiple_provider_credentials( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, _ = create_organization_and_creds + org = create_test_organization(db) # Create OpenAI credentials openai_creds = { @@ -446,16 +404,14 @@ def test_multiple_provider_credentials( assert Provider.LANGFUSE.value in providers -def test_credential_encryption( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): - org, creds_data = create_organization_and_creds - original_api_key = creds_data.credential[Provider.OPENAI.value]["api_key"] +def test_credential_encryption(db: Session, superuser_token_headers: dict[str, str]): + creds = test_credential_data(db) + original_api_key = creds.credential[Provider.OPENAI.value]["api_key"] # Create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data.dict(), + json=creds.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 @@ -463,7 +419,7 @@ def test_credential_encryption( db_cred = ( db.query(Credential) .filter( - Credential.organization_id == org.id, + Credential.organization_id == creds.organization_id, Credential.provider == Provider.OPENAI.value, ) .first() @@ -479,22 +435,22 @@ def test_credential_encryption( def test_credential_encryption_consistency( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - org, creds_data = create_organization_and_creds - original_api_key = creds_data.credential[Provider.OPENAI.value]["api_key"] + creds = test_credential_data(db) + original_api_key = creds.credential[Provider.OPENAI.value]["api_key"] # Create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data.dict(), + json=creds.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 # Fetch the credentials through the API response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 @@ -515,15 +471,14 @@ def test_credential_encryption_consistency( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{org.id}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}", json=update_data, headers=superuser_token_headers, ) assert response.status_code == 200 - # Verify the updated value is also properly encrypted/decrypted response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 diff --git a/backend/app/tests/api/routes/test_org.py b/backend/app/tests/api/routes/test_org.py index 709bb7f5..607e1aef 100644 --- a/backend/app/tests/api/routes/test_org.py +++ b/backend/app/tests/api/routes/test_org.py @@ -1,44 +1,25 @@ import pytest from fastapi.testclient import TestClient -from sqlmodel import Session, select +from sqlmodel import Session -from app import crud from app.core.config import settings -from app.core.security import verify_password -from app.models import User, UserCreate -from app.tests.utils.utils import random_email, random_lower_string -from app.models import Organization, OrganizationCreate, OrganizationUpdate -from app.api.deps import get_db +from app.models import Organization from app.main import app from app.crud.organization import create_organization, get_organization_by_id +from app.tests.utils.test_data import create_test_organization client = TestClient(app) @pytest.fixture def test_organization(db: Session, superuser_token_headers: dict[str, str]): - unique_name = f"TestOrg-{random_lower_string()}" - org_data = OrganizationCreate(name=unique_name, is_active=True) - organization = create_organization(session=db, org_create=org_data) - db.commit() - return organization - - -# Test retrieving organizations -def test_read_organizations(db: Session, superuser_token_headers: dict[str, str]): - response = client.get( - f"{settings.API_V1_STR}/organizations/", headers=superuser_token_headers - ) - assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert isinstance(response_data["data"], list) + return create_test_organization(db) # Test creating an organization def test_create_organization(db: Session, superuser_token_headers: dict[str, str]): - unique_name = f"Org-{random_lower_string()}" - org_data = {"name": unique_name, "is_active": True} + org_name = "Test-Org" + org_data = {"name": org_name, "is_active": True} response = client.post( f"{settings.API_V1_STR}/organizations/", json=org_data, @@ -55,13 +36,25 @@ def test_create_organization(db: Session, superuser_token_headers: dict[str, str assert org.is_active == created_org_data["is_active"] +# Test retrieving organizations +def test_read_organizations(db: Session, superuser_token_headers: dict[str, str]): + response = client.get( + f"{settings.API_V1_STR}/organizations/", headers=superuser_token_headers + ) + assert response.status_code == 200 + response_data = response.json() + assert "data" in response_data + assert isinstance(response_data["data"], list) + + +# Updating an organization def test_update_organization( db: Session, test_organization: Organization, superuser_token_headers: dict[str, str], ): - unique_name = f"UpdatedOrg-{random_lower_string()}" # Ensure a unique name - update_data = {"name": unique_name, "is_active": False} + updated_name = "UpdatedOrg" + update_data = {"name": updated_name, "is_active": False} response = client.patch( f"{settings.API_V1_STR}/organizations/{test_organization.id}", diff --git a/backend/app/tests/api/routes/test_project.py b/backend/app/tests/api/routes/test_project.py index 98d1f96d..69b11cea 100644 --- a/backend/app/tests/api/routes/test_project.py +++ b/backend/app/tests/api/routes/test_project.py @@ -1,61 +1,26 @@ import pytest from fastapi.testclient import TestClient from sqlmodel import Session -from app.core.security import decrypt_api_key, verify_password from app.main import app from app.core.config import settings -from app.models import Project, ProjectCreate, ProjectUpdate -from app.models import Organization, OrganizationCreate, ProjectUpdate -from app.api.deps import get_db -from app.tests.utils.utils import random_lower_string, random_email -from app.crud.project import create_project, get_project_by_id -from app.crud.organization import create_organization -from app.crud import api_key as api_key_crud +from app.models import Project, ProjectCreate +from app.tests.utils.test_data import create_test_organization, create_test_project + client = TestClient(app) @pytest.fixture -def test_project(db: Session, superuser_token_headers: dict[str, str]): - unique_org_name = f"TestOrg-{random_lower_string()}" - org_data = OrganizationCreate(name=unique_org_name, is_active=True) - organization = create_organization(session=db, org_create=org_data) - db.commit() - - unique_project_name = f"TestProject-{random_lower_string()}" - project_description = "This is a test project description." - project_data = ProjectCreate( - name=unique_project_name, - description=project_description, - is_active=True, - organization_id=organization.id, - ) - project = create_project(session=db, project_create=project_data) - db.commit() - - return project - - -# Test retrieving projects -def test_read_projects(db: Session, superuser_token_headers: dict[str, str]): - response = client.get( - f"{settings.API_V1_STR}/projects/", headers=superuser_token_headers - ) - assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert isinstance(response_data["data"], list) +def test_project(db: Session) -> Project: + return create_test_project(db) # Test creating a project def test_create_new_project(db: Session, superuser_token_headers: dict[str, str]): - unique_org_name = f"TestOrg-{random_lower_string()}" - org_data = OrganizationCreate(name=unique_org_name, is_active=True) - organization = create_organization(session=db, org_create=org_data) - db.commit() + organization = create_test_organization(db) - unique_project_name = f"TestProject-{random_lower_string()}" + unique_project_name = "TestProject" project_description = "This is a test project description." project_data = ProjectCreate( name=unique_project_name, @@ -82,6 +47,17 @@ def test_create_new_project(db: Session, superuser_token_headers: dict[str, str] assert created_project["data"]["organization_id"] == organization.id +# Test retrieving projects +def test_read_projects(db: Session, superuser_token_headers: dict[str, str]): + response = client.get( + f"{settings.API_V1_STR}/projects/", headers=superuser_token_headers + ) + assert response.status_code == 200 + response_data = response.json() + assert "data" in response_data + assert isinstance(response_data["data"], list) + + # Test updating a project def test_update_project( db: Session, test_project: Project, superuser_token_headers: dict[str, str] diff --git a/backend/app/tests/crud/test_api_key.py b/backend/app/tests/crud/test_api_key.py index 197756bd..a0f89ba3 100644 --- a/backend/app/tests/crud/test_api_key.py +++ b/backend/app/tests/crud/test_api_key.py @@ -1,86 +1,45 @@ -import uuid -import pytest -from datetime import datetime from sqlmodel import Session, select -from app.crud import api_key as api_key_crud -from app.models import APIKey, User, Organization, Project -from app.tests.utils.utils import random_email -from app.core.security import get_password_hash, verify_password, decrypt_api_key -from app.core.exception_handlers import HTTPException - - -# Helper function to create a user -def create_test_user(db: Session) -> User: - user = User(email=random_email(), hashed_password=get_password_hash("password123")) - db.add(user) - db.commit() - db.refresh(user) - return user - -# Helper function to create an organization with a random name -def create_test_organization(db: Session) -> Organization: - org = Organization( - name=f"Test Organization {uuid.uuid4()}", description="Test Organization" - ) - db.add(org) - db.commit() - db.refresh(org) - return org - - -def create_test_project(db: Session, organization_id: int) -> Project: - project = Project( - name=f"Test Project {uuid.uuid4()}", - description="Test project", - organization_id=organization_id, - is_active=True, - ) - db.add(project) - db.commit() - db.refresh(project) - return project +from app.crud import api_key as api_key_crud +from app.models import APIKey +from app.tests.utils.utils import get_non_existent_id +from app.tests.utils.user import create_random_user +from app.tests.utils.test_data import create_test_api_key, create_test_project def test_create_api_key(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) + user = create_random_user(db) + project = create_test_project(db) - api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) + api_key = api_key_crud.create_api_key( + db, project.organization_id, user.id, project.id + ) assert api_key.key.startswith("ApiKey ") assert len(api_key.key) > 32 - assert api_key.organization_id == org.id + assert api_key.organization_id == project.organization_id assert api_key.user_id == user.id assert api_key.project_id == project.id def test_get_api_key(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) - - created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) - retrieved_key = api_key_crud.get_api_key(db, created_key.id) + api_key = create_test_api_key(db) + retrieved_key = api_key_crud.get_api_key(db, api_key.id) assert retrieved_key is not None - assert retrieved_key.id == created_key.id + assert retrieved_key.id == api_key.id assert retrieved_key.key.startswith("ApiKey ") - assert retrieved_key.project_id == project.id + assert retrieved_key.project_id == api_key.project_id def test_get_api_key_not_found(db: Session) -> None: - result = api_key_crud.get_api_key(db, 9999) # Non-existent ID + api_key_id = get_non_existent_id(db, APIKey) + result = api_key_crud.get_api_key(db, api_key_id) assert result is None def test_delete_api_key(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) - - api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) + api_key = create_test_api_key(db) api_key_crud.delete_api_key(db, api_key.id) deleted_key = db.exec(select(APIKey).where(APIKey.id == api_key.id)).first() @@ -91,13 +50,7 @@ def test_delete_api_key(db: Session) -> None: def test_get_api_key_by_value(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) - - # Create an API key - api_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) - # Get the raw key that was returned during creation + api_key = create_test_api_key(db) raw_key = api_key.key # Test retrieving the API key by its value @@ -105,8 +58,8 @@ def test_get_api_key_by_value(db: Session) -> None: assert retrieved_key is not None assert retrieved_key.id == api_key.id - assert retrieved_key.organization_id == org.id - assert retrieved_key.user_id == user.id + assert retrieved_key.organization_id == api_key.organization_id + assert retrieved_key.user_id == api_key.user_id # The key should be in its original format assert retrieved_key.key == raw_key # Should be exactly the same key assert retrieved_key.key.startswith("ApiKey ") @@ -114,11 +67,12 @@ def test_get_api_key_by_value(db: Session) -> None: def test_get_api_key_by_project_user(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) + user = create_random_user(db) + project = create_test_project(db) - created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) + created_key = api_key_crud.create_api_key( + db, project.organization_id, user.id, project.id + ) retrieved_key = api_key_crud.get_api_key_by_project_user(db, project.id, user.id) assert retrieved_key is not None @@ -128,11 +82,13 @@ def test_get_api_key_by_project_user(db: Session) -> None: def test_get_api_keys_by_project(db: Session) -> None: - user = create_test_user(db) - org = create_test_organization(db) - project = create_test_project(db, org.id) + user = create_random_user(db) + project = create_test_project(db) + + created_key = api_key_crud.create_api_key( + db, project.organization_id, user.id, project.id + ) - created_key = api_key_crud.create_api_key(db, org.id, user.id, project.id) retrieved_keys = api_key_crud.get_api_keys_by_project(db, project.id) assert retrieved_keys is not None diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index f47ea8d6..1ff20105 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -1,39 +1,26 @@ -import uuid from sqlmodel import Session import pytest -from datetime import datetime -from app.crud import credentials as credentials_crud -from app.models import Credential, CredsCreate, CredsUpdate, Organization, Project -from app.tests.utils.utils import random_email -from app.core.security import get_password_hash - - -def create_organization_and_project(db: Session) -> tuple[Organization, Project]: - """Helper function to create an organization and a project.""" - organization = Organization( - name=f"Test Organization {uuid.uuid4()}", is_active=True - ) - db.add(organization) - db.commit() - db.refresh(organization) - - project = Project( - name=f"Test Project {uuid.uuid4()}", - description="A test project", - organization_id=organization.id, - is_active=True, - ) - db.add(project) - db.commit() - db.refresh(project) - - return organization, project +from app.crud import ( + set_creds_for_org, + get_creds_by_org, + get_provider_credential, + update_creds_for_org, + remove_provider_credential, + remove_creds_for_org, +) +from app.models import CredsCreate, CredsUpdate +from app.core.providers import Provider +from app.tests.utils.test_data import ( + create_test_project, + create_test_credential, + test_credential_data, +) def test_set_creds_for_org(db: Session) -> None: """Test setting credentials for an organization.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Test credentials for supported providers creds_data = { @@ -45,42 +32,26 @@ def test_set_creds_for_org(db: Session) -> None: }, } - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - - created_creds = credentials_crud.set_creds_for_org( - session=db, creds_add=creds_create + creds_create = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + credential=creds_data, ) + created_creds = set_creds_for_org(session=db, creds_add=creds_create) + assert len(created_creds) == 2 - assert all(cred.organization_id == organization.id for cred in created_creds) + assert all( + cred.organization_id == project.organization_id for cred in created_creds + ) + assert all(cred.project_id == project.id for cred in created_creds) assert all(cred.is_active for cred in created_creds) assert {cred.provider for cred in created_creds} == {"openai", "langfuse"} -def test_set_creds_for_org_with_project(db: Session) -> None: - """Test setting credentials for an organization with a specific project.""" - organization, project = create_organization_and_project(db) - - creds_data = {"openai": {"api_key": "test-openai-key"}} - - creds_create = CredsCreate( - organization_id=organization.id, project_id=project.id, credential=creds_data - ) - - created_creds = credentials_crud.set_creds_for_org( - session=db, creds_add=creds_create - ) - - assert len(created_creds) == 1 - assert created_creds[0].organization_id == organization.id - assert created_creds[0].project_id == project.id - assert created_creds[0].provider == "openai" - assert created_creds[0].is_active - - def test_get_creds_by_org(db: Session) -> None: """Test retrieving all credentials for an organization.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Set up test credentials creds_data = { @@ -92,101 +63,81 @@ def test_get_creds_by_org(db: Session) -> None: }, } - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + credential=creds_data, + ) + set_creds_for_org(session=db, creds_add=creds_create) # Test retrieving credentials - retrieved_creds = credentials_crud.get_creds_by_org( - session=db, org_id=organization.id - ) + retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) assert len(retrieved_creds) == 2 - assert all(cred.organization_id == organization.id for cred in retrieved_creds) + assert all( + cred.organization_id == project.organization_id for cred in retrieved_creds + ) assert {cred.provider for cred in retrieved_creds} == {"openai", "langfuse"} def test_get_provider_credential(db: Session) -> None: """Test retrieving credentials for a specific provider.""" - organization, _ = create_organization_and_project(db) - - # Set up test credentials - creds_data = {"openai": {"api_key": "test-openai-key"}} - - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds_create = test_credential_data(db) + original_api_key = creds_create.credential[Provider.OPENAI.value]["api_key"] + set_creds_for_org(session=db, creds_add=creds_create) # Test retrieving specific provider credentials - retrieved_cred = credentials_crud.get_provider_credential( - session=db, org_id=organization.id, provider="openai" + retrieved_cred = get_provider_credential( + session=db, org_id=creds_create.organization_id, provider="openai" ) assert retrieved_cred is not None assert "api_key" in retrieved_cred - assert retrieved_cred["api_key"] == "test-openai-key" + assert retrieved_cred["api_key"] == original_api_key def test_update_creds_for_org(db: Session) -> None: """Test updating credentials for a provider.""" - organization, _ = create_organization_and_project(db) - - # Set up initial credentials - initial_creds = {"openai": {"api_key": "initial-key"}} - creds_create = CredsCreate( - organization_id=organization.id, credential=initial_creds - ) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds = create_test_credential(db)[0] # Update credentials updated_creds = {"api_key": "updated-key"} creds_update = CredsUpdate(provider="openai", credential=updated_creds) - updated = credentials_crud.update_creds_for_org( - session=db, org_id=organization.id, creds_in=creds_update + updated = update_creds_for_org( + session=db, org_id=creds.organization_id, creds_in=creds_update ) assert len(updated) == 1 assert updated[0].provider == "openai" - retrieved_cred = credentials_crud.get_provider_credential( - session=db, org_id=organization.id, provider="openai" + retrieved_cred = get_provider_credential( + session=db, org_id=creds.organization_id, provider="openai" ) assert retrieved_cred["api_key"] == "updated-key" def test_remove_provider_credential(db: Session) -> None: """Test removing credentials for a specific provider.""" - organization, _ = create_organization_and_project(db) - - # Set up test credentials - creds_data = { - "openai": {"api_key": "test-openai-key"}, - "langfuse": { - "public_key": "test-public-key", - "secret_key": "test-secret-key", - "host": "https://cloud.langfuse.com", - }, - } - - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds = create_test_credential(db)[0] # Remove one provider's credentials - removed = credentials_crud.remove_provider_credential( - session=db, org_id=organization.id, provider="openai" + removed = remove_provider_credential( + session=db, org_id=creds.organization_id, provider="openai" ) assert removed.is_active is False assert removed.updated_at is not None # Verify the credentials are no longer retrievable - retrieved_cred = credentials_crud.get_provider_credential( - session=db, org_id=organization.id, provider="openai" + retrieved_cred = get_provider_credential( + session=db, org_id=creds.organization_id, provider="openai" ) assert retrieved_cred is None def test_remove_creds_for_org(db: Session) -> None: """Test removing all credentials for an organization.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Set up test credentials creds_data = { @@ -198,49 +149,55 @@ def test_remove_creds_for_org(db: Session) -> None: }, } - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + credential=creds_data, + ) + set_creds_for_org(session=db, creds_add=creds_create) # Remove all credentials - removed = credentials_crud.remove_creds_for_org(session=db, org_id=organization.id) + removed = remove_creds_for_org(session=db, org_id=project.organization_id) assert len(removed) == 2 assert all(not cred.is_active for cred in removed) assert all(cred.updated_at is not None for cred in removed) # Verify no credentials are retrievable - retrieved_creds = credentials_crud.get_creds_by_org( - session=db, org_id=organization.id - ) + retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) assert len(retrieved_creds) == 0 def test_invalid_provider(db: Session) -> None: """Test handling of invalid provider names.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Test with unsupported provider creds_data = {"gemini": {"api_key": "test-key"}} - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + creds_create = CredsCreate( + organization_id=project.organization_id, credential=creds_data + ) with pytest.raises(ValueError, match="Unsupported provider"): - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=creds_create) def test_duplicate_provider_credentials(db: Session) -> None: """Test handling of duplicate provider credentials.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Set up initial credentials creds_data = {"openai": {"api_key": "test-key"}} - creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + creds_create = CredsCreate( + organization_id=project.organization_id, credential=creds_data + ) + set_creds_for_org(session=db, creds_add=creds_create) # Verify credentials exist and are active - existing_creds = credentials_crud.get_provider_credential( - session=db, org_id=organization.id, provider="openai" + existing_creds = get_provider_credential( + session=db, org_id=project.organization_id, provider="openai" ) assert existing_creds is not None assert "api_key" in existing_creds @@ -249,7 +206,7 @@ def test_duplicate_provider_credentials(db: Session) -> None: def test_langfuse_credential_validation(db: Session) -> None: """Test validation of Langfuse credentials structure.""" - organization, _ = create_organization_and_project(db) + project = create_test_project(db) # Test with missing required fields invalid_creds = { @@ -261,11 +218,11 @@ def test_langfuse_credential_validation(db: Session) -> None: } creds_create = CredsCreate( - organization_id=organization.id, credential=invalid_creds + organization_id=project.organization_id, credential=invalid_creds ) with pytest.raises(ValueError): - credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=creds_create) # Test with valid Langfuse credentials valid_creds = { @@ -276,10 +233,10 @@ def test_langfuse_credential_validation(db: Session) -> None: } } - creds_create = CredsCreate(organization_id=organization.id, credential=valid_creds) - - created_creds = credentials_crud.set_creds_for_org( - session=db, creds_add=creds_create + creds_create = CredsCreate( + organization_id=project.organization_id, credential=valid_creds ) + + created_creds = set_creds_for_org(session=db, creds_add=creds_create) assert len(created_creds) == 1 assert created_creds[0].provider == "langfuse" diff --git a/backend/app/tests/crud/test_org.py b/backend/app/tests/crud/test_org.py index 7efba0ec..052988ff 100644 --- a/backend/app/tests/crud/test_org.py +++ b/backend/app/tests/crud/test_org.py @@ -3,6 +3,7 @@ from app.crud.organization import create_organization, get_organization_by_id from app.models import Organization, OrganizationCreate from app.tests.utils.utils import random_lower_string, get_non_existent_id +from app.tests.utils.test_data import create_test_organization def test_create_organization(db: Session) -> None: @@ -13,14 +14,12 @@ def test_create_organization(db: Session) -> None: assert org.name == name assert org.id is not None - assert org.is_active is True # Default should be active + assert org.is_active is True def test_get_organization_by_id(db: Session) -> None: """Test retrieving an organization by ID.""" - name = random_lower_string() - org_in = OrganizationCreate(name=name) - org = create_organization(session=db, org_create=org_in) + org = create_test_organization(db) fetched_org = get_organization_by_id(session=db, org_id=org.id) assert fetched_org @@ -31,7 +30,5 @@ def test_get_organization_by_id(db: Session) -> None: def test_get_non_existent_organization(db: Session) -> None: """Test retrieving a non-existent organization should return None.""" org_id = get_non_existent_id(db, Organization) - fetched_org = get_organization_by_id( - session=db, org_id=org_id - ) # Assuming ID 999 does not exist + fetched_org = get_organization_by_id(session=db, org_id=org_id) assert fetched_org is None diff --git a/backend/app/tests/crud/test_project.py b/backend/app/tests/crud/test_project.py index c9b135ce..041821cc 100644 --- a/backend/app/tests/crud/test_project.py +++ b/backend/app/tests/crud/test_project.py @@ -2,22 +2,20 @@ from sqlmodel import Session from fastapi import HTTPException -from app.models import Project, ProjectCreate, Organization +from app.models import Project, ProjectCreate from app.crud.project import ( create_project, get_project_by_id, get_projects_by_organization, validate_project, ) -from app.tests.utils.utils import random_lower_string +from app.tests.utils.utils import random_lower_string, get_non_existent_id +from app.tests.utils.test_data import create_test_project, create_test_organization def test_create_project(db: Session) -> None: """Test creating a project linked to an organization.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) + org = create_test_organization(db) project_name = random_lower_string() project_data = ProjectCreate( @@ -37,17 +35,7 @@ def test_create_project(db: Session) -> None: def test_get_project_by_id(db: Session) -> None: """Test retrieving a project by ID.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) - - project_name = random_lower_string() - project_data = ProjectCreate( - name=project_name, description="Test", organization_id=org.id - ) - - project = create_project(session=db, project_create=project_data) + project = create_test_project(db) fetched_project = get_project_by_id(session=db, project_id=project.id) assert fetched_project is not None @@ -57,53 +45,43 @@ def test_get_project_by_id(db: Session) -> None: def test_get_projects_by_organization(db: Session) -> None: """Test retrieving all projects for an organization.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) + org = create_test_organization(db) project_1 = create_project( session=db, project_create=ProjectCreate( - name=random_lower_string(), organization_id=org.id + name="Project 1", + description="Test project 1", + is_active=True, + organization_id=org.id, ), ) + project_2 = create_project( session=db, project_create=ProjectCreate( - name=random_lower_string(), organization_id=org.id + name="Project 2", + description="Test project 2", + is_active=True, + organization_id=org.id, ), ) projects = get_projects_by_organization(session=db, org_id=org.id) - assert len(projects) == 2 assert project_1 in projects assert project_2 in projects def test_get_non_existent_project(db: Session) -> None: - """Test retrieving a non-existent project should return None.""" - fetched_project = get_project_by_id(session=db, project_id=999) + non_existent_project_id = get_non_existent_id(db, Project) + fetched_project = get_project_by_id(session=db, project_id=non_existent_project_id) assert fetched_project is None def test_validate_project_success(db: Session) -> None: """Test that a valid and active project passes validation.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) - - project = create_project( - session=db, - project_create=ProjectCreate( - name=random_lower_string(), - description="Valid project", - is_active=True, - organization_id=org.id, - ), - ) + project = create_test_project(db) validated_project = validate_project(session=db, project_id=project.id) assert validated_project.id == project.id @@ -111,16 +89,14 @@ def test_validate_project_success(db: Session) -> None: def test_validate_project_not_found(db: Session) -> None: """Test that validation fails when project does not exist.""" + non_existent_project_id = get_non_existent_id(db, Project) with pytest.raises(HTTPException, match="Project not found"): - validate_project(session=db, project_id=9999) + validate_project(session=db, project_id=non_existent_project_id) def test_validate_project_inactive(db: Session) -> None: """Test that validation fails when project is inactive.""" - org = Organization(name=random_lower_string()) - db.add(org) - db.commit() - db.refresh(org) + org = create_test_organization(db) inactive_project = create_project( session=db, diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index ded5c728..0a68125c 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -7,6 +7,7 @@ from app.core.config import settings from app.models import Collection, Organization, Project from app.tests.utils.utils import get_user_id_by_email +from app.tests.utils.test_data import create_test_project from app.crud import create_api_key @@ -24,19 +25,14 @@ def get_collection(db: Session, client=None): owner_id = get_user_id_by_email(db) # Step 1: Create real organization and project entries - organization = Organization(name=f"Test Org {uuid4()}") - db.add(organization) - db.commit() - db.refresh(organization) - - project = Project(name="Test Project {uuid4()}", organization_id=organization.id) - db.add(project) - db.commit() - db.refresh(project) + project = create_test_project(db) # Step 2: Create API key for user with valid foreign keys create_api_key( - db, organization_id=organization.id, user_id=owner_id, project_id=project.id + db, + organization_id=project.organization_id, + user_id=owner_id, + project_id=project.id, ) if client is None: @@ -51,7 +47,7 @@ def get_collection(db: Session, client=None): return Collection( owner_id=owner_id, - organization_id=organization.id, + organization_id=project.organization_id, project_id=project.id, llm_service_id=assistant.id, llm_service_name=constants.llm_service_name, diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py new file mode 100644 index 00000000..a623cd1e --- /dev/null +++ b/backend/app/tests/utils/test_data.py @@ -0,0 +1,114 @@ +from sqlmodel import Session +from app.models import ( + Organization, + Project, + APIKey, + Credential, + OrganizationCreate, + ProjectCreate, + APIKeyPublic, + CredsCreate, +) +from app.crud import ( + create_organization, + create_project, + create_api_key, + set_creds_for_org, +) +from app.core.providers import Provider +from app.tests.utils.user import create_random_user +from app.tests.utils.utils import random_lower_string, generate_random_string + + +def create_test_organization(db: Session) -> Organization: + """ + Creates and returns a test organization with a unique name. + + Persists the organization to the database. + """ + name = f"TestOrg-{random_lower_string()}" + org_in = OrganizationCreate(name=name, is_active=True) + return create_organization(session=db, org_create=org_in) + + +def create_test_project(db: Session) -> Project: + """ + Creates and returns a test project under a newly created test organization. + + Persists both the organization and the project to the database. + + """ + org = create_test_organization(db) + name = f"TestProject-{random_lower_string()}" + project_description = "This is a test project description." + project_in = ProjectCreate( + name=name, + description=project_description, + is_active=True, + organization_id=org.id, + ) + return create_project(session=db, project_create=project_in) + + +def create_test_api_key(db: Session) -> APIKey: + """ + Creates and returns an API key for a test project and test user. + + Persists a test user, organization, project, and API key to the database + """ + project = create_test_project(db) + user = create_random_user(db) + api_key = create_api_key( + db, + organization_id=project.organization_id, + user_id=user.id, + project_id=project.id, + ) + return api_key + + +def test_credential_data(db: Session) -> CredsCreate: + """ + Returns credential data for a test project in the form of a CredsCreate schema. + + Use this when you just need credential input data without persisting it to the database. + """ + project = create_test_project(db) + api_key = "sk-" + generate_random_string(10) + creds_data = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + is_active=True, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, + ) + return creds_data + + +def create_test_credential(db: Session) -> Credential: + """ + Creates and returns a test credential for a test project. + + Persists the organization, project, and credential to the database. + + """ + project = create_test_project(db) + api_key = "sk-" + generate_random_string(10) + creds_data = CredsCreate( + organization_id=project.organization_id, + project_id=project.id, + is_active=True, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, + ) + return set_creds_for_org(session=db, creds_add=creds_data) diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 9cba7aae..351444b4 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -26,6 +26,10 @@ def random_lower_string() -> str: return "".join(random.choices(string.ascii_lowercase, k=32)) +def generate_random_string(length=10): + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + def random_email() -> str: return f"{random_lower_string()}@{random_lower_string()}.com" From f8097906ac04515f1b4987c099c749eb55c3f31e Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 11:21:42 +0530 Subject: [PATCH 02/15] credentials --- backend/app/tests/api/routes/test_creds.py | 8 +++----- backend/app/tests/crud/test_credentials.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 5d0ff141..f314b2ac 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -25,7 +25,7 @@ def create_test_credentials(db: Session): return create_test_credential(db) -def test_set_creds(db: Session, superuser_token_headers: dict[str, str]): +def test_set_credential(db: Session, superuser_token_headers: dict[str, str]): project = create_test_project(db) api_key = "sk-" + generate_random_string(10) @@ -57,7 +57,7 @@ def test_set_creds(db: Session, superuser_token_headers: dict[str, str]): assert data[0]["credential"]["model"] == "gpt-4" -def test_set_creds_for_invalid_project_org_relationship( +def test_set_credentials_for_invalid_project_org_relationship( db: Session, superuser_token_headers: dict[str, str] ): org1 = create_test_organization(db) @@ -83,7 +83,7 @@ def test_set_creds_for_invalid_project_org_relationship( ) -def test_set_creds_for_project_not_found( +def test_set_credentials_for_project_not_found( db: Session, superuser_token_headers: dict[str, str] ): # Setup: Create an organization but no project @@ -142,7 +142,6 @@ def test_read_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): creds = create_test_credentials[0] - print(creds) response = client.get( f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", @@ -331,7 +330,6 @@ def test_duplicate_credential_creation( json=creds.dict(), headers=superuser_token_headers, ) - print(response) assert response.status_code == 200 # Try to create the same credentials again diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 1ff20105..bef52e1e 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -18,7 +18,7 @@ ) -def test_set_creds_for_org(db: Session) -> None: +def test_set_credentials_for_org(db: Session) -> None: """Test setting credentials for an organization.""" project = create_test_project(db) From e266d2379a7953bbdefa8f80fb872d69309758d2 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 14:32:50 +0530 Subject: [PATCH 03/15] get creds by provider --- backend/app/tests/api/routes/test_creds.py | 31 ++++++++++++++++------ backend/app/tests/crud/test_credentials.py | 9 +++++-- backend/app/tests/utils/utils.py | 14 +++++++++- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index f314b2ac..8b34bb66 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -8,7 +8,11 @@ from app.core.providers import Provider from app.models.credentials import Credential from app.core.security import decrypt_credentials -from app.tests.utils.utils import generate_random_string, get_non_existent_id +from app.tests.utils.utils import ( + generate_random_string, + get_non_existent_id, + get_credential_by_provider, +) from app.tests.utils.test_data import ( create_test_credential, create_test_organization, @@ -110,8 +114,10 @@ def test_set_credentials_for_project_not_found( def test_read_credentials_with_creds( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds = create_test_credentials[0] - print(creds) + creds_list = create_test_credentials + + creds = get_credential_by_provider(creds_list, "openai") + response = client.get( f"{settings.API_V1_STR}/credentials/{creds.organization_id}", headers=superuser_token_headers, @@ -141,7 +147,8 @@ def test_read_credentials_not_found( def test_read_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds = create_test_credentials[0] + creds_list = create_test_credentials + creds = get_credential_by_provider(creds_list, "openai") response = client.get( f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", @@ -171,7 +178,9 @@ def test_read_provider_credential_not_found( def test_update_credentials( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds = create_test_credentials[0] + creds_list = create_test_credentials + + creds = get_credential_by_provider(creds_list, "openai") update_data = { "provider": Provider.OPENAI.value, @@ -200,7 +209,9 @@ def test_update_credentials( def test_update_credentials_failed_update( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds = create_test_credentials[0] + creds_list = create_test_credentials + + creds = get_credential_by_provider(creds_list, "openai") org_without_creds = create_test_organization(db) @@ -257,7 +268,9 @@ def test_update_credentials_not_found( def test_delete_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds = create_test_credentials[0] + creds_list = create_test_credentials + + creds = get_credential_by_provider(creds_list, "openai") response = client.delete( f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", @@ -287,7 +300,9 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds = create_test_credentials[0] + creds_list = create_test_credentials + + creds = get_credential_by_provider(creds_list, "openai") response = client.delete( f"{settings.API_V1_STR}/credentials/{creds.organization_id}", diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index bef52e1e..1abd68b8 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -11,6 +11,7 @@ ) from app.models import CredsCreate, CredsUpdate from app.core.providers import Provider +from app.tests.utils.utils import get_credential_by_provider from app.tests.utils.test_data import ( create_test_project, create_test_credential, @@ -98,7 +99,9 @@ def test_get_provider_credential(db: Session) -> None: def test_update_creds_for_org(db: Session) -> None: """Test updating credentials for a provider.""" - creds = create_test_credential(db)[0] + creds_list = create_test_credential(db) + + creds = get_credential_by_provider(creds_list, "openai") # Update credentials updated_creds = {"api_key": "updated-key"} @@ -118,7 +121,9 @@ def test_update_creds_for_org(db: Session) -> None: def test_remove_provider_credential(db: Session) -> None: """Test removing credentials for a specific provider.""" - creds = create_test_credential(db)[0] + creds_list = create_test_credential(db) + + creds = get_credential_by_provider(creds_list, "openai") # Remove one provider's credentials removed = remove_provider_credential( diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 351444b4..ac94a183 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -1,6 +1,7 @@ import random import string from uuid import UUID +from typing import List import pytest from fastapi.testclient import TestClient @@ -9,7 +10,7 @@ from app.core.config import settings from app.crud.user import get_user_by_email -from app.models import APIKeyPublic +from app.models import APIKeyPublic, Credential from app.crud import create_api_key, get_api_key_by_value from uuid import uuid4 @@ -59,6 +60,17 @@ def get_user_from_api_key(db: Session, api_key_headers: dict[str, str]) -> APIKe return api_key +def get_credential_by_provider(creds: List[Credential], provider: str) -> Credential: + """ + From a list of credentials, return the one matching the given provider. + Raises ValueError if not found. + """ + for c in creds: + if c.provider == provider: + return c + raise ValueError(f"No credential found for provider: {provider}") + + def get_non_existent_id(session: Session, model: Type[T]) -> int: result = session.exec(select(model.id).order_by(model.id.desc())).first() return (result or 0) + 1 From b7f9cb5bdf24e86ad6b91e580585db5a847fc7e5 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 14:42:49 +0530 Subject: [PATCH 04/15] small fixes --- backend/app/tests/api/routes/test_api_key.py | 7 +------ backend/app/tests/api/routes/test_creds.py | 2 -- backend/app/tests/api/routes/test_org.py | 4 ++-- backend/app/tests/utils/test_data.py | 6 ++++-- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/backend/app/tests/api/routes/test_api_key.py b/backend/app/tests/api/routes/test_api_key.py index 7d018a66..c4e63954 100644 --- a/backend/app/tests/api/routes/test_api_key.py +++ b/backend/app/tests/api/routes/test_api_key.py @@ -6,11 +6,7 @@ from app.core.config import settings from app.tests.utils.utils import get_non_existent_id from app.tests.utils.user import create_random_user -from app.tests.utils.test_data import ( - create_test_api_key, - create_test_project, - create_test_organization, -) +from app.tests.utils.test_data import create_test_api_key, create_test_project client = TestClient(app) @@ -87,7 +83,6 @@ def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, str]): api_key_id = get_non_existent_id(db, APIKey) - print(api_key_id) response = client.get( f"{settings.API_V1_STR}/apikeys/{api_key_id}", headers=superuser_token_headers, diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 8b34bb66..a9a730dd 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -80,7 +80,6 @@ def test_set_credentials_for_invalid_project_org_relationship( headers=superuser_token_headers, ) assert response_invalid.status_code == 400 - print(response_invalid.json()) assert ( response_invalid.json()["error"] == "Project does not belong to the specified organization" @@ -293,7 +292,6 @@ def test_delete_provider_credential_not_found( ) assert response.status_code == 404 - print(response.json()) assert response.json()["error"] == f"Provider credentials not found" diff --git a/backend/app/tests/api/routes/test_org.py b/backend/app/tests/api/routes/test_org.py index 607e1aef..2b518f34 100644 --- a/backend/app/tests/api/routes/test_org.py +++ b/backend/app/tests/api/routes/test_org.py @@ -5,14 +5,14 @@ from app.core.config import settings from app.models import Organization from app.main import app -from app.crud.organization import create_organization, get_organization_by_id +from app.crud.organization import get_organization_by_id from app.tests.utils.test_data import create_test_organization client = TestClient(app) @pytest.fixture -def test_organization(db: Session, superuser_token_headers: dict[str, str]): +def test_organization(db: Session): return create_test_organization(db) diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index a623cd1e..16ef0802 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -1,4 +1,7 @@ +from typing import List + from sqlmodel import Session + from app.models import ( Organization, Project, @@ -6,7 +9,6 @@ Credential, OrganizationCreate, ProjectCreate, - APIKeyPublic, CredsCreate, ) from app.crud import ( @@ -90,7 +92,7 @@ def test_credential_data(db: Session) -> CredsCreate: return creds_data -def create_test_credential(db: Session) -> Credential: +def create_test_credential(db: Session) -> List[Credential]: """ Creates and returns a test credential for a test project. From 51ecac548d298a0e8b696a019e6154c3dc6575b6 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 20:33:10 +0530 Subject: [PATCH 05/15] function for creds --- backend/app/crud/__init__.py | 1 + backend/app/crud/credentials.py | 20 +++ backend/app/tests/api/routes/test_creds.py | 134 ++++++++++++--------- backend/app/tests/crud/test_credentials.py | 28 +++-- backend/app/tests/utils/test_data.py | 2 +- 5 files changed, 121 insertions(+), 64 deletions(-) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 5c6fd102..95c1c21b 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -39,6 +39,7 @@ remove_creds_for_org, get_provider_credential, remove_provider_credential, + get_full_provider_credential, ) from .thread_results import upsert_thread_result, get_thread_result diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index ba3503e1..cbced836 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -107,6 +107,26 @@ def get_provider_credential( return None +def get_full_provider_credential( + *, session: Session, org_id: int, provider: str, project_id: Optional[int] = None +) -> Optional[Dict[str, Any]]: + """Fetches credentials for a specific provider of an organization.""" + validate_provider(provider) + + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider, + Credential.is_active == True, + Credential.project_id == project_id if project_id is not None else True, + ) + creds = session.exec(statement).first() + + if creds and creds.credential: + # Decrypt entire credentials object + return creds + return None + + def get_providers( *, session: Session, org_id: int, project_id: Optional[int] = None ) -> List[str]: diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index a9a730dd..54cdc91a 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -3,6 +3,7 @@ from sqlmodel import Session from app.main import app +from app.crud import get_full_provider_credential from app.models import Organization, Project from app.core.config import settings from app.core.providers import Provider @@ -33,7 +34,7 @@ def test_set_credential(db: Session, superuser_token_headers: dict[str, str]): project = create_test_project(db) api_key = "sk-" + generate_random_string(10) - creds_data = { + credential_data = { "organization_id": project.organization_id, "project_id": project.id, "is_active": True, @@ -48,7 +49,7 @@ def test_set_credential(db: Session, superuser_token_headers: dict[str, str]): response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data, + json=credential_data, headers=superuser_token_headers, ) @@ -67,7 +68,7 @@ def test_set_credentials_for_invalid_project_org_relationship( org1 = create_test_organization(db) project2 = create_test_project(db) - creds_data_invalid = { + credential_data_invalid = { "organization_id": org1.id, "is_active": True, "project_id": project2.id, # Invalid project for org1 @@ -76,7 +77,7 @@ def test_set_credentials_for_invalid_project_org_relationship( response_invalid = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data_invalid, + json=credential_data_invalid, headers=superuser_token_headers, ) assert response_invalid.status_code == 400 @@ -93,7 +94,7 @@ def test_set_credentials_for_project_not_found( org = create_test_organization(db) non_existent_project_id = get_non_existent_id(db, Project) - creds_data_invalid_project = { + credential_data_invalid_project = { "organization_id": org.id, "is_active": True, "project_id": non_existent_project_id, @@ -102,7 +103,7 @@ def test_set_credentials_for_project_not_found( response_invalid_project = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds_data_invalid_project, + json=credential_data_invalid_project, headers=superuser_token_headers, ) @@ -113,12 +114,17 @@ def test_set_credentials_for_project_not_found( def test_read_credentials_with_creds( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds_list = create_test_credentials + _, project = create_test_credentials - creds = get_credential_by_provider(creds_list, "openai") + credential = get_full_provider_credential( + session=db, + org_id=project.organization_id, + provider="openai", + project_id=project.id, + ) response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{credential.organization_id}", headers=superuser_token_headers, ) @@ -126,7 +132,7 @@ def test_read_credentials_with_creds( data = response.json()["data"] assert isinstance(data, list) assert len(data) == 1 - assert data[0]["organization_id"] == creds.organization_id + assert data[0]["organization_id"] == project.organization_id assert data[0]["provider"] == Provider.OPENAI.value assert data[0]["credential"]["model"] == "gpt-4" @@ -146,11 +152,10 @@ def test_read_credentials_not_found( def test_read_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds_list = create_test_credentials - creds = get_credential_by_provider(creds_list, "openai") + _, project = create_test_credentials response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{project.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) @@ -177,9 +182,14 @@ def test_read_provider_credential_not_found( def test_update_credentials( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds_list = create_test_credentials + _, project = create_test_credentials - creds = get_credential_by_provider(creds_list, "openai") + credential = get_full_provider_credential( + session=db, + org_id=project.organization_id, + provider="openai", + project_id=project.id, + ) update_data = { "provider": Provider.OPENAI.value, @@ -191,7 +201,7 @@ def test_update_credentials( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{credential.organization_id}", json=update_data, headers=superuser_token_headers, ) @@ -208,16 +218,23 @@ def test_update_credentials( def test_update_credentials_failed_update( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds_list = create_test_credentials + _, project = create_test_credentials - creds = get_credential_by_provider(creds_list, "openai") + credential = get_full_provider_credential( + session=db, + org_id=project.organization_id, + provider="openai", + project_id=project.id, + ) - org_without_creds = create_test_organization(db) + org_without_credential = create_test_organization(db) - existing_creds = ( - db.query(Credential).filter(creds.organization_id == org_without_creds.id).all() + existing_credential = ( + db.query(Credential) + .filter(credential.organization_id == org_without_credential.id) + .all() ) - assert len(existing_creds) == 0 + assert len(existing_credential) == 0 update_data = { "provider": Provider.OPENAI.value, @@ -229,7 +246,7 @@ def test_update_credentials_failed_update( } response_invalid_org = client.patch( - f"{settings.API_V1_STR}/credentials/{org_without_creds.id}", + f"{settings.API_V1_STR}/credentials/{org_without_credential.id}", json=update_data, headers=superuser_token_headers, ) @@ -267,12 +284,17 @@ def test_update_credentials_not_found( def test_delete_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds_list = create_test_credentials + _, project = create_test_credentials - creds = get_credential_by_provider(creds_list, "openai") + credential = get_full_provider_credential( + session=db, + org_id=project.organization_id, + provider="openai", + project_id=project.id, + ) response = client.delete( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{credential.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) @@ -298,12 +320,16 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( db: Session, superuser_token_headers: dict[str, str], create_test_credentials ): - creds_list = create_test_credentials - - creds = get_credential_by_provider(creds_list, "openai") + _, project = create_test_credentials + credential = get_full_provider_credential( + session=db, + org_id=project.organization_id, + provider="openai", + project_id=project.id, + ) response = client.delete( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{credential.organization_id}", headers=superuser_token_headers, ) @@ -313,7 +339,7 @@ def test_delete_all_credentials( # Verify the credentials are soft deleted response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{credential.organization_id}", headers=superuser_token_headers, ) assert response.status_code == 404 # Expect 404 as credentials are soft deleted @@ -323,9 +349,9 @@ def test_delete_all_credentials( def test_delete_all_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): - non_existent_creds_id = get_non_existent_id(db, Credential) + non_existent_credential_id = get_non_existent_id(db, Credential) response = client.delete( - f"{settings.API_V1_STR}/credentials/{non_existent_creds_id}", + f"{settings.API_V1_STR}/credentials/{non_existent_credential_id}", headers=superuser_token_headers, ) @@ -336,11 +362,11 @@ def test_delete_all_credentials_not_found( def test_duplicate_credential_creation( db: Session, superuser_token_headers: dict[str, str] ): - creds = test_credential_data(db) + credential = test_credential_data(db) response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds.dict(), + json=credential.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 @@ -348,7 +374,7 @@ def test_duplicate_credential_creation( # Try to create the same credentials again response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds.dict(), + json=credential.dict(), headers=superuser_token_headers, ) assert response.status_code == 400 @@ -362,7 +388,7 @@ def test_multiple_provider_credentials( org = create_test_organization(db) # Create OpenAI credentials - openai_creds = { + openai_credential = { "organization_id": org.id, "is_active": True, "credential": { @@ -375,7 +401,7 @@ def test_multiple_provider_credentials( } # Create Langfuse credentials - langfuse_creds = { + langfuse_credential = { "organization_id": org.id, "is_active": True, "credential": { @@ -390,14 +416,14 @@ def test_multiple_provider_credentials( # Create both credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=openai_creds, + json=openai_credential, headers=superuser_token_headers, ) assert response.status_code == 200 response = client.post( f"{settings.API_V1_STR}/credentials/", - json=langfuse_creds, + json=langfuse_credential, headers=superuser_token_headers, ) assert response.status_code == 200 @@ -416,52 +442,52 @@ def test_multiple_provider_credentials( def test_credential_encryption(db: Session, superuser_token_headers: dict[str, str]): - creds = test_credential_data(db) - original_api_key = creds.credential[Provider.OPENAI.value]["api_key"] + credential = test_credential_data(db) + original_api_key = credential.credential[Provider.OPENAI.value]["api_key"] # Create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds.dict(), + json=credential.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 - db_cred = ( + db_credential = ( db.query(Credential) .filter( - Credential.organization_id == creds.organization_id, + Credential.organization_id == credential.organization_id, Credential.provider == Provider.OPENAI.value, ) .first() ) - assert db_cred is not None + assert db_credential is not None # Verify the stored credential is encrypted - assert db_cred.credential != original_api_key + assert db_credential.credential != original_api_key # Verify we can decrypt and get the original value - decrypted_creds = decrypt_credentials(db_cred.credential) + decrypted_creds = decrypt_credentials(db_credential.credential) assert decrypted_creds["api_key"] == original_api_key def test_credential_encryption_consistency( db: Session, superuser_token_headers: dict[str, str] ): - creds = test_credential_data(db) - original_api_key = creds.credential[Provider.OPENAI.value]["api_key"] + credentials = test_credential_data(db) + original_api_key = credentials.credential[Provider.OPENAI.value]["api_key"] # Create credentials response = client.post( f"{settings.API_V1_STR}/credentials/", - json=creds.dict(), + json=credentials.dict(), headers=superuser_token_headers, ) assert response.status_code == 200 # Fetch the credentials through the API response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{credentials.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 @@ -482,14 +508,14 @@ def test_credential_encryption_consistency( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}", + f"{settings.API_V1_STR}/credentials/{credentials.organization_id}", json=update_data, headers=superuser_token_headers, ) assert response.status_code == 200 response = client.get( - f"{settings.API_V1_STR}/credentials/{creds.organization_id}/{Provider.OPENAI.value}", + f"{settings.API_V1_STR}/credentials/{credentials.organization_id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 1abd68b8..82afffdd 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -8,6 +8,7 @@ update_creds_for_org, remove_provider_credential, remove_creds_for_org, + get_full_provider_credential, ) from app.models import CredsCreate, CredsUpdate from app.core.providers import Provider @@ -99,35 +100,44 @@ def test_get_provider_credential(db: Session) -> None: def test_update_creds_for_org(db: Session) -> None: """Test updating credentials for a provider.""" - creds_list = create_test_credential(db) - - creds = get_credential_by_provider(creds_list, "openai") + _, project = create_test_credential(db) + credential = get_full_provider_credential( + session=db, + org_id=project.organization_id, + provider="openai", + project_id=project.id, + ) # Update credentials updated_creds = {"api_key": "updated-key"} creds_update = CredsUpdate(provider="openai", credential=updated_creds) updated = update_creds_for_org( - session=db, org_id=creds.organization_id, creds_in=creds_update + session=db, org_id=credential.organization_id, creds_in=creds_update ) assert len(updated) == 1 assert updated[0].provider == "openai" retrieved_cred = get_provider_credential( - session=db, org_id=creds.organization_id, provider="openai" + session=db, org_id=credential.organization_id, provider="openai" ) assert retrieved_cred["api_key"] == "updated-key" def test_remove_provider_credential(db: Session) -> None: """Test removing credentials for a specific provider.""" - creds_list = create_test_credential(db) + _, project = create_test_credential(db) - creds = get_credential_by_provider(creds_list, "openai") + credential = get_full_provider_credential( + session=db, + org_id=project.organization_id, + provider="openai", + project_id=project.id, + ) # Remove one provider's credentials removed = remove_provider_credential( - session=db, org_id=creds.organization_id, provider="openai" + session=db, org_id=credential.organization_id, provider="openai" ) assert removed.is_active is False @@ -135,7 +145,7 @@ def test_remove_provider_credential(db: Session) -> None: # Verify the credentials are no longer retrievable retrieved_cred = get_provider_credential( - session=db, org_id=creds.organization_id, provider="openai" + session=db, org_id=credential.organization_id, provider="openai" ) assert retrieved_cred is None diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 16ef0802..568702d9 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -113,4 +113,4 @@ def create_test_credential(db: Session) -> List[Credential]: } }, ) - return set_creds_for_org(session=db, creds_add=creds_data) + return set_creds_for_org(session=db, creds_add=creds_data), project From 3b81ec73b777ff323b198f3c72aae659421beae0 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 20:40:51 +0530 Subject: [PATCH 06/15] naming better --- backend/app/tests/crud/test_credentials.py | 78 +++++++++++----------- backend/app/tests/crud/test_org.py | 20 +++--- backend/app/tests/crud/test_project.py | 18 ++--- 3 files changed, 58 insertions(+), 58 deletions(-) diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 82afffdd..e7a39823 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -25,7 +25,7 @@ def test_set_credentials_for_org(db: Session) -> None: project = create_test_project(db) # Test credentials for supported providers - creds_data = { + credentials_data = { "openai": {"api_key": "test-openai-key"}, "langfuse": { "public_key": "test-public-key", @@ -34,21 +34,21 @@ def test_set_credentials_for_org(db: Session) -> None: }, } - creds_create = CredsCreate( + credentials_create = CredsCreate( organization_id=project.organization_id, project_id=project.id, - credential=creds_data, + credential=credentials_data, ) - created_creds = set_creds_for_org(session=db, creds_add=creds_create) + created_credentials = set_creds_for_org(session=db, creds_add=credentials_create) - assert len(created_creds) == 2 + assert len(created_credentials) == 2 assert all( - cred.organization_id == project.organization_id for cred in created_creds + cred.organization_id == project.organization_id for cred in created_credentials ) - assert all(cred.project_id == project.id for cred in created_creds) - assert all(cred.is_active for cred in created_creds) - assert {cred.provider for cred in created_creds} == {"openai", "langfuse"} + assert all(cred.project_id == project.id for cred in created_credentials) + assert all(cred.is_active for cred in created_credentials) + assert {cred.provider for cred in created_credentials} == {"openai", "langfuse"} def test_get_creds_by_org(db: Session) -> None: @@ -56,7 +56,7 @@ def test_get_creds_by_org(db: Session) -> None: project = create_test_project(db) # Set up test credentials - creds_data = { + credentials_data = { "openai": {"api_key": "test-openai-key"}, "langfuse": { "public_key": "test-public-key", @@ -65,12 +65,12 @@ def test_get_creds_by_org(db: Session) -> None: }, } - creds_create = CredsCreate( + credentials_create = CredsCreate( organization_id=project.organization_id, project_id=project.id, - credential=creds_data, + credential=credentials_data, ) - set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=credentials_create) # Test retrieving credentials retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) @@ -84,13 +84,13 @@ def test_get_creds_by_org(db: Session) -> None: def test_get_provider_credential(db: Session) -> None: """Test retrieving credentials for a specific provider.""" - creds_create = test_credential_data(db) - original_api_key = creds_create.credential[Provider.OPENAI.value]["api_key"] + credentials_create = test_credential_data(db) + original_api_key = credentials_create.credential[Provider.OPENAI.value]["api_key"] - set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=credentials_create) # Test retrieving specific provider credentials retrieved_cred = get_provider_credential( - session=db, org_id=creds_create.organization_id, provider="openai" + session=db, org_id=credentials_create.organization_id, provider="openai" ) assert retrieved_cred is not None @@ -155,7 +155,7 @@ def test_remove_creds_for_org(db: Session) -> None: project = create_test_project(db) # Set up test credentials - creds_data = { + credentials_data = { "openai": {"api_key": "test-openai-key"}, "langfuse": { "public_key": "test-public-key", @@ -167,7 +167,7 @@ def test_remove_creds_for_org(db: Session) -> None: creds_create = CredsCreate( organization_id=project.organization_id, project_id=project.id, - credential=creds_data, + credential=credentials_data, ) set_creds_for_org(session=db, creds_add=creds_create) @@ -179,8 +179,8 @@ def test_remove_creds_for_org(db: Session) -> None: assert all(cred.updated_at is not None for cred in removed) # Verify no credentials are retrievable - retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) - assert len(retrieved_creds) == 0 + retrieved_credentials = get_creds_by_org(session=db, org_id=project.organization_id) + assert len(retrieved_credentials) == 0 def test_invalid_provider(db: Session) -> None: @@ -188,14 +188,14 @@ def test_invalid_provider(db: Session) -> None: project = create_test_project(db) # Test with unsupported provider - creds_data = {"gemini": {"api_key": "test-key"}} + credentials_data = {"gemini": {"api_key": "test-key"}} - creds_create = CredsCreate( - organization_id=project.organization_id, credential=creds_data + credentials_create = CredsCreate( + organization_id=project.organization_id, credential=credentials_data ) with pytest.raises(ValueError, match="Unsupported provider"): - set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=credentials_create) def test_duplicate_provider_credentials(db: Session) -> None: @@ -203,12 +203,12 @@ def test_duplicate_provider_credentials(db: Session) -> None: project = create_test_project(db) # Set up initial credentials - creds_data = {"openai": {"api_key": "test-key"}} + credentials_data = {"openai": {"api_key": "test-key"}} - creds_create = CredsCreate( - organization_id=project.organization_id, credential=creds_data + credentials_create = CredsCreate( + organization_id=project.organization_id, credential=credentials_data ) - set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=credentials_create) # Verify credentials exist and are active existing_creds = get_provider_credential( @@ -224,7 +224,7 @@ def test_langfuse_credential_validation(db: Session) -> None: project = create_test_project(db) # Test with missing required fields - invalid_creds = { + invalid_credentials = { "langfuse": { "public_key": "test-public-key", "secret_key": "test-secret-key" @@ -232,15 +232,15 @@ def test_langfuse_credential_validation(db: Session) -> None: } } - creds_create = CredsCreate( - organization_id=project.organization_id, credential=invalid_creds + credentials_create = CredsCreate( + organization_id=project.organization_id, credential=invalid_credentials ) with pytest.raises(ValueError): - set_creds_for_org(session=db, creds_add=creds_create) + set_creds_for_org(session=db, creds_add=credentials_create) # Test with valid Langfuse credentials - valid_creds = { + valid_credentials = { "langfuse": { "public_key": "test-public-key", "secret_key": "test-secret-key", @@ -248,10 +248,10 @@ def test_langfuse_credential_validation(db: Session) -> None: } } - creds_create = CredsCreate( - organization_id=project.organization_id, credential=valid_creds + credentials_create = CredsCreate( + organization_id=project.organization_id, credential=valid_credentials ) - created_creds = set_creds_for_org(session=db, creds_add=creds_create) - assert len(created_creds) == 1 - assert created_creds[0].provider == "langfuse" + created_credentials = set_creds_for_org(session=db, creds_add=credentials_create) + assert len(created_credentials) == 1 + assert created_credentials[0].provider == "langfuse" diff --git a/backend/app/tests/crud/test_org.py b/backend/app/tests/crud/test_org.py index 052988ff..77ee7f31 100644 --- a/backend/app/tests/crud/test_org.py +++ b/backend/app/tests/crud/test_org.py @@ -10,25 +10,25 @@ def test_create_organization(db: Session) -> None: """Test creating an organization.""" name = random_lower_string() org_in = OrganizationCreate(name=name) - org = create_organization(session=db, org_create=org_in) + organization = create_organization(session=db, org_create=org_in) - assert org.name == name - assert org.id is not None - assert org.is_active is True + assert organization.name == name + assert organization.id is not None + assert organization.is_active is True def test_get_organization_by_id(db: Session) -> None: """Test retrieving an organization by ID.""" - org = create_test_organization(db) + organization = create_test_organization(db) - fetched_org = get_organization_by_id(session=db, org_id=org.id) + fetched_org = get_organization_by_id(session=db, org_id=organization.id) assert fetched_org - assert fetched_org.id == org.id - assert fetched_org.name == org.name + assert fetched_org.id == organization.id + assert fetched_org.name == organization.name def test_get_non_existent_organization(db: Session) -> None: """Test retrieving a non-existent organization should return None.""" - org_id = get_non_existent_id(db, Organization) - fetched_org = get_organization_by_id(session=db, org_id=org_id) + organization_id = get_non_existent_id(db, Organization) + fetched_org = get_organization_by_id(session=db, org_id=organization_id) assert fetched_org is None diff --git a/backend/app/tests/crud/test_project.py b/backend/app/tests/crud/test_project.py index 041821cc..3f15a898 100644 --- a/backend/app/tests/crud/test_project.py +++ b/backend/app/tests/crud/test_project.py @@ -15,14 +15,14 @@ def test_create_project(db: Session) -> None: """Test creating a project linked to an organization.""" - org = create_test_organization(db) + organization = create_test_organization(db) project_name = random_lower_string() project_data = ProjectCreate( name=project_name, description="Test description", is_active=True, - organization_id=org.id, + organization_id=organization.id, ) project = create_project(session=db, project_create=project_data) @@ -30,7 +30,7 @@ def test_create_project(db: Session) -> None: assert project.id is not None assert project.name == project_name assert project.description == "Test description" - assert project.organization_id == org.id + assert project.organization_id == organization.id def test_get_project_by_id(db: Session) -> None: @@ -45,7 +45,7 @@ def test_get_project_by_id(db: Session) -> None: def test_get_projects_by_organization(db: Session) -> None: """Test retrieving all projects for an organization.""" - org = create_test_organization(db) + organization = create_test_organization(db) project_1 = create_project( session=db, @@ -53,7 +53,7 @@ def test_get_projects_by_organization(db: Session) -> None: name="Project 1", description="Test project 1", is_active=True, - organization_id=org.id, + organization_id=organization.id, ), ) @@ -63,11 +63,11 @@ def test_get_projects_by_organization(db: Session) -> None: name="Project 2", description="Test project 2", is_active=True, - organization_id=org.id, + organization_id=organization.id, ), ) - projects = get_projects_by_organization(session=db, org_id=org.id) + projects = get_projects_by_organization(session=db, org_id=organization.id) assert project_1 in projects assert project_2 in projects @@ -96,7 +96,7 @@ def test_validate_project_not_found(db: Session) -> None: def test_validate_project_inactive(db: Session) -> None: """Test that validation fails when project is inactive.""" - org = create_test_organization(db) + organization = create_test_organization(db) inactive_project = create_project( session=db, @@ -104,7 +104,7 @@ def test_validate_project_inactive(db: Session) -> None: name=random_lower_string(), description="Inactive project", is_active=False, - organization_id=org.id, + organization_id=organization.id, ), ) From 5d5746958a0c2c7e2081529ffa979a721417b2de Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 22:01:26 +0530 Subject: [PATCH 07/15] small fixes --- backend/app/crud/credentials.py | 1 + backend/app/tests/utils/test_data.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index cbced836..3cd7df52 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -115,6 +115,7 @@ def get_full_provider_credential( statement = select(Credential).where( Credential.organization_id == org_id, + Credential.project_id == project_id, Credential.provider == provider, Credential.is_active == True, Credential.project_id == project_id if project_id is not None else True, diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 568702d9..a6a7e2df 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -92,7 +92,7 @@ def test_credential_data(db: Session) -> CredsCreate: return creds_data -def create_test_credential(db: Session) -> List[Credential]: +def create_test_credential(db: Session) -> tuple[list[Credential], Project]: """ Creates and returns a test credential for a test project. From 7106272545310f0800fe6daf75f9b0d66ba0388d Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 22:06:56 +0530 Subject: [PATCH 08/15] project id fix --- backend/app/tests/crud/test_credentials.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index e7a39823..09809a62 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -191,7 +191,9 @@ def test_invalid_provider(db: Session) -> None: credentials_data = {"gemini": {"api_key": "test-key"}} credentials_create = CredsCreate( - organization_id=project.organization_id, credential=credentials_data + organization_id=project.organization_id, + project_id=project.id, + credential=credentials_data, ) with pytest.raises(ValueError, match="Unsupported provider"): @@ -206,7 +208,9 @@ def test_duplicate_provider_credentials(db: Session) -> None: credentials_data = {"openai": {"api_key": "test-key"}} credentials_create = CredsCreate( - organization_id=project.organization_id, credential=credentials_data + organization_id=project.organization_id, + project_id=project.id, + credential=credentials_data, ) set_creds_for_org(session=db, creds_add=credentials_create) @@ -233,7 +237,9 @@ def test_langfuse_credential_validation(db: Session) -> None: } credentials_create = CredsCreate( - organization_id=project.organization_id, credential=invalid_credentials + organization_id=project.organization_id, + project_id=project.id, + credential=invalid_credentials, ) with pytest.raises(ValueError): @@ -249,7 +255,9 @@ def test_langfuse_credential_validation(db: Session) -> None: } credentials_create = CredsCreate( - organization_id=project.organization_id, credential=valid_credentials + organization_id=project.organization_id, + project_id=project.id, + credential=valid_credentials, ) created_credentials = set_creds_for_org(session=db, creds_add=credentials_create) From c3cefc992b4d4bc6fa30d0262c15758998242692 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 22:08:08 +0530 Subject: [PATCH 09/15] repeated project id twice --- backend/app/crud/credentials.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index 3cd7df52..cbced836 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -115,7 +115,6 @@ def get_full_provider_credential( statement = select(Credential).where( Credential.organization_id == org_id, - Credential.project_id == project_id, Credential.provider == provider, Credential.is_active == True, Credential.project_id == project_id if project_id is not None else True, From 33b4f501612e40924525a5dcaf92f4f90578c3b4 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 22:12:08 +0530 Subject: [PATCH 10/15] small fix --- backend/app/crud/credentials.py | 2 +- backend/app/tests/utils/test_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index cbced836..ce8fed60 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -122,7 +122,7 @@ def get_full_provider_credential( creds = session.exec(statement).first() if creds and creds.credential: - # Decrypt entire credentials object + # Return the full Credential object return creds return None diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index a6a7e2df..6f76cd7d 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -1,4 +1,4 @@ -from typing import List +from typing import list from sqlmodel import Session From 927a4b01b5cd8bd6b6d534e19351c89e4e09299f Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 22:15:36 +0530 Subject: [PATCH 11/15] removed import list --- backend/app/tests/utils/test_data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 6f76cd7d..0d9c5d3b 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -1,5 +1,3 @@ -from typing import list - from sqlmodel import Session from app.models import ( From edcd00971866724f5dfd130d31adc2affb94f39d Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 11 Jul 2025 22:19:31 +0530 Subject: [PATCH 12/15] small fix --- backend/app/crud/credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index ce8fed60..a5826dcf 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -109,7 +109,7 @@ def get_provider_credential( def get_full_provider_credential( *, session: Session, org_id: int, provider: str, project_id: Optional[int] = None -) -> Optional[Dict[str, Any]]: +) -> Optional[Credential]: """Fetches credentials for a specific provider of an organization.""" validate_provider(provider) From b027651f23d47cae167d91bf9fedff302cf8a891 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Sat, 12 Jul 2025 16:57:20 +0530 Subject: [PATCH 13/15] removed redundent code --- backend/app/crud/__init__.py | 1 - backend/app/crud/credentials.py | 42 +++++++++------------- backend/app/tests/api/routes/test_creds.py | 17 +++++---- backend/app/tests/crud/test_credentials.py | 7 ++-- 4 files changed, 31 insertions(+), 36 deletions(-) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 95c1c21b..5c6fd102 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -39,7 +39,6 @@ remove_creds_for_org, get_provider_credential, remove_provider_credential, - get_full_provider_credential, ) from .thread_results import upsert_thread_result, get_thread_result diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index a5826dcf..14a7687b 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Union from sqlmodel import Session, select from sqlalchemy.exc import IntegrityError from datetime import datetime, timezone @@ -88,29 +88,20 @@ def get_creds_by_org( def get_provider_credential( - *, session: Session, org_id: int, provider: str, project_id: Optional[int] = None -) -> Optional[Dict[str, Any]]: - """Fetches credentials for a specific provider of an organization.""" - validate_provider(provider) - - statement = select(Credential).where( - Credential.organization_id == org_id, - Credential.provider == provider, - Credential.is_active == True, - Credential.project_id == project_id if project_id is not None else True, - ) - creds = session.exec(statement).first() - - if creds and creds.credential: - # Decrypt entire credentials object - return decrypt_credentials(creds.credential) - return None - - -def get_full_provider_credential( - *, session: Session, org_id: int, provider: str, project_id: Optional[int] = None -) -> Optional[Credential]: - """Fetches credentials for a specific provider of an organization.""" + *, + session: Session, + org_id: int, + provider: str, + project_id: Optional[int] = None, + full: bool = False, +) -> Optional[Union[Dict[str, Any], Credential]]: + """ + Fetches credentials for a specific provider of a project. + + Returns: + Optional[Union[Dict[str, Any], Credential]]: If full is True, returns full Credential object. + Otherwise returns just the decrypted credentials dict. + """ validate_provider(provider) statement = select(Credential).where( @@ -122,8 +113,7 @@ def get_full_provider_credential( creds = session.exec(statement).first() if creds and creds.credential: - # Return the full Credential object - return creds + return creds if full else decrypt_credentials(creds.credential) return None diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 54cdc91a..e462c610 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -3,7 +3,7 @@ from sqlmodel import Session from app.main import app -from app.crud import get_full_provider_credential +from app.crud import get_provider_credential from app.models import Organization, Project from app.core.config import settings from app.core.providers import Provider @@ -116,11 +116,12 @@ def test_read_credentials_with_creds( ): _, project = create_test_credentials - credential = get_full_provider_credential( + credential = get_provider_credential( session=db, org_id=project.organization_id, provider="openai", project_id=project.id, + full=True, ) response = client.get( @@ -184,11 +185,12 @@ def test_update_credentials( ): _, project = create_test_credentials - credential = get_full_provider_credential( + credential = get_provider_credential( session=db, org_id=project.organization_id, provider="openai", project_id=project.id, + full=True, ) update_data = { @@ -220,11 +222,12 @@ def test_update_credentials_failed_update( ): _, project = create_test_credentials - credential = get_full_provider_credential( + credential = get_provider_credential( session=db, org_id=project.organization_id, provider="openai", project_id=project.id, + full=True, ) org_without_credential = create_test_organization(db) @@ -286,11 +289,12 @@ def test_delete_provider_credential( ): _, project = create_test_credentials - credential = get_full_provider_credential( + credential = get_provider_credential( session=db, org_id=project.organization_id, provider="openai", project_id=project.id, + full=True, ) response = client.delete( @@ -322,11 +326,12 @@ def test_delete_all_credentials( ): _, project = create_test_credentials - credential = get_full_provider_credential( + credential = get_provider_credential( session=db, org_id=project.organization_id, provider="openai", project_id=project.id, + full=True, ) response = client.delete( f"{settings.API_V1_STR}/credentials/{credential.organization_id}", diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 09809a62..a3631b6a 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -8,7 +8,6 @@ update_creds_for_org, remove_provider_credential, remove_creds_for_org, - get_full_provider_credential, ) from app.models import CredsCreate, CredsUpdate from app.core.providers import Provider @@ -102,11 +101,12 @@ def test_update_creds_for_org(db: Session) -> None: """Test updating credentials for a provider.""" _, project = create_test_credential(db) - credential = get_full_provider_credential( + credential = get_provider_credential( session=db, org_id=project.organization_id, provider="openai", project_id=project.id, + full=True, ) # Update credentials updated_creds = {"api_key": "updated-key"} @@ -128,11 +128,12 @@ def test_remove_provider_credential(db: Session) -> None: """Test removing credentials for a specific provider.""" _, project = create_test_credential(db) - credential = get_full_provider_credential( + credential = get_provider_credential( session=db, org_id=project.organization_id, provider="openai", project_id=project.id, + full=True, ) # Remove one provider's credentials From f0f23d4360b28ec780fe0e3c5304d7826bbfd438 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Sat, 12 Jul 2025 17:30:57 +0530 Subject: [PATCH 14/15] small fixes --- backend/app/tests/api/routes/test_creds.py | 3 +-- backend/app/tests/crud/test_credentials.py | 1 - backend/app/tests/utils/utils.py | 11 ----------- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index e462c610..2b440eee 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -12,7 +12,6 @@ from app.tests.utils.utils import ( generate_random_string, get_non_existent_id, - get_credential_by_provider, ) from app.tests.utils.test_data import ( create_test_credential, @@ -318,7 +317,7 @@ def test_delete_provider_credential_not_found( ) assert response.status_code == 404 - assert response.json()["error"] == f"Provider credentials not found" + assert response.json()["error"] == "Provider credentials not found" def test_delete_all_credentials( diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index a3631b6a..3a87ce10 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -11,7 +11,6 @@ ) from app.models import CredsCreate, CredsUpdate from app.core.providers import Provider -from app.tests.utils.utils import get_credential_by_provider from app.tests.utils.test_data import ( create_test_project, create_test_credential, diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index ac94a183..a9eefe45 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -60,17 +60,6 @@ def get_user_from_api_key(db: Session, api_key_headers: dict[str, str]) -> APIKe return api_key -def get_credential_by_provider(creds: List[Credential], provider: str) -> Credential: - """ - From a list of credentials, return the one matching the given provider. - Raises ValueError if not found. - """ - for c in creds: - if c.provider == provider: - return c - raise ValueError(f"No credential found for provider: {provider}") - - def get_non_existent_id(session: Session, model: Type[T]) -> int: result = session.exec(select(model.id).order_by(model.id.desc())).first() return (result or 0) + 1 From 6d30776e551f4d0dabd9d5ed409bf8dcd32b1084 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Mon, 14 Jul 2025 16:19:14 +0530 Subject: [PATCH 15/15] formatting changes --- backend/app/crud/credentials.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index 14a7687b..f456d48e 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -96,11 +96,12 @@ def get_provider_credential( full: bool = False, ) -> Optional[Union[Dict[str, Any], Credential]]: """ - Fetches credentials for a specific provider of a project. + Fetch credentials for a specific provider within a project. Returns: - Optional[Union[Dict[str, Any], Credential]]: If full is True, returns full Credential object. - Otherwise returns just the decrypted credentials dict. + Optional[Union[Dict[str, Any], Credential]]: + - If `full` is True, returns the full Credential SQLModel object. + - Otherwise, returns the decrypted credentials as a dictionary. """ validate_provider(provider)