Skip to content
Merged
2 changes: 2 additions & 0 deletions backend/app/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
get_key_by_org,
update_creds_for_org,
remove_creds_for_org,
get_provider_credential,
remove_provider_credential,
)

from .thread_results import upsert_thread_result, get_thread_result
Expand Down
23 changes: 17 additions & 6 deletions backend/app/crud/credentials.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Union
from sqlmodel import Session, select
from sqlalchemy.exc import IntegrityError
from datetime import datetime, timezone
Expand Down Expand Up @@ -88,9 +88,21 @@ def get_creds_by_org(


def get_provider_credential(
*, session: Session, org_id: int, provider: str, project_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""Fetches credentials for a specific provider of an organization."""
*,
session: Session,
org_id: int,
provider: str,
project_id: Optional[int] = None,
full: bool = False,
) -> Optional[Union[Dict[str, Any], Credential]]:
"""
Fetch credentials for a specific provider within a project.

Returns:
Optional[Union[Dict[str, Any], Credential]]:
- If `full` is True, returns the full Credential SQLModel object.
- Otherwise, returns the decrypted credentials as a dictionary.
"""
validate_provider(provider)

statement = select(Credential).where(
Expand All @@ -102,8 +114,7 @@ def get_provider_credential(
creds = session.exec(statement).first()

if creds and creds.credential:
# Decrypt entire credentials object
return decrypt_credentials(creds.credential)
return creds if full else decrypt_credentials(creds.credential)
return None


Expand Down
90 changes: 21 additions & 69 deletions backend/app/tests/api/routes/test_api_key.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,19 @@
import uuid
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session

from app.main import app
from app.models import APIKey, User, Organization, Project
from app.models import APIKey
from app.core.config import settings
from app.crud.api_key import create_api_key
from app.tests.utils.utils import random_email
from app.core.security import get_password_hash
from app.tests.utils.utils import get_non_existent_id
from app.tests.utils.user import create_random_user
from app.tests.utils.test_data import create_test_api_key, create_test_project

client = TestClient(app)


def create_test_user(db: Session) -> User:
user = User(
email=random_email(),
hashed_password=get_password_hash("password123"),
is_superuser=True,
)
db.add(user)
db.commit()
db.refresh(user)
return user


def create_test_organization(db: Session) -> Organization:
org = Organization(
name=f"Test Organization {uuid.uuid4()}", description="Test Organization"
)
db.add(org)
db.commit()
db.refresh(org)
return org


def create_test_project(db: Session, organization_id: int) -> Project:
project = Project(name="Test Project", organization_id=organization_id)
db.add(project)
db.commit()
db.refresh(project)
return project


def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
user = create_random_user(db)
project = create_test_project(db)

response = client.post(
f"{settings.API_V1_STR}/apikeys",
Expand All @@ -57,14 +25,13 @@ def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]):
assert data["success"] is True
assert "id" in data["data"]
assert "key" in data["data"]
assert data["data"]["organization_id"] == org.id
assert data["data"]["organization_id"] == project.organization_id
assert data["data"]["user_id"] == user.id


def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
user = create_random_user(db)
project = create_test_project(db)

client.post(
f"{settings.API_V1_STR}/apikeys",
Expand All @@ -81,16 +48,11 @@ def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str


def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
api_key = create_api_key(
db, organization_id=org.id, user_id=user.id, project_id=project.id
)
api_key = create_test_api_key(db)

response = client.get(
f"{settings.API_V1_STR}/apikeys",
params={"project_id": project.id},
params={"project_id": api_key.project_id},
headers=superuser_token_headers,
)
assert response.status_code == 200
Expand All @@ -100,17 +62,12 @@ def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]):
assert len(data["data"]) > 0

first_key = data["data"][0]
assert first_key["organization_id"] == org.id
assert first_key["user_id"] == user.id
assert first_key["organization_id"] == api_key.organization_id
assert first_key["user_id"] == api_key.user_id


def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
api_key = create_api_key(
db, organization_id=org.id, user_id=user.id, project_id=project.id
)
api_key = create_test_api_key(db)

response = client.get(
f"{settings.API_V1_STR}/apikeys/{api_key.id}",
Expand All @@ -121,25 +78,21 @@ def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]):
assert data["success"] is True
assert data["data"]["id"] == api_key.id
assert data["data"]["organization_id"] == api_key.organization_id
assert data["data"]["user_id"] == user.id
assert data["data"]["user_id"] == api_key.user_id


def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, str]):
api_key_id = get_non_existent_id(db, APIKey)
response = client.get(
f"{settings.API_V1_STR}/apikeys/999999",
f"{settings.API_V1_STR}/apikeys/{api_key_id}",
headers=superuser_token_headers,
)
assert response.status_code == 404
assert "API Key does not exist" in response.json()["error"]


def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
api_key = create_api_key(
db, organization_id=org.id, user_id=user.id, project_id=project.id
)
api_key = create_test_api_key(db)

response = client.delete(
f"{settings.API_V1_STR}/apikeys/{api_key.id}",
Expand All @@ -154,11 +107,10 @@ def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]):
def test_revoke_nonexistent_api_key(
db: Session, superuser_token_headers: dict[str, str]
):
user = create_test_user(db)
org = create_test_organization(db)
api_key_id = get_non_existent_id(db, APIKey)

response = client.delete(
f"{settings.API_V1_STR}/apikeys/999999",
f"{settings.API_V1_STR}/apikeys/{api_key_id}",
headers=superuser_token_headers,
)
assert response.status_code == 404
Expand Down
Loading