diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index 0091e8e0..ae267e88 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -1,8 +1,6 @@ -from typing import List - from fastapi import APIRouter, Depends -from app.api.deps import SessionDep, get_current_active_superuser +from app.api.deps import SessionDep, get_current_user_org_project from app.crud.credentials import ( get_creds_by_org, get_provider_credential, @@ -11,10 +9,7 @@ update_creds_for_org, remove_provider_credential, ) -from app.crud import validate_organization, validate_project -from app.models import CredsCreate, CredsPublic, CredsUpdate -from app.models.organization import Organization -from app.models.project import Project +from app.models import CredsCreate, CredsPublic, CredsUpdate, UserProjectOrg from app.utils import APIResponse from app.core.providers import validate_provider from app.core.exception_handlers import HTTPException @@ -24,31 +19,25 @@ @router.post( "/", - dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[List[CredsPublic]], - summary="Create new credentials for an organization and project", - description="Creates new credentials for a specific organization and project combination. This endpoint requires superuser privileges. Each organization can have different credentials for different providers and projects. Only one credential per provider is allowed per organization-project combination.", + response_model=APIResponse[list[CredsPublic]], + summary="Create new credentials for the current organization and project", + description="Creates new credentials for the caller's organization and project. Each organization can have different credentials for different providers and projects. Only one credential per provider is allowed per organization-project combination.", ) -def create_new_credential(*, session: SessionDep, creds_in: CredsCreate): - # Validate organization - validate_organization(session, creds_in.organization_id) - - # Validate project if provided - if creds_in.project_id: - project = validate_project(session, creds_in.project_id) - if project.organization_id != creds_in.organization_id: - raise HTTPException( - status_code=400, - detail="Project does not belong to the specified organization", - ) +def create_new_credential( + *, + session: SessionDep, + creds_in: CredsCreate, + _current_user: UserProjectOrg = Depends(get_current_user_org_project), +): + # Project comes from API key context; no cross-org check needed here # Prevent duplicate credentials for provider in creds_in.credential.keys(): existing_cred = get_provider_credential( session=session, - org_id=creds_in.organization_id, + org_id=_current_user.organization_id, provider=provider, - project_id=creds_in.project_id, + project_id=_current_user.project_id, ) if existing_cred: raise HTTPException( @@ -60,22 +49,34 @@ def create_new_credential(*, session: SessionDep, creds_in: CredsCreate): ) # Create credentials - new_creds = set_creds_for_org(session=session, creds_add=creds_in) - if not new_creds: + created_creds = set_creds_for_org( + session=session, + creds_add=creds_in, + organization_id=_current_user.organization_id, + project_id=_current_user.project_id, + ) + if not created_creds: raise Exception(status_code=500, detail="Failed to create credentials") - return APIResponse.success_response([cred.to_public() for cred in new_creds]) + return APIResponse.success_response([cred.to_public() for cred in created_creds]) @router.get( - "/{org_id}", - dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[List[CredsPublic]], - summary="Get all credentials for an organization and project", - description="Retrieves all provider credentials associated with a specific organization and project combination. If project_id is not provided, returns credentials for the organization level. This endpoint requires superuser privileges.", + "/", + response_model=APIResponse[list[CredsPublic]], + summary="Get all credentials for current org and project", + description="Retrieves all provider credentials associated with the caller's organization and project.", ) -def read_credential(*, session: SessionDep, org_id: int, project_id: int | None = None): - creds = get_creds_by_org(session=session, org_id=org_id, project_id=project_id) +def read_credential( + *, + session: SessionDep, + _current_user: UserProjectOrg = Depends(get_current_user_org_project), +): + creds = get_creds_by_org( + session=session, + org_id=_current_user.organization_id, + project_id=_current_user.project_id, + ) if not creds: raise HTTPException(status_code=404, detail="Credentials not found") @@ -83,72 +84,86 @@ def read_credential(*, session: SessionDep, org_id: int, project_id: int | None @router.get( - "/{org_id}/{provider}", - dependencies=[Depends(get_current_active_superuser)], + "/provider/{provider}", response_model=APIResponse[dict], - summary="Get specific provider credentials for an organization and project", - description="Retrieves credentials for a specific provider (e.g., 'openai', 'anthropic') for a given organization and project combination. If project_id is not provided, returns organization-level credentials. This endpoint requires superuser privileges.", + summary="Get specific provider credentials for current org and project", + description="Retrieves credentials for a specific provider (e.g., 'openai', 'anthropic') for the caller's organization and project.", ) def read_provider_credential( - *, session: SessionDep, org_id: int, provider: str, project_id: int | None = None + *, + session: SessionDep, + provider: str, + _current_user: UserProjectOrg = Depends(get_current_user_org_project), ): provider_enum = validate_provider(provider) - provider_creds = get_provider_credential( + credential = get_provider_credential( session=session, - org_id=org_id, + org_id=_current_user.organization_id, provider=provider_enum, - project_id=project_id, + project_id=_current_user.project_id, ) - if provider_creds is None: + if credential is None: raise HTTPException(status_code=404, detail="Provider credentials not found") - return APIResponse.success_response(provider_creds) + return APIResponse.success_response(credential) @router.patch( - "/{org_id}", - dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[List[CredsPublic]], - summary="Update organization and project credentials", - description="Updates credentials for a specific organization and project combination. Can update specific provider credentials or add new providers. If project_id is provided in the update, credentials will be moved to that project. This endpoint requires superuser privileges.", + "/", + response_model=APIResponse[list[CredsPublic]], + summary="Update credentials for current org and project", + description="Updates credentials for a specific provider of the caller's organization and project.", ) -def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate): - validate_organization(session, org_id) +def update_credential( + *, + session: SessionDep, + creds_in: CredsUpdate, + _current_user: UserProjectOrg = Depends(get_current_user_org_project), +): if not creds_in or not creds_in.provider or not creds_in.credential: raise HTTPException( status_code=400, detail="Provider and credential must be provided" ) - updated_creds = update_creds_for_org( - session=session, org_id=org_id, creds_in=creds_in + # Pass project_id directly to the CRUD function since CredsUpdate no longer has this field + updated_credential = update_creds_for_org( + session=session, + org_id=_current_user.organization_id, + creds_in=creds_in, + project_id=_current_user.project_id, ) - return APIResponse.success_response([cred.to_public() for cred in updated_creds]) + return APIResponse.success_response( + [cred.to_public() for cred in updated_credential] + ) @router.delete( - "/{org_id}/{provider}", - dependencies=[Depends(get_current_active_superuser)], + "/provider/{provider}", response_model=APIResponse[dict], - summary="Delete specific provider credentials for an organization and project", + summary="Delete specific provider credentials for current org and project", ) def delete_provider_credential( - *, session: SessionDep, org_id: int, provider: str, project_id: int | None = None + *, + session: SessionDep, + provider: str, + _current_user: UserProjectOrg = Depends(get_current_user_org_project), ): provider_enum = validate_provider(provider) - if not provider_enum: - raise HTTPException(status_code=400, detail="Invalid provider") provider_creds = get_provider_credential( session=session, - org_id=org_id, + org_id=_current_user.organization_id, provider=provider_enum, - project_id=project_id, + project_id=_current_user.project_id, ) if provider_creds is None: raise HTTPException(status_code=404, detail="Provider credentials not found") - updated_creds = remove_provider_credential( - session=session, org_id=org_id, provider=provider_enum, project_id=project_id + remove_provider_credential( + session=session, + org_id=_current_user.organization_id, + provider=provider_enum, + project_id=_current_user.project_id, ) return APIResponse.success_response( @@ -157,19 +172,26 @@ def delete_provider_credential( @router.delete( - "/{org_id}", - dependencies=[Depends(get_current_active_superuser)], + "/", response_model=APIResponse[dict], - summary="Delete all credentials for an organization and project", - description="Removes all credentials for a specific organization and project combination. If project_id is provided, only removes credentials for that project. This is a soft delete operation that marks credentials as inactive. This endpoint requires superuser privileges.", + summary="Delete all credentials for current org and project", + description="Removes all credentials for the caller's organization and project. This is a soft delete operation that marks credentials as inactive.", ) def delete_all_credentials( - *, session: SessionDep, org_id: int, project_id: int | None = None + *, + session: SessionDep, + _current_user: UserProjectOrg = Depends(get_current_user_org_project), ): - creds = remove_creds_for_org(session=session, org_id=org_id, project_id=project_id) + creds = remove_creds_for_org( + session=session, + org_id=_current_user.organization_id, + project_id=_current_user.project_id, + ) if not creds: raise HTTPException( - status_code=404, detail="Credentials for organization not found" + status_code=404, detail="Credentials for organization/project not found" ) - return APIResponse.success_response({"message": "Credentials deleted successfully"}) + return APIResponse.success_response( + {"message": "All credentials deleted successfully"} + ) diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index f456d48e..d23a6164 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -14,7 +14,9 @@ from app.core.exception_handlers import HTTPException -def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> List[Credential]: +def set_creds_for_org( + *, session: Session, creds_add: CredsCreate, organization_id: int, project_id: int +) -> List[Credential]: """Set credentials for an organization. Creates a separate row for each provider.""" created_credentials = [] @@ -31,8 +33,8 @@ def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> List[Crede # Create a row for each provider credential = Credential( - organization_id=creds_add.organization_id, - project_id=creds_add.project_id, + organization_id=organization_id, + project_id=project_id, is_active=creds_add.is_active, provider=provider, credential=encrypted_credentials, @@ -127,7 +129,11 @@ def get_providers( def update_creds_for_org( - *, session: Session, org_id: int, creds_in: CredsUpdate + *, + session: Session, + org_id: int, + creds_in: CredsUpdate, + project_id: Optional[int] = None, ) -> List[Credential]: """Updates credentials for a specific provider of an organization.""" if not creds_in.provider or not creds_in.credential: @@ -143,9 +149,7 @@ def update_creds_for_org( Credential.organization_id == org_id, Credential.provider == creds_in.provider, Credential.is_active == True, - Credential.project_id == creds_in.project_id - if creds_in.project_id is not None - else True, + Credential.project_id == project_id if project_id is not None else True, ) creds = session.exec(statement).first() if creds is None: diff --git a/backend/app/models/credentials.py b/backend/app/models/credentials.py index 7096ef29..9785bd4f 100644 --- a/backend/app/models/credentials.py +++ b/backend/app/models/credentials.py @@ -16,12 +16,13 @@ class CredsBase(SQLModel): is_active: bool = True -class CredsCreate(CredsBase): +class CredsCreate(SQLModel): """Create new credentials for an organization. The credential field should be a dictionary mapping provider names to their credentials. Example: {"openai": {"api_key": "..."}, "langfuse": {"public_key": "..."}} """ + is_active: bool = True credential: Dict[str, Any] = Field( default=None, description="Dictionary mapping provider names to their credentials", @@ -42,9 +43,6 @@ class CredsUpdate(SQLModel): is_active: Optional[bool] = Field( default=None, description="Whether the credentials are active" ) - project_id: Optional[int] = Field( - default=None, description="Project ID to associate with these credentials" - ) class Credential(CredsBase, table=True): diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 52bf8429..8625cf9e 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -2,40 +2,41 @@ from fastapi.testclient import TestClient from sqlmodel import Session -from app.main import app -from app.crud import get_provider_credential -from app.models import Organization, Project +from app.models import APIKeyPublic from app.core.config import settings from app.core.providers import Provider from app.models.credentials import Credential from app.core.security import decrypt_credentials from app.tests.utils.utils import ( generate_random_string, - get_non_existent_id, ) from app.tests.utils.test_data import ( create_test_credential, - create_test_organization, - create_test_project, test_credential_data, ) -client = TestClient(app) - - @pytest.fixture def create_test_credentials(db: Session): return create_test_credential(db) -def test_set_credential(db: Session, superuser_token_headers: dict[str, str]): - project = create_test_project(db) +def test_set_credential( + client: TestClient, + user_api_key: APIKeyPublic, +): + project_id = user_api_key.project_id + org_id = user_api_key.organization_id api_key = "sk-" + generate_random_string(10) + # Ensure clean state for provider + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, + ) credential_data = { - "organization_id": project.organization_id, - "project_id": project.id, + "organization_id": org_id, + "project_id": project_id, "is_active": True, "credential": { Provider.OPENAI.value: { @@ -49,7 +50,7 @@ def test_set_credential(db: Session, superuser_token_headers: dict[str, str]): response = client.post( f"{settings.API_V1_STR}/credentials/", json=credential_data, - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 @@ -57,107 +58,107 @@ def test_set_credential(db: Session, superuser_token_headers: dict[str, str]): assert isinstance(data, list) assert len(data) == 1 - assert data[0]["organization_id"] == project.organization_id + assert data[0]["organization_id"] == org_id assert data[0]["provider"] == Provider.OPENAI.value assert data[0]["credential"]["model"] == "gpt-4" -def test_set_credentials_for_invalid_project_org_relationship( - db: Session, superuser_token_headers: dict[str, str] +def test_set_credentials_ignored_mismatched_ids( + client: TestClient, + user_api_key: APIKeyPublic, ): - org1 = create_test_organization(db) - project2 = create_test_project(db) - - credential_data_invalid = { - "organization_id": org1.id, + # Even if mismatched IDs are sent, route uses API key context + # Ensure clean state for provider + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, + ) + credential_data = { + "organization_id": 999999, + "project_id": 999999, "is_active": True, - "project_id": project2.id, "credential": {Provider.OPENAI.value: {"api_key": "sk-123", "model": "gpt-4"}}, } response_invalid = client.post( f"{settings.API_V1_STR}/credentials/", - json=credential_data_invalid, - headers=superuser_token_headers, - ) - assert response_invalid.status_code == 400 - assert ( - response_invalid.json()["error"] - == "Project does not belong to the specified organization" + json=credential_data, + headers={"X-API-KEY": user_api_key.key}, ) + assert response_invalid.status_code == 200 -def test_set_credentials_for_project_not_found( - db: Session, superuser_token_headers: dict[str, str] +def test_read_credentials_with_creds( + client: TestClient, + user_api_key: APIKeyPublic, ): - # Setup: Create an organization but no project - org = create_test_organization(db) - non_existent_project_id = get_non_existent_id(db, Project) - - credential_data_invalid_project = { - "organization_id": org.id, + # Ensure at least one credential exists for current project + api_key_value = "sk-" + generate_random_string(10) + payload = { + "organization_id": user_api_key.organization_id, + "project_id": user_api_key.project_id, "is_active": True, - "project_id": non_existent_project_id, - "credential": {Provider.OPENAI.value: {"api_key": "sk-123", "model": "gpt-4"}}, + "credential": { + Provider.OPENAI.value: {"api_key": api_key_value, "model": "gpt-4"} + }, } - - response_invalid_project = client.post( + client.post( f"{settings.API_V1_STR}/credentials/", - json=credential_data_invalid_project, - headers=superuser_token_headers, - ) - - assert response_invalid_project.status_code == 404 - assert response_invalid_project.json()["error"] == "Project not found" - - -def test_read_credentials_with_creds( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials -): - _, project = create_test_credentials - - credential = get_provider_credential( - session=db, - org_id=project.organization_id, - provider="openai", - project_id=project.id, - full=True, + json=payload, + headers={"X-API-KEY": user_api_key.key}, ) response = client.get( - f"{settings.API_V1_STR}/credentials/{credential.organization_id}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 data = response.json()["data"] assert isinstance(data, list) - assert len(data) == 1 - assert data[0]["organization_id"] == project.organization_id - assert data[0]["provider"] == Provider.OPENAI.value - assert data[0]["credential"]["model"] == "gpt-4" + assert len(data) >= 1 def test_read_credentials_not_found( - db: Session, superuser_token_headers: dict[str, str] + client: TestClient, db: Session, user_api_key: APIKeyPublic ): - non_existent_creds_id = get_non_existent_id(db, Credential) + # Delete all first to ensure none remain + client.delete( + f"{settings.API_V1_STR}/credentials/", headers={"X-API-KEY": user_api_key.key} + ) response = client.get( - f"{settings.API_V1_STR}/credentials/{non_existent_creds_id}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 404 assert "Credentials not found" in response.json()["error"] def test_read_provider_credential( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + client: TestClient, + user_api_key: APIKeyPublic, ): - _, project = create_test_credentials + # Ensure exists + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, + ) + client.post( + f"{settings.API_V1_STR}/credentials/", + json={ + "organization_id": user_api_key.organization_id, + "project_id": user_api_key.project_id, + "is_active": True, + "credential": { + Provider.OPENAI.value: {"api_key": "sk-xyz", "model": "gpt-4"} + }, + }, + headers={"X-API-KEY": user_api_key.key}, + ) response = client.get( - f"{settings.API_V1_STR}/credentials/{project.organization_id}/{Provider.OPENAI.value}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 @@ -167,13 +168,15 @@ def test_read_provider_credential( def test_read_provider_credential_not_found( - db: Session, superuser_token_headers: dict[str, str] + client: TestClient, db: Session, user_api_key: APIKeyPublic ): - org = create_test_organization(db) - + # Ensure none + client.delete( + f"{settings.API_V1_STR}/credentials/", headers={"X-API-KEY": user_api_key.key} + ) response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 404 @@ -181,16 +184,25 @@ def test_read_provider_credential_not_found( def test_update_credentials( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + client: TestClient, + user_api_key: APIKeyPublic, ): - _, project = create_test_credentials - - credential = get_provider_credential( - session=db, - org_id=project.organization_id, - provider="openai", - project_id=project.id, - full=True, + # Ensure exists + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, + ) + client.post( + f"{settings.API_V1_STR}/credentials/", + json={ + "organization_id": user_api_key.organization_id, + "project_id": user_api_key.project_id, + "is_active": True, + "credential": { + Provider.OPENAI.value: {"api_key": "sk-abc", "model": "gpt-4"} + }, + }, + headers={"X-API-KEY": user_api_key.key}, ) update_data = { @@ -203,9 +215,9 @@ def test_update_credentials( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{credential.organization_id}", + f"{settings.API_V1_STR}/credentials/", json=update_data, - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 @@ -217,28 +229,14 @@ def test_update_credentials( assert data[0]["updated_at"] is not None -def test_update_credentials_failed_update( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials +def test_update_credentials_not_found_for_provider( + client: TestClient, db: Session, user_api_key: APIKeyPublic ): - _, project = create_test_credentials - - credential = get_provider_credential( - session=db, - org_id=project.organization_id, - provider="openai", - project_id=project.id, - full=True, + # Ensure none exist + client.delete( + f"{settings.API_V1_STR}/credentials/", headers={"X-API-KEY": user_api_key.key} ) - org_without_credential = create_test_organization(db) - - existing_credential = ( - db.query(Credential) - .filter(credential.organization_id == org_without_credential.id) - .all() - ) - assert len(existing_credential) == 0 - update_data = { "provider": Provider.OPENAI.value, "credential": { @@ -248,73 +246,53 @@ def test_update_credentials_failed_update( }, } - response_invalid_org = client.patch( - f"{settings.API_V1_STR}/credentials/{org_without_credential.id}", - json=update_data, - headers=superuser_token_headers, - ) - assert response_invalid_org.status_code == 404 - assert ( - response_invalid_org.json()["error"] - == "Credentials not found for this provider" - ) - - -def test_update_credentials_not_found( - db: Session, superuser_token_headers: dict[str, str] -): - non_existent_org_id = get_non_existent_id(db, Organization) - - update_data = { - "provider": Provider.OPENAI.value, - "credential": { - "api_key": "sk-" + generate_random_string(), - "model": "gpt-4", - "temperature": 0.7, - }, - } - - response = client.patch( - f"{settings.API_V1_STR}/credentials/{non_existent_org_id}", + response_invalid = client.patch( + f"{settings.API_V1_STR}/credentials/", json=update_data, - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) - - assert response.status_code == 404 # Expect 404 for non-existent organization - assert "Organization not found" in response.json()["error"] + assert response_invalid.status_code == 404 + assert response_invalid.json()["error"] == "Credentials not found for this provider" def test_delete_provider_credential( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + client: TestClient, + user_api_key: APIKeyPublic, ): - _, project = create_test_credentials - - credential = get_provider_credential( - session=db, - org_id=project.organization_id, - provider="openai", - project_id=project.id, - full=True, + # Ensure exists + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, ) - - response = client.delete( - f"{settings.API_V1_STR}/credentials/{credential.organization_id}/{Provider.OPENAI.value}", - headers=superuser_token_headers, + client.post( + f"{settings.API_V1_STR}/credentials/", + json={ + "organization_id": user_api_key.organization_id, + "project_id": user_api_key.project_id, + "is_active": True, + "credential": { + Provider.OPENAI.value: {"api_key": "sk-abc", "model": "gpt-4"} + }, + }, + headers={"X-API-KEY": user_api_key.key}, ) - assert response.status_code == 200 - data = response.json()["data"] - assert data["message"] == "Provider credentials removed successfully" + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, + ) def test_delete_provider_credential_not_found( - db: Session, superuser_token_headers: dict[str, str] + client: TestClient, db: Session, user_api_key: APIKeyPublic ): - org = create_test_organization(db) - + # Ensure not exists + client.delete( + f"{settings.API_V1_STR}/credentials/", headers={"X-API-KEY": user_api_key.key} + ) response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 404 @@ -322,57 +300,76 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( - db: Session, superuser_token_headers: dict[str, str], create_test_credentials + client: TestClient, + user_api_key: APIKeyPublic, ): - _, project = create_test_credentials - - credential = get_provider_credential( - session=db, - org_id=project.organization_id, - provider="openai", - project_id=project.id, - full=True, + # Ensure exists + client.delete( + f"{settings.API_V1_STR}/credentials/", + headers={"X-API-KEY": user_api_key.key}, + ) + client.post( + f"{settings.API_V1_STR}/credentials/", + json={ + "organization_id": user_api_key.organization_id, + "project_id": user_api_key.project_id, + "is_active": True, + "credential": { + Provider.OPENAI.value: {"api_key": "sk-abc", "model": "gpt-4"} + }, + }, + headers={"X-API-KEY": user_api_key.key}, ) response = client.delete( - f"{settings.API_V1_STR}/credentials/{credential.organization_id}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 # Expect 200 for successful deletion data = response.json()["data"] - assert data["message"] == "Credentials deleted successfully" + assert data["message"] == "All credentials deleted successfully" # Verify the credentials are soft deleted response = client.get( - f"{settings.API_V1_STR}/credentials/{credential.organization_id}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 404 # Expect 404 as credentials are soft deleted assert response.json()["error"] == "Credentials not found" def test_delete_all_credentials_not_found( - db: Session, superuser_token_headers: dict[str, str] + client: TestClient, db: Session, user_api_key: APIKeyPublic ): - non_existent_credential_id = get_non_existent_id(db, Credential) + # Ensure already deleted + client.delete( + f"{settings.API_V1_STR}/credentials/", headers={"X-API-KEY": user_api_key.key} + ) response = client.delete( - f"{settings.API_V1_STR}/credentials/{non_existent_credential_id}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 404 - assert "Credentials for organization not found" in response.json()["error"] + assert "Credentials for organization/project not found" in response.json()["error"] def test_duplicate_credential_creation( - db: Session, superuser_token_headers: dict[str, str] + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, ): credential = test_credential_data(db) + # Ensure clean state for provider + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, + ) response = client.post( f"{settings.API_V1_STR}/credentials/", json=credential.dict(), - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 @@ -380,7 +377,7 @@ def test_duplicate_credential_creation( response = client.post( f"{settings.API_V1_STR}/credentials/", json=credential.dict(), - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 400 @@ -388,14 +385,19 @@ def test_duplicate_credential_creation( def test_multiple_provider_credentials( - db: Session, superuser_token_headers: dict[str, str] + client: TestClient, + user_api_key: APIKeyPublic, ): - project = create_test_project(db) + # Ensure clean state for current org/project + client.delete( + f"{settings.API_V1_STR}/credentials/", + headers={"X-API-KEY": user_api_key.key}, + ) # Create OpenAI credentials openai_credential = { - "organization_id": project.organization_id, - "project_id": project.id, + "organization_id": user_api_key.organization_id, + "project_id": user_api_key.project_id, "is_active": True, "credential": { Provider.OPENAI.value: { @@ -408,8 +410,8 @@ def test_multiple_provider_credentials( # Create Langfuse credentials langfuse_credential = { - "organization_id": project.organization_id, - "project_id": project.id, + "organization_id": user_api_key.organization_id, + "project_id": user_api_key.project_id, "is_active": True, "credential": { Provider.LANGFUSE.value: { @@ -424,21 +426,21 @@ def test_multiple_provider_credentials( response = client.post( f"{settings.API_V1_STR}/credentials/", json=openai_credential, - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 response = client.post( f"{settings.API_V1_STR}/credentials/", json=langfuse_credential, - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 # Fetch all credentials response = client.get( - f"{settings.API_V1_STR}/credentials/{project.organization_id}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 data = response.json()["data"] @@ -448,22 +450,32 @@ def test_multiple_provider_credentials( assert Provider.LANGFUSE.value in providers -def test_credential_encryption(db: Session, superuser_token_headers: dict[str, str]): +def test_credential_encryption( + client: TestClient, + db: Session, + user_api_key: APIKeyPublic, +): credential = test_credential_data(db) original_api_key = credential.credential[Provider.OPENAI.value]["api_key"] # Create credentials + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, + ) response = client.post( f"{settings.API_V1_STR}/credentials/", json=credential.dict(), - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 db_credential = ( db.query(Credential) .filter( - Credential.organization_id == credential.organization_id, + Credential.organization_id == user_api_key.organization_id, + Credential.project_id == user_api_key.project_id, + Credential.is_active == True, Credential.provider == Provider.OPENAI.value, ) .first() @@ -475,27 +487,31 @@ def test_credential_encryption(db: Session, superuser_token_headers: dict[str, s # Verify we can decrypt and get the original value decrypted_creds = decrypt_credentials(db_credential.credential) - assert decrypted_creds["api_key"] == original_api_key + assert decrypted_creds.get("api_key") == original_api_key def test_credential_encryption_consistency( - db: Session, superuser_token_headers: dict[str, str] + client: TestClient, db: Session, user_api_key: APIKeyPublic ): credentials = test_credential_data(db) original_api_key = credentials.credential[Provider.OPENAI.value]["api_key"] # Create credentials + client.delete( + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, + ) response = client.post( f"{settings.API_V1_STR}/credentials/", json=credentials.dict(), - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 # Fetch the credentials through the API response = client.get( - f"{settings.API_V1_STR}/credentials/{credentials.organization_id}/{Provider.OPENAI.value}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 data = response.json()["data"] @@ -515,15 +531,15 @@ def test_credential_encryption_consistency( } response = client.patch( - f"{settings.API_V1_STR}/credentials/{credentials.organization_id}", + f"{settings.API_V1_STR}/credentials/", json=update_data, - headers=superuser_token_headers, + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 response = client.get( - f"{settings.API_V1_STR}/credentials/{credentials.organization_id}/{Provider.OPENAI.value}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 data = response.json()["data"] diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 827c7e01..8eeeda28 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -32,12 +32,16 @@ def test_set_credentials_for_org(db: Session) -> None: }, } credentials_create = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, + is_active=True, credential=credentials_data, ) - created_credentials = set_creds_for_org(session=db, creds_add=credentials_create) + created_credentials = set_creds_for_org( + session=db, + creds_add=credentials_create, + organization_id=project.organization_id, + project_id=project.id, + ) assert len(created_credentials) == 2 assert all( @@ -63,11 +67,15 @@ def test_get_creds_by_org(db: Session) -> None: } credentials_create = CredsCreate( + is_active=True, + credential=credentials_data, + ) + set_creds_for_org( + session=db, + creds_add=credentials_create, organization_id=project.organization_id, project_id=project.id, - credential=credentials_data, ) - set_creds_for_org(session=db, creds_add=credentials_create) # Test retrieving credentials retrieved_creds = get_creds_by_org(session=db, org_id=project.organization_id) @@ -84,10 +92,16 @@ def test_get_provider_credential(db: Session) -> None: credentials_create = test_credential_data(db) original_api_key = credentials_create.credential[Provider.OPENAI.value]["api_key"] - set_creds_for_org(session=db, creds_add=credentials_create) + project = create_test_project(db) + set_creds_for_org( + session=db, + creds_add=credentials_create, + organization_id=project.organization_id, + project_id=project.id, + ) # Test retrieving specific provider credentials retrieved_cred = get_provider_credential( - session=db, org_id=credentials_create.organization_id, provider="openai" + session=db, org_id=project.organization_id, provider="openai" ) assert retrieved_cred is not None @@ -108,12 +122,13 @@ def test_update_creds_for_org(db: Session) -> None: ) # Update credentials updated_creds = {"api_key": "updated-key"} - creds_update = CredsUpdate( - project_id=project.id, provider="openai", credential=updated_creds - ) + creds_update = CredsUpdate(provider="openai", credential=updated_creds) updated = update_creds_for_org( - session=db, org_id=credential.organization_id, creds_in=creds_update + session=db, + org_id=credential.organization_id, + creds_in=creds_update, + project_id=project.id, ) assert len(updated) == 1 @@ -166,11 +181,15 @@ def test_remove_creds_for_org(db: Session) -> None: } creds_create = CredsCreate( + is_active=True, + credential=credentials_data, + ) + set_creds_for_org( + session=db, + creds_add=creds_create, organization_id=project.organization_id, project_id=project.id, - credential=credentials_data, ) - set_creds_for_org(session=db, creds_add=creds_create) # Remove all credentials removed = remove_creds_for_org(session=db, org_id=project.organization_id) @@ -191,13 +210,17 @@ def test_invalid_provider(db: Session) -> None: # Test with unsupported provider credentials_data = {"gemini": {"api_key": "test-key"}} credentials_create = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, + is_active=True, credential=credentials_data, ) with pytest.raises(ValueError, match="Unsupported provider"): - set_creds_for_org(session=db, creds_add=credentials_create) + set_creds_for_org( + session=db, + creds_add=credentials_create, + organization_id=project.organization_id, + project_id=project.id, + ) def test_duplicate_provider_credentials(db: Session) -> None: @@ -208,11 +231,15 @@ def test_duplicate_provider_credentials(db: Session) -> None: credentials_data = {"openai": {"api_key": "test-key"}} credentials_create = CredsCreate( + is_active=True, + credential=credentials_data, + ) + set_creds_for_org( + session=db, + creds_add=credentials_create, organization_id=project.organization_id, project_id=project.id, - credential=credentials_data, ) - set_creds_for_org(session=db, creds_add=credentials_create) # Verify credentials exist and are active existing_creds = get_provider_credential( @@ -236,13 +263,17 @@ def test_langfuse_credential_validation(db: Session) -> None: } } credentials_create = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, + is_active=True, credential=invalid_credentials, ) with pytest.raises(ValueError): - set_creds_for_org(session=db, creds_add=credentials_create) + set_creds_for_org( + session=db, + creds_add=credentials_create, + organization_id=project.organization_id, + project_id=project.id, + ) # Test with valid Langfuse credentials valid_credentials = { @@ -254,11 +285,15 @@ def test_langfuse_credential_validation(db: Session) -> None: } credentials_create = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, + is_active=True, credential=valid_credentials, ) - created_credentials = set_creds_for_org(session=db, creds_add=credentials_create) + created_credentials = set_creds_for_org( + session=db, + creds_add=credentials_create, + organization_id=project.organization_id, + project_id=project.id, + ) assert len(created_credentials) == 1 assert created_credentials[0].provider == "langfuse" diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 0d9c5d3b..302e6ac4 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -76,8 +76,6 @@ def test_credential_data(db: Session) -> CredsCreate: project = create_test_project(db) api_key = "sk-" + generate_random_string(10) creds_data = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, is_active=True, credential={ Provider.OPENAI.value: { @@ -100,8 +98,6 @@ def create_test_credential(db: Session) -> tuple[list[Credential], Project]: project = create_test_project(db) api_key = "sk-" + generate_random_string(10) creds_data = CredsCreate( - organization_id=project.organization_id, - project_id=project.id, is_active=True, credential={ Provider.OPENAI.value: { @@ -111,4 +107,12 @@ def create_test_credential(db: Session) -> tuple[list[Credential], Project]: } }, ) - return set_creds_for_org(session=db, creds_add=creds_data), project + return ( + set_creds_for_org( + session=db, + creds_add=creds_data, + organization_id=project.organization_id, + project_id=project.id, + ), + project, + )