diff --git a/backend/app/alembic/migrate_api_key.py b/backend/app/alembic/migrate_api_key.py new file mode 100644 index 00000000..ab48901d --- /dev/null +++ b/backend/app/alembic/migrate_api_key.py @@ -0,0 +1,198 @@ +""" +Migration script to convert encrypted API keys to hashed format. + +This script: +1. Decrypts existing API keys from the old encrypted format +2. Extracts the prefix and secret from the decrypted keys +3. Hashes the secret using bcrypt +4. Generates UUID4 for the new primary key +5. Stores the prefix, hash, and UUID in the new format for backward compatibility + +The format is: "ApiKey {12-char-prefix}{31-char-secret}" (total 43 chars) +""" + +import logging +import uuid +from sqlalchemy.orm import Session +from sqlalchemy import text +from passlib.context import CryptContext + +from app.core.security import decrypt_api_key + +logger = logging.getLogger(__name__) + +# Use the same hash algorithm as APIKeyManager +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# Old format constants +OLD_PREFIX_NAME = "ApiKey " +OLD_PREFIX_LENGTH = 12 +OLD_SECRET_LENGTH = 31 +OLD_KEY_LENGTH = 43 # Total: 12 + 31 + + +def migrate_api_keys(session: Session, generate_uuid: bool = False) -> None: + """ + Migrate all existing API keys from encrypted format to hashed format. + + This function: + 1. Fetches all API keys with the old 'key' column + 2. Decrypts each key + 3. Extracts prefix and secret + 4. Hashes the secret + 5. Generates UUID4 for new_id column if generate_uuid is True + 6. Updates key_prefix, key_hash, and optionally new_id columns + + Args: + session: SQLAlchemy database session + generate_uuid: Whether to generate and set UUID for new_id column + """ + logger.info( + "[migrate_api_keys] Starting API key migration from encrypted to hashed format" + ) + + try: + # Fetch all API keys that have the old 'key' column + result = session.execute( + text("SELECT id, key FROM apikey WHERE key IS NOT NULL") + ) + api_keys = result.fetchall() + + if not api_keys: + logger.info("[migrate_api_keys] No API keys found to migrate") + return + + logger.info(f"[migrate_api_keys] Found {len(api_keys)} API keys to migrate") + + migrated_count = 0 + failed_count = 0 + + for row in api_keys: + key_id = row[0] + encrypted_key = row[1] + + try: + # Decrypt the API key + decrypted_key = decrypt_api_key(encrypted_key) + + # Validate format + if not decrypted_key.startswith(OLD_PREFIX_NAME): + logger.error( + f"[migrate_api_keys] Invalid key format for ID {key_id}: " + f"does not start with '{OLD_PREFIX_NAME}'" + ) + failed_count += 1 + continue + + # Extract the key part (after "ApiKey ") + key_part = decrypted_key[len(OLD_PREFIX_NAME) :] + + if len(key_part) != OLD_KEY_LENGTH: + logger.error( + f"[migrate_api_keys] Invalid key length for ID {key_id}: " + f"expected {OLD_KEY_LENGTH}, got {len(key_part)}" + ) + failed_count += 1 + continue + + # Extract prefix and secret + key_prefix = key_part[:OLD_PREFIX_LENGTH] + secret_key = key_part[OLD_PREFIX_LENGTH:] + + # Hash the secret + key_hash = pwd_context.hash(secret_key) + + # Generate UUID if requested + if generate_uuid: + new_uuid = uuid.uuid4() + # Update the record with prefix, hash, and UUID + session.execute( + text( + "UPDATE apikey SET key_prefix = :prefix, key_hash = :hash, new_id = :new_id " + "WHERE id = :id" + ), + { + "prefix": key_prefix, + "hash": key_hash, + "new_id": new_uuid, + "id": key_id, + }, + ) + else: + # Update the record with prefix and hash only + session.execute( + text( + "UPDATE apikey SET key_prefix = :prefix, key_hash = :hash " + "WHERE id = :id" + ), + {"prefix": key_prefix, "hash": key_hash, "id": key_id}, + ) + + migrated_count += 1 + logger.info( + f"[migrate_api_keys] Successfully migrated key ID {key_id} " + f"with prefix {key_prefix[:4]}..." + ) + + except Exception as e: + logger.error( + f"[migrate_api_keys] Failed to migrate key ID {key_id}: {str(e)}", + exc_info=True, + ) + failed_count += 1 + continue + + logger.info( + f"[migrate_api_keys] Migration completed: " + f"{migrated_count} successful, {failed_count} failed" + ) + + except Exception as e: + logger.error( + f"[migrate_api_keys] Fatal error during migration: {str(e)}", exc_info=True + ) + raise + + +def verify_migration(session: Session) -> bool: + """ + Verify that all API keys have been migrated successfully. + + Args: + session: SQLAlchemy database session + + Returns: + bool: True if all keys are migrated, False otherwise + """ + try: + # Check for any keys with NULL key_prefix or key_hash + result = session.execute( + text( + "SELECT COUNT(*) FROM apikey " + "WHERE key_prefix IS NULL OR key_hash IS NULL" + ) + ) + null_count = result.scalar() + + if null_count > 0: + logger.warning( + f"[verify_migration] Found {null_count} API keys with NULL " + "key_prefix or key_hash" + ) + return False + + # Check total count + result = session.execute(text("SELECT COUNT(*) FROM apikey")) + total_count = result.scalar() + + logger.info( + f"[verify_migration] All {total_count} API keys have been " + "successfully migrated" + ) + return True + + except Exception as e: + logger.error( + f"[verify_migration] Error verifying migration: {str(e)}", exc_info=True + ) + return False diff --git a/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py new file mode 100644 index 00000000..42d5080c --- /dev/null +++ b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py @@ -0,0 +1,69 @@ +"""Refactor API key table + +Revision ID: e7c68e43ce6f +Revises: 27c271ab6dd0 +Create Date: 2025-10-16 13:06:51.777671 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.orm import Session +from app.alembic.migrate_api_key import migrate_api_keys, verify_migration + + +# revision identifiers, used by Alembic. +revision = "e7c68e43ce6f" +down_revision = "27c271ab6dd0" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Step 1: Add new columns as nullable to allow migration + op.add_column( + "apikey", + sa.Column("key_prefix", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + op.add_column( + "apikey", + sa.Column("key_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + + # Step 2: Add UUID column before migration + op.add_column("apikey", sa.Column("new_id", sa.Uuid(), nullable=True)) + + # Step 3: Migrate existing encrypted keys to the new hashed format and generate UUIDs + bind = op.get_bind() + with Session(bind=bind) as session: + migrate_api_keys(session, generate_uuid=True) + + # Step 4: Verify migration was successful + if not verify_migration(session): + raise Exception( + "API key migration verification failed. Please check the logs." + ) + + session.flush() + + # Step 5: Make the columns non-nullable after migration + op.alter_column("apikey", "key_prefix", nullable=False) + op.alter_column("apikey", "key_hash", nullable=False) + + # Step 6: Replace old PK with UUID-based PK + op.drop_constraint("apikey_pkey", "apikey", type_="primary") + op.drop_column("apikey", "id") + op.alter_column("apikey", "new_id", new_column_name="id", nullable=False) + op.create_primary_key("apikey_pkey", "apikey", ["id"]) + + # Step 7: Update indexes and drop old key column + op.drop_index("ix_apikey_key", table_name="apikey") + op.create_index(op.f("ix_apikey_key_prefix"), "apikey", ["key_prefix"], unique=True) + op.drop_column("apikey", "key") + # ### end Alembic commands ### + + +def downgrade(): + # instead of downgrade, will take a db snapshot and restore from that if needed + pass diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 7729dfdd..59678d2f 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,10 +1,9 @@ from collections.abc import Generator -from typing import Annotated, Optional +from typing import Annotated import jwt -from fastapi import Depends, HTTPException, status, Request, Header, Security -from fastapi.responses import JSONResponse -from fastapi.security import OAuth2PasswordBearer, APIKeyHeader +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError from pydantic import ValidationError from sqlmodel import Session, select @@ -12,18 +11,17 @@ from app.core import security from app.core.config import settings from app.core.db import engine -from app.utils import APIResponse +from app.core.security import api_key_manager from app.crud.organization import validate_organization -from app.crud.api_key import get_api_key_by_value from app.models import ( + AuthContext, TokenPayload, User, - UserProjectOrg, UserOrganization, - Project, - Organization, + UserProjectOrg, ) + reusable_oauth2 = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/login/access-token", auto_error=False ) @@ -47,19 +45,16 @@ def get_current_user( """Authenticate user via API Key first, fallback to JWT token. Returns only User.""" if api_key: - api_key_record = get_api_key_by_value(session, api_key) + api_key_record = api_key_manager.verify(session, api_key) if not api_key_record: raise HTTPException(status_code=401, detail="Invalid API Key") - user = session.get(User, api_key_record.user_id) - if not user: - raise HTTPException( - status_code=404, detail="User linked to API Key not found" - ) + if not api_key_record.user.is_active: + raise HTTPException(status_code=403, detail="Inactive user") - return user # Return only User object + return api_key_record.user # Return only User object - if token: + elif token: try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] @@ -93,10 +88,10 @@ def get_current_user_org( organization_id = None api_key = request.headers.get("X-API-KEY") if api_key: - api_key_record = get_api_key_by_value(session, api_key) + api_key_record = api_key_manager.verify(session, api_key) if api_key_record: - validate_organization(session, api_key_record.organization_id) - organization_id = api_key_record.organization_id + validate_organization(session, api_key_record.organization.id) + organization_id = api_key_record.organization.id return UserOrganization( **current_user.model_dump(), organization_id=organization_id @@ -114,11 +109,11 @@ def get_current_user_org_project( project_id = None if api_key: - api_key_record = get_api_key_by_value(session, api_key) + api_key_record = api_key_manager.verify(session, api_key) if api_key_record: - validate_organization(session, api_key_record.organization_id) - organization_id = api_key_record.organization_id - project_id = api_key_record.project_id + validate_organization(session, api_key_record.organization.id) + organization_id = api_key_record.organization.id + project_id = api_key_record.project.id else: raise HTTPException(status_code=401, detail="Invalid API Key") @@ -147,3 +142,59 @@ def get_current_active_superuser_org(current_user: CurrentUserOrg) -> User: status_code=403, detail="The user doesn't have enough privileges" ) return current_user + + +def get_auth_context( + session: SessionDep, + token: TokenDep, + api_key: Annotated[str, Depends(api_key_header)], +) -> AuthContext: + """ + Verify valid authentication (API Key or JWT token) and return authenticated user context. + Returns AuthContext with user info, project_id, and organization_id. + Authorization logic should be handled in routes. + """ + if api_key: + auth_context = api_key_manager.verify(session, api_key) + if not auth_context: + raise HTTPException(status_code=401, detail="Invalid API Key") + + if not auth_context.user.is_active: + raise HTTPException(status_code=403, detail="Inactive user") + + if not auth_context.organization.is_active: + raise HTTPException(status_code=403, detail="Inactive Organization") + + if not auth_context.project.is_active: + raise HTTPException(status_code=403, detail="Inactive Project") + + return auth_context + + elif token: + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + token_data = TokenPayload(**payload) + except (InvalidTokenError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + + user = session.get(User, token_data.sub) + if not user: + raise HTTPException(status_code=404, detail="User not found") + if not user.is_active: + raise HTTPException(status_code=403, detail="Inactive user") + + auth_context = AuthContext( + user=user, + ) + return auth_context + + else: + raise HTTPException(status_code=401, detail="Invalid Authorization format") + + +AuthContextDep = Annotated[AuthContext, Depends(get_auth_context)] diff --git a/backend/app/api/permissions.py b/backend/app/api/permissions.py new file mode 100644 index 00000000..4142b7a4 --- /dev/null +++ b/backend/app/api/permissions.py @@ -0,0 +1,70 @@ +from enum import Enum +from typing import Annotated +from fastapi import Depends, HTTPException +from sqlmodel import Session + +from app.models import AuthContext +from app.api.deps import AuthContextDep, SessionDep + + +class Permission(str, Enum): + """Permission types for authorization checks""" + + SUPERUSER = "require_superuser" + REQUIRE_ORGANIZATION = "require_organization_id" + REQUIRE_PROJECT = "require_project_id" + + +def has_permission( + auth_context: AuthContext, + permission: Permission, + session: Session | None = None, +) -> bool: + """ + Check if the auth_context has the specified permission. + + Args: + user_context: The authenticated user context + permission: The permission to check (Permission enum) + session: Database session (optional) + + Returns: + bool: True if user has permission, False otherwise + """ + match permission: + case Permission.SUPERUSER: + return auth_context.user.is_superuser + case Permission.REQUIRE_ORGANIZATION: + return auth_context.organization is not None + case Permission.REQUIRE_PROJECT: + return auth_context.project is not None + case _: + return False + + +def require_permission(permission: Permission): + """ + Dependency factory for requiring specific permissions in FastAPI routes. + + Usage: + @app.get("/endpoint", dependencies=[Depends(require_permission(Permission.REQUIRE_ORGANIZATION))]) + def endpoint(auth_context: Annotated[AuthContext, Depends(get_user_context)]): + pass + """ + + def permission_checker( + auth_context: AuthContextDep, + session: SessionDep, + ): + if not has_permission(auth_context, permission, session): + error_messages = { + Permission.SUPERUSER: "Insufficient permissions - require superuser access.", + Permission.REQUIRE_ORGANIZATION: "Insufficient permissions - require organization access.", + Permission.REQUIRE_PROJECT: "Insufficient permissions - require project access.", + } + raise HTTPException( + status_code=403, + detail=error_messages.get(permission, "Insufficient permissions"), + ) + + return permission_checker diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index 125df075..d1821a35 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -1,116 +1,85 @@ -import logging -from fastapi import APIRouter, Depends, HTTPException -from sqlmodel import Session -from app.api.deps import get_db, get_current_active_superuser -from app.crud.api_key import ( - create_api_key, - get_api_key, - delete_api_key, - get_api_keys_by_project, - get_api_key_by_project_user, -) -from app.crud.project import validate_project -from app.models import APIKeyPublic, User +from uuid import UUID +from fastapi import APIRouter, Depends, Query + +from app.api.deps import SessionDep, AuthContextDep +from app.crud.api_key import APIKeyCrud +from app.models import APIKeyPublic, APIKeyCreateResponse, Message from app.utils import APIResponse -from app.core.exception_handlers import HTTPException +from app.api.permissions import Permission, require_permission -logger = logging.getLogger(__name__) router = APIRouter(prefix="/apikeys", tags=["API Keys"]) -@router.post("/", response_model=APIResponse[APIKeyPublic]) -def create_key( +@router.post( + "/", + response_model=APIResponse[APIKeyCreateResponse], + status_code=201, + dependencies=[Depends(require_permission(Permission.SUPERUSER))], +) +def create_api_key_route( project_id: int, user_id: int, - session: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), + current_user: AuthContextDep, + session: SessionDep, ): """ - Generate a new API key for the user's organization. - """ - project = validate_project(session, project_id) - - existing_api_key = get_api_key_by_project_user(session, project_id, user_id) - if existing_api_key: - logger.warning( - f"[create_key] API key already exists | project_id={project_id}, user_id={user_id}" - ) - raise HTTPException( - status_code=400, - detail="API Key already exists for this user and project.", - ) + Create a new API key for the project and user, Restricted to Superuser. - api_key = create_api_key( - session, - organization_id=project.organization_id, + The raw API key is returned only once during creation. + Store it securely as it cannot be retrieved again. + """ + api_key_crud = APIKeyCrud(session=session, project_id=project_id) + raw_key, api_key = api_key_crud.create( user_id=user_id, project_id=project_id, ) - return APIResponse.success_response(api_key) - - -@router.get("/", response_model=APIResponse[list[APIKeyPublic]]) -def list_keys( - project_id: int, - session: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), -): - """ - Retrieve all API keys for the given project. Superusers get all keys; - regular users get only their own. - """ - project = validate_project(session=session, project_id=project_id) - - if current_user.is_superuser: - api_keys = get_api_keys_by_project(session=session, project_id=project_id) - else: - user_api_key = get_api_key_by_project_user( - session=session, project_id=project_id, user_id=current_user.id - ) - api_keys = [user_api_key] if user_api_key else [] - if not api_keys: - logger.warning(f"[list_keys] No API keys found | project_id={project_id}") - raise HTTPException( - status_code=404, - detail="No API keys found for this project.", - ) - - return APIResponse.success_response(api_keys) + api_key = APIKeyCreateResponse(**api_key.model_dump(), key=raw_key) + return APIResponse.success_response( + data=api_key, + metadata={ + "message": "The raw API key is returned only once during creation. Store it securely as it cannot be retrieved again." + }, + ) -@router.get("/{api_key_id}", response_model=APIResponse[APIKeyPublic]) -def get_key( - api_key_id: int, - session: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), +@router.get( + "/", + response_model=APIResponse[list[APIKeyPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def list_api_keys_route( + current_user: AuthContextDep, + session: SessionDep, + skip: int = Query(0, ge=0, description="Number of records to skip"), + limit: int = Query(100, ge=1, le=100, description="Maximum records to return"), ): """ - Retrieve an API key by ID. + List all API keys for the current project. + + Returns key prefix for security - the full key is only shown during creation. + Supports pagination via skip and limit parameters. """ - api_key = get_api_key(session, api_key_id) - if not api_key: - logger.warning(f"[get_key] API key not found | api_key_id={api_key_id}") - raise HTTPException(404, "API Key does not exist") + crud = APIKeyCrud(session, current_user.project.id) + api_keys = crud.read_all(skip=skip, limit=limit) - return APIResponse.success_response(api_key) + return APIResponse.success_response(api_keys) -@router.delete("/{api_key_id}", response_model=APIResponse[dict]) -def revoke_key( - api_key_id: int, - session: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), +@router.delete( + "/{key_id}", + response_model=APIResponse[Message], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def delete_api_key_route( + key_id: UUID, + current_user: AuthContextDep, + session: SessionDep, ): """ - Soft delete an API key (revoke access). + Delete an API key by its ID. """ - api_key = get_api_key(session, api_key_id) - if not api_key: - logger.warning( - f"[apikey.revoke] API key not found or already deleted | api_key_id={api_key_id}" - ) - raise HTTPException(404, "API key not found or already deleted") + api_key_crud = APIKeyCrud(session=session, project_id=current_user.project.id) + api_key_crud.delete(key_id=key_id) - delete_api_key(session, api_key_id) - return APIResponse.success_response({"message": "API key revoked successfully"}) + return APIResponse.success_response(Message(message="API Key deleted successfully")) diff --git a/backend/app/core/security.py b/backend/app/core/security.py index ace78c3a..8807d705 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -7,19 +7,26 @@ - Credentials encryption/decryption """ -from datetime import datetime, timedelta, timezone -from typing import Any import base64 import json +import logging +import secrets +from datetime import datetime, timedelta, timezone +from typing import Any, Tuple import jwt from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from passlib.context import CryptContext +from sqlmodel import Session, and_, select +from app.models import APIKey, User, Organization, Project, AuthContext from app.core.config import settings + +logger = logging.getLogger(__name__) + # Password hashing configuration pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -179,3 +186,152 @@ def decrypt_credentials(encrypted_credentials: str) -> dict: return json.loads(decrypted_str) except Exception as e: raise ValueError(f"Failed to decrypt credentials: {e}") + + +class APIKeyManager: + """ + Handles secure API key generation and verification. + + Overview: + - **Old Format (Legacy)**: 43 chars after "ApiKey ", with 12-char prefix and 31-char secret. + - **New Format (Current)**: 65 chars after "ApiKey ", with 22-char prefix and 43-char secret. + - Generates cryptographically secure API keys with fixed lengths, + storing only the hashed secret using bcrypt while keeping the prefix in plaintext for quick lookup. + Raw keys are displayed only once during creation for security. + The system automatically verifies both old and new key formats to ensure backward compatibility. + + Compatibility: + Both old and new formats are supported automatically during verification. + """ + + # Configuration constants + PREFIX_NAME = "ApiKey " + PREFIX_BYTES = 16 # Generates 22 chars in urlsafe base64 + SECRET_BYTES = 32 # Generates 43 chars in urlsafe base64 + PREFIX_LENGTH = 22 + KEY_LENGTH = 65 # Total length: 22 (prefix) + 43 (secret) + HASH_ALGORITHM = "bcrypt" + + pwd_context = CryptContext(schemes=[HASH_ALGORITHM], deprecated="auto") + + @classmethod + def generate(cls) -> Tuple[str, str, str]: + """ + Generate a new API key with prefix and hashed value. + Ensures exact lengths: prefix=22 chars, secret=43 chars. + + Returns: + Tuple of (raw_key, key_prefix, key_hash) + """ + # Generate tokens and ensure exact length + secret_length = cls.KEY_LENGTH - cls.PREFIX_LENGTH + key_prefix = secrets.token_urlsafe(cls.PREFIX_BYTES)[: cls.PREFIX_LENGTH].ljust( + cls.PREFIX_LENGTH, "A" + ) + secret_key = secrets.token_urlsafe(cls.SECRET_BYTES)[:secret_length].ljust( + secret_length, "A" + ) + + # Construct raw key: "ApiKey {prefix}{secret}" + raw_key = f"{cls.PREFIX_NAME}{key_prefix}{secret_key}" + + key_hash = cls.pwd_context.hash(secret_key) + + return raw_key, key_prefix, key_hash + + @classmethod + def _extract_key_parts(cls, raw_key: str) -> Tuple[str, str] | None: + """ + Extract prefix and secret from an API key based on its format. + + Supports: + - New format: "ApiKey {22-char-prefix}{43-char-secret}" + - Old format: "ApiKey {12-char-prefix}{31-char-secret}" + + Returns: + Tuple[str, str] -> (key_prefix, secret_to_verify) + or None if invalid + """ + if not raw_key.startswith(cls.PREFIX_NAME): + return None + + key_part = raw_key[len(cls.PREFIX_NAME) :] + + if len(key_part) == cls.KEY_LENGTH: + key_prefix = key_part[: cls.PREFIX_LENGTH] + secret_key = key_part[cls.PREFIX_LENGTH :] + return key_prefix, secret_key + + old_key_length = 43 + old_prefix_length = 12 + if len(key_part) == old_key_length: + key_prefix = key_part[:old_prefix_length] + secret_key = key_part[old_prefix_length:] + return key_prefix, secret_key + + # Invalid format + return None + + @classmethod + def verify(cls, session: Session, raw_key: str) -> AuthContext | None: + """ + Verify an API key by checking its prefix and hashed value. + Supports both old (43 chars) and new ("ApiKey " + 65 chars) formats. + + Eagerly loads User, Organization, and Project in a single query. + + Args: + session: Database session + raw_key: The raw API key to verify + + Returns: + AuthContext if valid, None otherwise + """ + try: + key_parts = cls._extract_key_parts(raw_key) + + if not key_parts: + return None + + key_prefix, secret = key_parts + + # Single query to fetch APIKey with User, Organization, and Project + statement = ( + select(APIKey, User, Organization, Project) + .where( + and_( + APIKey.key_prefix == key_prefix, + APIKey.is_deleted.is_(False), + ) + ) + .join(User, User.id == APIKey.user_id) + .join(Organization, Organization.id == APIKey.organization_id) + .join(Project, Project.id == APIKey.project_id) + ) + + result = session.exec(statement).first() + + if not result: + return None + api_key_record, user, organization, project = result + auth_context = AuthContext( + user=user, + project=project, + organization=organization, + ) + + # Verify the secret hash + if cls.pwd_context.verify(secret, api_key_record.key_hash): + return auth_context + + return None + + except Exception as e: + logger.error( + f"[APIKeyManager.verify] Error verifying API key: {str(e)}", + exc_info=True, + ) + return None + + +api_key_manager = APIKeyManager() diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 098d9363..be5d36a4 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -26,16 +26,7 @@ validate_project, ) -from .api_key import ( - create_api_key, - generate_api_key, - get_api_key, - get_api_key_by_value, - get_api_keys_by_project, - get_api_key_by_project_user, - delete_api_key, - get_api_key_by_user_id, -) +from .api_key import APIKeyCrud, api_key_manager from .credentials import ( set_creds_for_org, diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 7a8b7c16..374b496e 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -1,182 +1,118 @@ -import uuid -import secrets import logging -from sqlmodel import Session, select -from app.core.security import ( - get_password_hash, - encrypt_api_key, - decrypt_api_key, -) -from app.core import settings -from app.core.util import now -from app.core.exception_handlers import HTTPException -from app.models.api_key import APIKey, APIKeyPublic - -logger = logging.getLogger(__name__) - - -def generate_api_key() -> tuple[str, str]: - """Generate a new API key and its hash.""" - raw_key = "ApiKey " + secrets.token_urlsafe(32) - hashed_key = get_password_hash(raw_key) - return raw_key, hashed_key +from uuid import UUID +from typing import Tuple +from sqlmodel import Session, select, and_ +from fastapi import HTTPException -def create_api_key( - session: Session, organization_id: int, user_id: int, project_id: int -) -> APIKeyPublic: - """ - Generates a new API key for an organization and associates it with a user. - Returns the API key details with the raw key (shown only once). - """ - # Generate raw key and its hash using the helper function - raw_key, hashed_key = generate_api_key() - encrypted_key = encrypt_api_key( - raw_key - ) # Encrypt the raw key instead of hashed key - - # Create API key record with encrypted raw key - api_key = APIKey( - key=encrypted_key, # Store the encrypted raw key - organization_id=organization_id, - user_id=user_id, - project_id=project_id, - ) - - session.add(api_key) - session.commit() - session.refresh(api_key) - - # Set the raw key in the response (shown only once) - api_key_dict = api_key.model_dump() - api_key_dict["key"] = raw_key # Return the raw key to the user - - logger.info( - f"[create_api_key] API key creation completed | {{'api_key_id': {api_key.id}, 'user_id': {user_id}, 'project_id': {project_id}}}" - ) - return APIKeyPublic.model_validate(api_key_dict) - - -def get_api_key(session: Session, api_key_id: int) -> APIKeyPublic | None: - """ - Retrieves an API key by its ID if it exists and is not deleted. - Returns the API key in its original format. - """ - api_key = session.exec( - select(APIKey).where(APIKey.id == api_key_id, APIKey.is_deleted == False) - ).first() - - if api_key: - # Create a copy of the API key data - api_key_dict = api_key.model_dump() - # Decrypt the key - decrypted_key = decrypt_api_key(api_key.key) - api_key_dict["key"] = decrypted_key - return APIKeyPublic.model_validate(api_key_dict) +from app.models import APIKey, User +from app.crud.project import get_project_by_id +from app.core.util import now +from app.core.security import api_key_manager - logger.warning(f"[get_api_key] API key not found | {{'api_key_id': {api_key_id}}}") - return None +logger = logging.getLogger(__name__) -def delete_api_key(session: Session, api_key_id: int) -> None: +class APIKeyCrud: """ - Soft deletes (revokes) an API key by marking it as deleted. + CRUD operations for API keys scoped to a project. """ - api_key = session.get(APIKey, api_key_id) - if not api_key: - logger.warning( - f"[delete_api_key] API key not found | {{'api_key_id': {api_key_id}}}" + def __init__(self, session: Session, project_id: int): + self.session = session + self.project_id = project_id + + def read_one(self, key_id: UUID) -> APIKey | None: + """ + Retrieve a single non-deleted API key by its id. + """ + statement = select(APIKey).where( + and_( + APIKey.id == key_id, + APIKey.project_id == self.project_id, + APIKey.is_deleted.is_(False), + ) + ) + return self.session.exec(statement).one_or_none() + + def read_all(self, skip: int = 0, limit: int = 100) -> list[APIKey]: + """ + Read all non-deleted API keys for the project. + """ + statement = ( + select(APIKey) + .where( + and_( + APIKey.project_id == self.project_id, + APIKey.is_deleted.is_(False), + ) + ) + .offset(skip) + .limit(limit) + ) + return self.session.exec(statement).all() + + def create(self, user_id: int, project_id: int) -> Tuple[str, APIKey]: + """ + Create a new API key for the project. + """ + project = get_project_by_id(session=self.session, project_id=project_id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + user = self.session.get(User, user_id) + + if not user: + raise HTTPException(status_code=404, detail="User not found") + + try: + raw_key, key_prefix, key_hash = api_key_manager.generate() + + api_key = APIKey( + key_prefix=key_prefix, + key_hash=key_hash, + user_id=user_id, + organization_id=project.organization_id, + project_id=project_id, + ) + + self.session.add(api_key) + self.session.commit() + self.session.refresh(api_key) + + logger.info( + f"[APIKeyCrud.create_api_key] API key created successfully | " + f"{{'api_key_id': '{api_key.id}', 'project_id': {project_id}, 'user_id': {user_id}}}" + ) + + return raw_key, api_key + + except Exception as e: + logger.error( + f"[APIKeyCrud.create_api_key] Failed to create API key | " + f"{{'project_id': {project_id}, 'user_id': {user_id}, 'error': '{str(e)}'}}", + exc_info=True, + ) + raise HTTPException( + status_code=500, detail=f"Failed to create API key: {str(e)}" + ) + + def delete(self, key_id: UUID) -> None: + """ + Soft delete an API key by marking it as deleted. + """ + api_key = self.read_one(key_id) + if not api_key: + raise HTTPException(status_code=404, detail="API Key not found") + + api_key.is_deleted = True + api_key.deleted_at = now() + api_key.updated_at = now() + self.session.add(api_key) + self.session.commit() + self.session.refresh(api_key) + + logger.info( + f"[APIKeyCrud.delete_api_key] API key deleted successfully | " + f"{{'api_key_id': '{api_key.id}', 'project_id': {self.project_id}}}" ) - return - - api_key.is_deleted = True - api_key.deleted_at = now() - api_key.updated_at = now() - - session.add(api_key) - session.commit() - logger.info( - f"[delete_api_key] API key soft deleted successfully | {{'api_key_id': {api_key_id}}}" - ) - - -def get_api_key_by_value(session: Session, api_key_value: str) -> APIKeyPublic | None: - """ - Retrieve an API Key record by verifying the provided key against stored hashes. - Returns the API key in its original format. - """ - # Get all active API keys - api_keys = session.exec(select(APIKey).where(APIKey.is_deleted == False)).all() - - for api_key in api_keys: - decrypted_key = decrypt_api_key(api_key.key) - if api_key_value == decrypted_key: - api_key_dict = api_key.model_dump() - api_key_dict["key"] = decrypted_key - return APIKeyPublic.model_validate(api_key_dict) - - logger.warning( - f"[get_api_key_by_value] API key not found | {{'action': 'not_found'}}" - ) - return None - - -def get_api_key_by_project_user( - session: Session, project_id: int, user_id: uuid.UUID -) -> APIKeyPublic | None: - """ - Retrieves the single API key associated with a project. - """ - statement = select(APIKey).where( - APIKey.user_id == user_id, - APIKey.project_id == project_id, - APIKey.is_deleted == False, - ) - api_key = session.exec(statement).first() - - if api_key: - api_key_dict = api_key.model_dump() - api_key_dict["key"] = decrypt_api_key(api_key.key) - return APIKeyPublic.model_validate(api_key_dict) - - logger.warning( - f"[get_api_key_by_project_user] API key not found | {{'project_id': {project_id}, 'user_id': '{user_id}'}}" - ) - return None - - -def get_api_keys_by_project(session: Session, project_id: int) -> list[APIKeyPublic]: - """ - Retrieves all API keys associated with a project. - """ - statement = select(APIKey).where( - APIKey.project_id == project_id, APIKey.is_deleted == False - ) - api_keys = session.exec(statement).all() - - result = [] - for key in api_keys: - key_dict = key.model_dump() - key_dict["key"] = decrypt_api_key(key.key) - result.append(APIKeyPublic.model_validate(key_dict)) - - return result - - -def get_api_key_by_user_id(session: Session, user_id: int) -> APIKeyPublic | None: - """ - Retrieves the API key associated with a user by their user_id. - """ - api_key = ( - session.query(APIKey) - .filter(APIKey.user_id == user_id, APIKey.is_deleted == False) - .first() - ) - - if not api_key: - return None - - key_dict = api_key.model_dump() - key_dict["key"] = decrypt_api_key(api_key.key) - return APIKeyPublic.model_validate(key_dict) diff --git a/backend/app/crud/onboarding.py b/backend/app/crud/onboarding.py index 8788b083..26777d42 100644 --- a/backend/app/crud/onboarding.py +++ b/backend/app/crud/onboarding.py @@ -4,7 +4,7 @@ from app.core.security import encrypt_api_key, encrypt_credentials, get_password_hash from app.crud import ( - generate_api_key, + api_key_manager, get_organization_by_name, get_project_by_name, get_user_by_email, @@ -90,15 +90,16 @@ def onboard_project( session.add(user) session.flush() - raw_key, _ = generate_api_key() - encrypted_key = encrypt_api_key(raw_key) + raw_key, key_prefix, key_hash = api_key_manager.generate() api_key = APIKey( - key=encrypted_key, # Store the encrypted raw key - organization_id=organization.id, + key_prefix=key_prefix, + key_hash=key_hash, user_id=user.id, + organization_id=project.organization_id, project_id=project.id, ) + session.add(api_key) credential = None diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 5d50cf24..15b61428 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,7 +1,9 @@ from sqlmodel import SQLModel -from .auth import Token, TokenPayload -from .api_key import APIKey, APIKeyBase, APIKeyPublic +from .auth import AuthContext, Token, TokenPayload + +from .api_key import APIKey, APIKeyBase, APIKeyPublic, APIKeyCreateResponse + from .assistants import Assistant, AssistantBase, AssistantCreate, AssistantUpdate from .collection import Collection, CollectionPublic diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py index 22387e1b..1da56382 100644 --- a/backend/app/models/api_key.py +++ b/backend/app/models/api_key.py @@ -1,5 +1,6 @@ -import uuid +from uuid import UUID, uuid4 import secrets +import base64 from datetime import datetime from typing import Optional, List from sqlmodel import SQLModel, Field, Relationship @@ -15,24 +16,30 @@ class APIKeyBase(SQLModel): foreign_key="project.id", nullable=False, ondelete="CASCADE" ) user_id: int = Field(foreign_key="user.id", nullable=False, ondelete="CASCADE") - key: str = Field( - default_factory=lambda: secrets.token_urlsafe(32), unique=True, index=True - ) class APIKeyPublic(APIKeyBase): - id: int - inserted_at: datetime = Field(default_factory=now, nullable=False) + id: UUID + key_prefix: str # Expose key_id for display (partial key identifier) + inserted_at: datetime + updated_at: datetime + + +class APIKeyCreateResponse(APIKeyPublic): + """Response model when creating an API key includes the raw key only once""" + + key: str class APIKey(APIKeyBase, table=True): - id: int = Field(default=None, primary_key=True) + id: UUID = Field(default_factory=uuid4, primary_key=True) + + key_prefix: str = Field( + unique=True, index=True, nullable=False + ) # Unique identifier from the key + key_hash: str = Field(nullable=False) # bcrypt hash of the secret portion + inserted_at: datetime = Field(default_factory=now, nullable=False) updated_at: datetime = Field(default_factory=now, nullable=False) is_deleted: bool = Field(default=False, nullable=False) deleted_at: Optional[datetime] = Field(default=None, nullable=True) - - # Relationships - organization: "Organization" = Relationship(back_populates="api_keys") - project: "Project" = Relationship(back_populates="api_keys") - user: "User" = Relationship(back_populates="api_keys") diff --git a/backend/app/models/auth.py b/backend/app/models/auth.py index 7355c383..adb93aeb 100644 --- a/backend/app/models/auth.py +++ b/backend/app/models/auth.py @@ -1,4 +1,7 @@ from sqlmodel import Field, SQLModel +from app.models.user import User +from app.models.organization import Organization +from app.models.project import Project # JSON payload containing access token @@ -10,3 +13,9 @@ class Token(SQLModel): # Contents of JWT token class TokenPayload(SQLModel): sub: str | None = None + + +class AuthContext(SQLModel): + user: User + organization: Organization | None = None + project: Project | None = None diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 39342fda..4c729cdd 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -37,9 +37,6 @@ class Organization(OrganizationBase, table=True): updated_at: datetime = Field(default_factory=now, nullable=False) # Relationship back to Creds - api_keys: list["APIKey"] = Relationship( - back_populates="organization", cascade_delete=True - ) creds: list["Credential"] = Relationship( back_populates="organization", cascade_delete=True ) diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 00549b0c..2a1d346a 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -45,9 +45,6 @@ class Project(ProjectBase, table=True): assistants: list["Assistant"] = Relationship( back_populates="project", cascade_delete=True ) - api_keys: list["APIKey"] = Relationship( - back_populates="project", cascade_delete=True - ) organization: Optional["Organization"] = Relationship(back_populates="project") collections: list["Collection"] = Relationship( back_populates="project", cascade_delete=True diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 6328efaf..82a98262 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -49,8 +49,6 @@ class User(UserBase, table=True): id: int = Field(default=None, primary_key=True) hashed_password: str - api_keys: list["APIKey"] = Relationship(back_populates="user", cascade_delete=True) - class UserOrganization(UserBase): id: int diff --git a/backend/app/seed_data/seed_data.json b/backend/app/seed_data/seed_data.json index 106f1f67..6eb4860a 100644 --- a/backend/app/seed_data/seed_data.json +++ b/backend/app/seed_data/seed_data.json @@ -1,4 +1,5 @@ { + "_comment":"This data will be used in testing also, modifying this will affect tests. Please ensure to update tests/utils/auth.py if you change this data.", "organization": [ { "name": "Project Tech4dev", @@ -48,7 +49,7 @@ "organization_name": "Project Tech4dev", "user_email": "{{ADMIN_EMAIL}}", "project_name": "Dalgo", - "api_key": "ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j9", + "api_key": "ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j", "is_deleted": false, "deleted_at": null } diff --git a/backend/app/seed_data/seed_data.py b/backend/app/seed_data/seed_data.py index 4e5740b3..2f890625 100644 --- a/backend/app/seed_data/seed_data.py +++ b/backend/app/seed_data/seed_data.py @@ -9,7 +9,7 @@ from app.core.db import engine from app.core import settings -from app.core.security import encrypt_api_key, get_password_hash +from app.core.security import get_password_hash, encrypt_credentials from app.models import ( APIKey, Organization, @@ -49,7 +49,6 @@ class APIKeyData(BaseModel): api_key: str is_deleted: bool deleted_at: Optional[str] = None - created_at: Optional[str] = None class CredentialData(BaseModel): @@ -184,21 +183,34 @@ def create_api_key(session: Session, api_key_data_raw: dict) -> APIKey: ).first() if not user: raise ValueError(f"User '{api_key_data.user_email}' not found") - encrypted_api_key = encrypt_api_key(api_key_data.api_key) + + # Extract key_prefix from the provided API key and hash the full key + # API key format: "ApiKey {key_prefix}{random_key}" where key_prefix is 16 chars + raw_key = api_key_data.api_key + if not raw_key.startswith("ApiKey "): + raise ValueError(f"Invalid API key format: {raw_key}") + + # Extract the key_prefix (first 16 characters after "ApiKey ") + key_portion = raw_key[7:] # Remove "ApiKey " prefix + + key_prefix = key_portion[:12] # First 12 characters as prefix + + from passlib.context import CryptContext + + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + key_hash = pwd_context.hash(key_portion[12:]) + api_key = APIKey( organization_id=organization.id, project_id=project.id, user_id=user.id, - key=encrypted_api_key, + key_prefix=key_prefix, + key_hash=key_hash, is_deleted=api_key_data.is_deleted, deleted_at=api_key_data.deleted_at, ) - if api_key_data.created_at: - api_key.created_at = datetime.fromisoformat( - api_key_data.created_at.replace("Z", "+00:00") - ) session.add(api_key) - session.flush() # Ensure ID is assigned + session.flush() return api_key except Exception as e: logging.error(f"Error creating API key: {e}") @@ -229,8 +241,9 @@ def create_credential(session: Session, credential_data_raw: dict) -> Credential if not project: raise ValueError(f"Project '{credential_data.project_name}' not found") - # Encrypt the credential data - encrypted_credential = encrypt_api_key(credential_data.credential) + # Encrypt the credential data - convert string to dict first, then encrypt + credential_dict = json.loads(credential_data.credential) + encrypted_credential = encrypt_credentials(credential_dict) credential = Credential( is_active=credential_data.is_active, diff --git a/backend/app/tests/api/routes/documents/test_route_document_upload.py b/backend/app/tests/api/routes/documents/test_route_document_upload.py index 4b73a451..320c1992 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_upload.py +++ b/backend/app/tests/api/routes/documents/test_route_document_upload.py @@ -10,7 +10,6 @@ from sqlmodel import Session, select from fastapi.testclient import TestClient -from app.models import APIKeyPublic from app.core.cloud import AmazonCloudStorageClient from app.core.config import settings from app.models import Document @@ -19,6 +18,7 @@ WebCrawler, httpx_to_standard, ) +from app.tests.utils.auth import TestAuthContext class WebUploader(WebCrawler): @@ -77,7 +77,7 @@ def route(): @pytest.fixture -def uploader(client: TestClient, user_api_key: APIKeyPublic): +def uploader(client: TestClient, user_api_key: TestAuthContext): return WebUploader(client, user_api_key) diff --git a/backend/app/tests/api/routes/test_api_key.py b/backend/app/tests/api/routes/test_api_key.py index c4e63954..ee3231c0 100644 --- a/backend/app/tests/api/routes/test_api_key.py +++ b/backend/app/tests/api/routes/test_api_key.py @@ -1,117 +1,114 @@ +from uuid import uuid4 + from fastapi.testclient import TestClient from sqlmodel import Session -from app.main import app -from app.models import APIKey from app.core.config import settings -from app.tests.utils.utils import get_non_existent_id -from app.tests.utils.user import create_random_user +from app.tests.utils.auth import TestAuthContext from app.tests.utils.test_data import create_test_api_key, create_test_project +from app.tests.utils.user import create_random_user -client = TestClient(app) +def test_create_api_key_as_superuser( + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], +) -> None: + """Test creating an API key as a superuser.""" -def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]): user = create_random_user(db) project = create_test_project(db) response = client.post( - f"{settings.API_V1_STR}/apikeys", - params={"project_id": project.id, "user_id": user.id}, + f"{settings.API_V1_STR}/apikeys/", headers=superuser_token_headers, + params={ + "project_id": project.id, + "user_id": user.id, + }, ) - assert response.status_code == 200 + assert response.status_code == 201 data = response.json() assert data["success"] is True - assert "id" in data["data"] + assert "data" in data assert "key" in data["data"] - assert data["data"]["organization_id"] == project.organization_id + assert "id" in data["data"] + assert "key_prefix" in data["data"] + assert data["data"]["project_id"] == project.id assert data["data"]["user_id"] == user.id + assert data["data"]["organization_id"] == project.organization_id + assert data["data"]["key"].startswith("ApiKey ") -def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_random_user(db) - project = create_test_project(db) - - client.post( - f"{settings.API_V1_STR}/apikeys", - params={"project_id": project.id, "user_id": user.id}, - headers=superuser_token_headers, - ) +def test_create_api_key_as_normal_user_forbidden( + client: TestClient, + normal_user_token_headers: dict[str, str], + user_api_key: TestAuthContext, +) -> None: + """Test that normal users cannot create API keys (superuser only).""" response = client.post( - f"{settings.API_V1_STR}/apikeys", - params={"project_id": project.id, "user_id": user.id}, - headers=superuser_token_headers, + f"{settings.API_V1_STR}/apikeys/", + headers=normal_user_token_headers, + params={ + "project_id": user_api_key.project_id, + "user_id": user_api_key.user_id, + }, ) - assert response.status_code == 400 - assert "API Key already exists" in response.json()["error"] - - -def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) + assert response.status_code == 403 + + +def test_list_api_keys( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test listing API keys as a normal user.""" + created_keys = [] + for _ in range(3): + key = create_test_api_key( + db=db, + project_id=user_api_key.project_id, + user_id=user_api_key.user_id, + ) + created_keys.append(key) response = client.get( - f"{settings.API_V1_STR}/apikeys", - params={"project_id": api_key.project_id}, - headers=superuser_token_headers, + f"{settings.API_V1_STR}/apikeys/", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 200 data = response.json() assert data["success"] is True assert isinstance(data["data"], list) - assert len(data["data"]) > 0 - - first_key = data["data"][0] - assert first_key["organization_id"] == api_key.organization_id - assert first_key["user_id"] == api_key.user_id + # Verify we have at least the 3 created keys + the fixture key (4 total) + assert len(data["data"]) >= 4 -def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) +def test_delete_api_key( + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test deleting an API key by its owner.""" - response = client.get( - f"{settings.API_V1_STR}/apikeys/{api_key.id}", - headers=superuser_token_headers, + delete_response = client.delete( + f"{settings.API_V1_STR}/apikeys/{user_api_key.api_key_id}", + headers={"X-API-KEY": user_api_key.key}, ) - assert response.status_code == 200 - data = response.json() + assert delete_response.status_code == 200 + data = delete_response.json() assert data["success"] is True - assert data["data"]["id"] == api_key.id - assert data["data"]["organization_id"] == api_key.organization_id - assert data["data"]["user_id"] == api_key.user_id - - -def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key_id = get_non_existent_id(db, APIKey) - response = client.get( - f"{settings.API_V1_STR}/apikeys/{api_key_id}", - headers=superuser_token_headers, - ) - assert response.status_code == 404 - assert "API Key does not exist" in response.json()["error"] + assert "message" in data["data"] + assert "deleted successfully" in data["data"]["message"].lower() -def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) - +def test_delete_api_key_nonexistent( + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test deleting a non-existent API key.""" + fake_uuid = uuid4() response = client.delete( - f"{settings.API_V1_STR}/apikeys/{api_key.id}", - headers=superuser_token_headers, - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is True - assert "API key revoked successfully" in data["data"]["message"] - - -def test_revoke_nonexistent_api_key( - db: Session, superuser_token_headers: dict[str, str] -): - api_key_id = get_non_existent_id(db, APIKey) - - response = client.delete( - f"{settings.API_V1_STR}/apikeys/{api_key_id}", - headers=superuser_token_headers, + f"{settings.API_V1_STR}/apikeys/{fake_uuid}", + headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 404 - assert "API key not found or already deleted" in response.json()["error"] diff --git a/backend/app/tests/api/routes/test_assistants.py b/backend/app/tests/api/routes/test_assistants.py index 8d236696..d4d2aadc 100644 --- a/backend/app/tests/api/routes/test_assistants.py +++ b/backend/app/tests/api/routes/test_assistants.py @@ -4,9 +4,9 @@ from fastapi import HTTPException from fastapi.testclient import TestClient from unittest.mock import patch -from app.crud.api_key import get_api_keys_by_project from app.tests.utils.openai import mock_openai_assistant from app.tests.utils.utils import get_assistant +from app.tests.utils.auth import TestAuthContext @pytest.fixture @@ -30,7 +30,7 @@ def assistant_id(): def test_ingest_assistant_success( mock_fetch_assistant, client: TestClient, - user_api_key_header: dict[str, str], + user_api_key: TestAuthContext, ): """Test successful assistant ingestion from OpenAI.""" mock_assistant = mock_openai_assistant() @@ -39,7 +39,7 @@ def test_ingest_assistant_success( response = client.post( f"/api/v1/assistant/{mock_assistant.id}/ingest", - headers=user_api_key_header, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 201 @@ -53,7 +53,7 @@ def test_create_assistant_success( mock_verify_vector_ids, client: TestClient, assistant_create_payload: dict, - user_api_key_header: dict, + user_api_key: TestAuthContext, ): """Test successful assistant creation with OpenAI vector store ID verification.""" @@ -62,7 +62,7 @@ def test_create_assistant_success( response = client.post( "/api/v1/assistant", json=assistant_create_payload, - headers=user_api_key_header, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 201 @@ -92,7 +92,7 @@ def test_create_assistant_invalid_vector_store( mock_verify_vector_ids, client: TestClient, assistant_create_payload: dict, - user_api_key_header: dict, + user_api_key: TestAuthContext, ): """Test failure when one or more vector store IDs are invalid.""" @@ -106,7 +106,7 @@ def test_create_assistant_invalid_vector_store( response = client.post( "/api/v1/assistant", json=payload, - headers=user_api_key_header, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 400 @@ -117,6 +117,7 @@ def test_create_assistant_invalid_vector_store( def test_update_assistant_success( client: TestClient, db: Session, + user_api_key: TestAuthContext, ): """Test successful assistant update.""" update_payload = { @@ -127,13 +128,12 @@ def test_update_assistant_success( "max_num_results": 5, } - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.patch( f"/api/v1/assistant/{assistant.assistant_id}", json=update_payload, - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 200 @@ -151,6 +151,7 @@ def test_update_assistant_invalid_vector_store( mock_verify_vector_ids, client: TestClient, db: Session, + user_api_key: TestAuthContext, ): """Test failure when updating assistant with invalid vector store IDs.""" mock_verify_vector_ids.side_effect = HTTPException( @@ -159,13 +160,12 @@ def test_update_assistant_invalid_vector_store( update_payload = {"vector_store_ids_add": ["vs_invalid"]} - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.patch( f"/api/v1/assistant/{assistant.assistant_id}", json=update_payload, - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 400 @@ -175,7 +175,7 @@ def test_update_assistant_invalid_vector_store( def test_update_assistant_not_found( client: TestClient, - user_api_key_header: dict, + user_api_key: TestAuthContext, ): """Test failure when updating a non-existent assistant.""" update_payload = {"name": "Updated Assistant"} @@ -185,7 +185,7 @@ def test_update_assistant_not_found( response = client.patch( f"/api/v1/assistant/{non_existent_id}", json=update_payload, - headers=user_api_key_header, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 404 @@ -196,14 +196,14 @@ def test_update_assistant_not_found( def test_get_assistant_success( client: TestClient, db: Session, + user_api_key: TestAuthContext, ): """Test successful retrieval of a single assistant.""" - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.get( f"/api/v1/assistant/{assistant.assistant_id}", - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 200 @@ -235,14 +235,14 @@ def test_get_assistant_not_found( def test_list_assistants_success( client: TestClient, db: Session, + user_api_key: TestAuthContext, ): """Test successful retrieval of assistants list.""" - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.get( "/api/v1/assistant/", - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 200 @@ -286,14 +286,14 @@ def test_list_assistants_invalid_pagination( def test_delete_assistant_success( client: TestClient, db: Session, + user_api_key: TestAuthContext, ): """Test successful soft deletion of an assistant.""" - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.delete( f"/api/v1/assistant/{assistant.assistant_id}", - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 200 diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index d21cb6bd..34a897a5 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -1,8 +1,6 @@ -import pytest from fastapi.testclient import TestClient from sqlmodel import Session -from app.models import APIKeyPublic from app.core.config import settings from app.core.providers import Provider from app.models.credentials import Credential @@ -14,11 +12,12 @@ create_test_credential, test_credential_data, ) +from app.tests.utils.auth import TestAuthContext def test_set_credential( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): project_id = user_api_key.project_id org_id = user_api_key.organization_id @@ -62,7 +61,7 @@ def test_set_credential( def test_set_credentials_ignored_mismatched_ids( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Delete existing credentials first client.delete( @@ -88,7 +87,7 @@ def test_set_credentials_ignored_mismatched_ids( def test_read_credentials_with_creds( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Ensure at least one credential exists for current project api_key_value = "sk-" + generate_random_string(10) @@ -117,7 +116,7 @@ def test_read_credentials_with_creds( assert len(data) >= 1 -def test_read_credentials_not_found(client: TestClient, user_api_key: APIKeyPublic): +def test_read_credentials_not_found(client: TestClient, user_api_key: TestAuthContext): # Delete all first to ensure none remain client.delete( f"{settings.API_V1_STR}/credentials/", headers={"X-API-KEY": user_api_key.key} @@ -132,7 +131,7 @@ def test_read_credentials_not_found(client: TestClient, user_api_key: APIKeyPubl def test_read_provider_credential( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Seed data already has OpenAI credentials - just test GET response = client.get( @@ -148,7 +147,7 @@ def test_read_provider_credential( def test_read_provider_credential_not_found( - client: TestClient, user_api_key: APIKeyPublic + client: TestClient, user_api_key: TestAuthContext ): # Ensure none client.delete( @@ -165,7 +164,7 @@ def test_read_provider_credential_not_found( def test_update_credentials( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Update existing OpenAI credentials from seed data update_data = { @@ -193,7 +192,7 @@ def test_update_credentials( def test_update_credentials_not_found_for_provider( - client: TestClient, user_api_key: APIKeyPublic + client: TestClient, user_api_key: TestAuthContext ): # Ensure none exist client.delete( @@ -220,7 +219,7 @@ def test_update_credentials_not_found_for_provider( def test_delete_provider_credential( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Ensure exists client.delete( @@ -247,7 +246,7 @@ def test_delete_provider_credential( def test_delete_provider_credential_not_found( - client: TestClient, user_api_key: APIKeyPublic + client: TestClient, user_api_key: TestAuthContext ): # Ensure not exists client.delete( @@ -264,7 +263,7 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Delete existing credentials from seed data response = client.delete( @@ -290,7 +289,7 @@ def test_delete_all_credentials( def test_delete_all_credentials_not_found( - client: TestClient, user_api_key: APIKeyPublic + client: TestClient, user_api_key: TestAuthContext ): # Ensure already deleted client.delete( @@ -311,7 +310,7 @@ def test_delete_all_credentials_not_found( def test_duplicate_credential_creation( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Test verifies that the database unique constraint prevents duplicate credentials # for the same organization, project, and provider combination. @@ -335,7 +334,7 @@ def test_duplicate_credential_creation( def test_multiple_provider_credentials( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Ensure clean state for current org/project client.delete( @@ -401,7 +400,7 @@ def test_multiple_provider_credentials( def test_credential_encryption( db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): # Use existing credentials from seed data to verify encryption db_credential = ( @@ -426,7 +425,7 @@ def test_credential_encryption( def test_credential_encryption_consistency( - client: TestClient, user_api_key: APIKeyPublic + client: TestClient, user_api_key: TestAuthContext ): # Fetch existing seed data credentials response = client.get( diff --git a/backend/app/tests/api/routes/test_doc_transformation_job.py b/backend/app/tests/api/routes/test_doc_transformation_job.py index a808c430..92460058 100644 --- a/backend/app/tests/api/routes/test_doc_transformation_job.py +++ b/backend/app/tests/api/routes/test_doc_transformation_job.py @@ -3,13 +3,14 @@ from app.core.config import settings from app.crud.doc_transformation_job import DocTransformationJobCrud -from app.models import APIKeyPublic, TransformationStatus +from app.models import TransformationStatus from app.tests.utils.document import DocumentStore +from app.tests.utils.auth import TestAuthContext class TestGetTransformationJob: def test_get_existing_job_success( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test successfully retrieving an existing transformation job.""" document = DocumentStore(db, user_api_key.project_id).put() @@ -31,7 +32,7 @@ def test_get_existing_job_success( assert data["data"]["transformed_document_id"] is None def test_get_nonexistent_job_404( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test getting a non-existent transformation job returns 404.""" fake_uuid = "00000000-0000-0000-0000-000000000001" @@ -44,7 +45,7 @@ def test_get_nonexistent_job_404( assert response.status_code == 404 def test_get_job_invalid_uuid_422( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: TestAuthContext ): """Test getting a job with invalid UUID format returns 422.""" invalid_uuid = "not-a-uuid" @@ -60,8 +61,8 @@ def test_get_job_different_project_404( self, client: TestClient, db: Session, - user_api_key: APIKeyPublic, - superuser_api_key: APIKeyPublic, + user_api_key: TestAuthContext, + superuser_api_key: TestAuthContext, ): """Test that jobs from different projects are not accessible.""" store = DocumentStore(db, user_api_key.project_id) @@ -78,7 +79,7 @@ def test_get_job_different_project_404( assert response.status_code == 404 def test_get_completed_job_with_result( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test getting a completed job with transformation result.""" store = DocumentStore(db, user_api_key.project_id) @@ -105,7 +106,7 @@ def test_get_completed_job_with_result( assert data["data"]["transformed_document_id"] == str(transformed_document.id) def test_get_failed_job_with_error( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test getting a failed job with error message.""" store = DocumentStore(db, user_api_key.project_id) @@ -130,7 +131,7 @@ def test_get_failed_job_with_error( class TestGetMultipleTransformationJobs: def test_get_multiple_jobs_success( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test successfully retrieving multiple transformation jobs.""" store = DocumentStore(db, user_api_key.project_id) @@ -155,7 +156,7 @@ def test_get_multiple_jobs_success( assert returned_ids == expected_ids def test_get_mixed_existing_nonexisting_jobs( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test retrieving a mix of existing and non-existing jobs.""" store = DocumentStore(db, user_api_key.project_id) @@ -180,7 +181,7 @@ def test_get_mixed_existing_nonexisting_jobs( assert data["data"]["jobs_not_found"][0] == fake_uuid def test_get_jobs_with_empty_string( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: TestAuthContext ): """Test retrieving jobs with empty job_ids parameter.""" response = client.get( @@ -191,7 +192,7 @@ def test_get_jobs_with_empty_string( assert response.status_code == 422 def test_get_jobs_with_whitespace_only( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: TestAuthContext ): """Test retrieving jobs with whitespace-only job_ids.""" response = client.get( @@ -202,7 +203,7 @@ def test_get_jobs_with_whitespace_only( assert response.status_code == 422 def test_get_jobs_invalid_uuid_format_422( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: TestAuthContext ): """Test that invalid UUID format returns 422.""" invalid_uuid = "not-a-uuid" @@ -217,7 +218,7 @@ def test_get_jobs_invalid_uuid_format_422( assert "Input should be a valid UUID" in data["error"] def test_get_jobs_mixed_valid_invalid_uuid_422( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test that mixed valid/invalid UUIDs returns 422.""" store = DocumentStore(db, user_api_key.project_id) @@ -238,7 +239,7 @@ def test_get_jobs_mixed_valid_invalid_uuid_422( assert "job_ids" in data["error"] def test_get_jobs_missing_parameter_422( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: TestAuthContext ): """Test that missing job_ids parameter returns empty results.""" response = client.get( @@ -252,8 +253,8 @@ def test_get_jobs_different_project_not_found( self, client: TestClient, db: Session, - user_api_key: APIKeyPublic, - superuser_api_key: APIKeyPublic, + user_api_key: TestAuthContext, + superuser_api_key: TestAuthContext, ): """Test that jobs from different projects are not returned.""" store = DocumentStore(db, user_api_key.project_id) @@ -274,7 +275,7 @@ def test_get_jobs_different_project_not_found( assert data["data"]["jobs_not_found"][0] == str(job.id) def test_get_jobs_with_various_statuses( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test retrieving jobs with different statuses.""" store = DocumentStore(db, user_api_key.project_id) diff --git a/backend/app/tests/api/routes/test_openai_conversation.py b/backend/app/tests/api/routes/test_openai_conversation.py index dafc569f..500a467b 100644 --- a/backend/app/tests/api/routes/test_openai_conversation.py +++ b/backend/app/tests/api/routes/test_openai_conversation.py @@ -2,15 +2,15 @@ from fastapi.testclient import TestClient from app.crud.openai_conversation import create_conversation -from app.models import APIKeyPublic from app.models import OpenAIConversationCreate from app.tests.utils.openai import generate_openai_id +from app.tests.utils.auth import TestAuthContext def test_get_conversation_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test successful conversation retrieval.""" @@ -45,7 +45,7 @@ def test_get_conversation_success( def test_get_conversation_not_found( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation retrieval with non-existent ID.""" response = client.get( @@ -61,7 +61,7 @@ def test_get_conversation_not_found( def test_get_conversation_by_response_id_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test successful conversation retrieval by response ID.""" response_id = generate_openai_id("resp_", 40) @@ -96,7 +96,7 @@ def test_get_conversation_by_response_id_success( def test_get_conversation_by_response_id_not_found( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation retrieval with non-existent response ID.""" response = client.get( @@ -112,7 +112,7 @@ def test_get_conversation_by_response_id_not_found( def test_get_conversation_by_ancestor_id_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test successful conversation retrieval by ancestor ID.""" ancestor_response_id = generate_openai_id("resp_", 40) @@ -147,7 +147,7 @@ def test_get_conversation_by_ancestor_id_success( def test_get_conversation_by_ancestor_id_not_found( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation retrieval with non-existent ancestor ID.""" response = client.get( @@ -163,7 +163,7 @@ def test_get_conversation_by_ancestor_id_not_found( def test_list_conversations_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test successful conversation listing.""" conversation_data = OpenAIConversationCreate( @@ -199,7 +199,7 @@ def test_list_conversations_success( def test_list_conversations_with_pagination( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation listing with pagination.""" # Create multiple conversations @@ -262,7 +262,7 @@ def test_list_conversations_with_pagination( def test_list_conversations_pagination_metadata( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation listing pagination metadata.""" # Create 5 conversations @@ -320,7 +320,7 @@ def test_list_conversations_pagination_metadata( def test_list_conversations_default_pagination( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation listing with default pagination parameters.""" # Create a conversation @@ -360,7 +360,7 @@ def test_list_conversations_default_pagination( def test_list_conversations_edge_cases( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation listing edge cases for pagination.""" # Test with skip larger than total @@ -397,7 +397,7 @@ def test_list_conversations_edge_cases( def test_list_conversations_invalid_pagination( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation listing with invalid pagination parameters.""" response = client.get( @@ -411,7 +411,7 @@ def test_list_conversations_invalid_pagination( def test_delete_conversation_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test successful conversation deletion.""" conversation_data = OpenAIConversationCreate( @@ -453,7 +453,7 @@ def test_delete_conversation_success( def test_delete_conversation_not_found( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: TestAuthContext, ): """Test conversation deletion with non-existent ID.""" response = client.delete( diff --git a/backend/app/tests/api/test_deps.py b/backend/app/tests/api/test_deps.py new file mode 100644 index 00000000..925e8c8e --- /dev/null +++ b/backend/app/tests/api/test_deps.py @@ -0,0 +1,167 @@ +import pytest +from sqlmodel import Session +from fastapi import HTTPException +from app.api.deps import get_auth_context +from app.models import ( + User, + AuthContext, +) +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.user import authentication_token_from_email, create_random_user +from app.core.config import settings +from app.tests.utils.test_data import create_test_api_key + + +class TestGetAuthContext: + """Test suite for get_auth_context function""" + + def test_get_auth_context_with_valid_api_key( + self, db: Session, user_api_key: TestAuthContext + ): + """Test successful authentication with valid API key""" + auth_context = get_auth_context( + session=db, + token=None, + api_key=user_api_key.key, + ) + + assert isinstance(auth_context, AuthContext) + assert auth_context.user == user_api_key.user + assert auth_context.project == user_api_key.project + assert auth_context.organization == user_api_key.organization + + def test_get_auth_context_with_invalid_api_key(self, db: Session): + """Test authentication fails with invalid API key""" + invalid_api_key = "ApiKey InvalidKeyThatDoesNotExist123456789" + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=invalid_api_key, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid API Key" + + def test_get_auth_context_with_valid_token( + self, db: Session, normal_user_token_headers: dict[str, str] + ): + """Test successful authentication with valid token""" + token = normal_user_token_headers["Authorization"].replace("Bearer ", "") + auth_context = get_auth_context( + session=db, + token=token, + api_key=None, + ) + + # Assert + assert isinstance(auth_context, AuthContext) + assert auth_context.user.email == settings.EMAIL_TEST_USER + + def test_get_auth_context_with_invalid_token(self, db: Session): + """Test authentication fails with invalid token""" + invalid_token = "invalid.token" + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=invalid_token, + api_key=None, + ) + + assert exc_info.value.status_code == 403 + + def test_get_auth_context_with_no_credentials(self, db: Session): + """Test authentication fails when neither API key nor token is provided""" + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=None, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid Authorization format" + + def test_get_auth_context_with_inactive_user_via_api_key(self, db: Session): + """Test authentication fails when API key belongs to inactive user""" + api_key = create_test_api_key(db) + + user = db.get(User, api_key.user_id) + user.is_active = False + db.add(user) + db.commit() + db.refresh(user) + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=api_key.key, + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive user" + + def test_get_auth_context_with_inactive_user_via_token(self, db: Session, client): + """Test authentication fails when token belongs to inactive user""" + user = create_random_user(db) + token_headers = authentication_token_from_email( + client=client, email=user.email, db=db + ) + token = token_headers["Authorization"].replace("Bearer ", "") + + user.is_active = False + db.add(user) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=token, + api_key=None, + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive user" + + def test_get_auth_context_with_inactive_organization( + self, db: Session, user_api_key: TestAuthContext + ): + """Test authentication fails when organization is inactive""" + organization = user_api_key.organization + organization.is_active = False + db.add(organization) + db.commit() + db.refresh(organization) + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=user_api_key.key, + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive Organization" + + def test_get_auth_context_with_inactive_project( + self, db: Session, user_api_key: TestAuthContext + ): + """Test authentication fails when project is inactive""" + project = user_api_key.project + project.is_active = False + db.add(project) + db.commit() + db.refresh(project) + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=user_api_key.key, + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive Project" diff --git a/backend/app/tests/api/test_permissions.py b/backend/app/tests/api/test_permissions.py new file mode 100644 index 00000000..2c9092ac --- /dev/null +++ b/backend/app/tests/api/test_permissions.py @@ -0,0 +1,148 @@ +import pytest +from fastapi import HTTPException +from sqlmodel import Session + +from app.models import User +from app.api.permissions import Permission, has_permission, require_permission +from app.api.deps import get_auth_context +from app.tests.utils.test_data import create_test_api_key + + +class TestHasPermission: + """Test suite for has_permission function""" + + def test_superuser_permission_with_superuser(self, db: Session): + """Test that superuser has SUPERUSER permission""" + api_key_response = create_test_api_key(db) + user = db.get(User, api_key_response.user_id) + user.is_superuser = True + db.add(user) + db.commit() + db.refresh(user) + + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) + + result = has_permission(auth_context, Permission.SUPERUSER, db) + + assert result is True + + def test_superuser_permission_with_regular_user(self, db: Session): + """Test that regular user does not have SUPERUSER permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) + + result = has_permission(auth_context, Permission.SUPERUSER, db) + + assert result is False + + def test_require_organization_permission_with_organization(self, db: Session): + """Test that user with organization has REQUIRE_ORGANIZATION permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) + + result = has_permission(auth_context, Permission.REQUIRE_ORGANIZATION, db) + + assert result is True + + def test_require_organization_permission_without_organization(self, db: Session): + """Test that user without organization does not have REQUIRE_ORGANIZATION permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) + + auth_context.organization = None + + result = has_permission(auth_context, Permission.REQUIRE_ORGANIZATION, db) + + assert result is False + + def test_require_project_permission_with_project(self, db: Session): + """Test that user with project has REQUIRE_PROJECT permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) + + result = has_permission(auth_context, Permission.REQUIRE_PROJECT, db) + + assert result is True + + def test_require_project_permission_without_project(self, db: Session): + """Test that user without project does not have REQUIRE_PROJECT permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) + + auth_context.project = None + + result = has_permission(auth_context, Permission.REQUIRE_PROJECT, db) + + assert result is False + + +class TestRequirePermission: + """Test suite for require_permission dependency factory""" + + def test_returns_valid_permission_checker(self): + """Test that require_permission returns a valid callable permission checker""" + permission_checker = require_permission(Permission.SUPERUSER) + + assert callable(permission_checker) + + def test_permission_checker_passes_with_valid_permission(self, db: Session): + """Test that permission checker passes when user has required permission""" + api_key_response = create_test_api_key(db) + user = db.get(User, api_key_response.user_id) + user.is_superuser = True + db.add(user) + db.commit() + db.refresh(user) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) + + permission_checker = require_permission(Permission.SUPERUSER) + permission_checker(auth_context, db) + + def test_permission_checker_raises_403_without_permission(self, db: Session): + """Test that permission checker raises HTTPException with 403 when user lacks permission""" + api_key_response = create_test_api_key(db) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) + + permission_checker = require_permission(Permission.SUPERUSER) + + with pytest.raises(HTTPException) as exc_info: + permission_checker(auth_context, db) + + assert exc_info.value.status_code == 403 + + +class TestPermissionEnum: + """Test suite for Permission enum""" + + def test_permission_enum_values(self): + """Test that Permission enum has expected values""" + assert Permission.SUPERUSER.value == "require_superuser" + assert Permission.REQUIRE_ORGANIZATION.value == "require_organization_id" + assert Permission.REQUIRE_PROJECT.value == "require_project_id" + + def test_permission_enum_is_string(self): + """Test that Permission enum members are strings""" + assert isinstance(Permission.SUPERUSER, str) + assert isinstance(Permission.REQUIRE_ORGANIZATION, str) + assert isinstance(Permission.REQUIRE_PROJECT, str) diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 8f14de4c..d396a435 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -14,9 +14,13 @@ from app.core.db import engine from app.api.deps import get_db from app.main import app -from app.models import APIKeyPublic from app.tests.utils.user import authentication_token_from_email -from app.tests.utils.utils import get_superuser_token_headers, get_api_key_by_email +from app.tests.utils.utils import get_superuser_token_headers +from app.tests.utils.auth import ( + get_superuser_test_auth_context, + get_user_test_auth_context, + TestAuthContext, +) from app.seed_data.seed_data import seed_database @@ -76,32 +80,25 @@ def normal_user_token_headers(client: TestClient, db: Session) -> dict[str, str] ) -@pytest.fixture(scope="function") +@pytest.fixture def superuser_api_key_header(db: Session) -> dict[str, str]: - api_key = get_api_key_by_email(db, settings.FIRST_SUPERUSER) - return {"X-API-KEY": api_key.key} + auth_ctx = get_superuser_test_auth_context(db) + return {"X-API-KEY": auth_ctx.key} -@pytest.fixture(scope="function") +@pytest.fixture def user_api_key_header(db: Session) -> dict[str, str]: - api_key = get_api_key_by_email(db, settings.EMAIL_TEST_USER) - return {"X-API-KEY": api_key.key} + auth_ctx = get_user_test_auth_context(db) + return {"X-API-KEY": auth_ctx.key} -@pytest.fixture(scope="function") -def superuser_api_key(db: Session) -> APIKeyPublic: - api_key = get_api_key_by_email(db, settings.FIRST_SUPERUSER) - return api_key +@pytest.fixture +def superuser_api_key(db: Session) -> TestAuthContext: + auth_ctx = get_superuser_test_auth_context(db) + return auth_ctx -@pytest.fixture(scope="function") -def user_api_key(db: Session) -> APIKeyPublic: - """ - Provides an API key for the test user. - - This API key is associated with the Dalgo project, which has both OpenAI - and Langfuse credentials pre-populated via seed data. - All tests can assume credentials exist for this user. - """ - api_key = get_api_key_by_email(db, settings.EMAIL_TEST_USER) - return api_key +@pytest.fixture +def user_api_key(db: Session) -> TestAuthContext: + auth_ctx = get_user_test_auth_context(db) + return auth_ctx diff --git a/backend/app/tests/core/doctransformer/test_service/conftest.py b/backend/app/tests/core/doctransformer/test_service/conftest.py index 7f3aeda0..75574be0 100644 --- a/backend/app/tests/core/doctransformer/test_service/conftest.py +++ b/backend/app/tests/core/doctransformer/test_service/conftest.py @@ -16,7 +16,7 @@ from app.core.config import settings from app.models import Document, Project, UserProjectOrg from app.tests.utils.document import DocumentStore -from app.tests.utils.test_data import create_test_api_key +from app.tests.utils.auth import TestAuthContext @pytest.fixture(scope="class") @@ -52,10 +52,10 @@ def fast_execute_job_func( @pytest.fixture -def current_user(db: Session) -> UserProjectOrg: +def current_user(db: Session, user_api_key: TestAuthContext) -> UserProjectOrg: """Create a test user for testing.""" - api_key = create_test_api_key(db) - user = db.get(User, api_key.user_id) + api_key = user_api_key + user = api_key.user return UserProjectOrg( **user.model_dump(), project_id=api_key.project_id, diff --git a/backend/app/tests/core/test_security.py b/backend/app/tests/core/test_security.py index 4f7f6861..59101375 100644 --- a/backend/app/tests/core/test_security.py +++ b/backend/app/tests/core/test_security.py @@ -1,11 +1,16 @@ import pytest +from uuid import uuid4 +from sqlmodel import Session from app.core.security import ( get_password_hash, verify_password, encrypt_api_key, decrypt_api_key, get_encryption_key, + APIKeyManager, ) +from app.models import APIKey, User, Organization, Project, AuthContext +from app.tests.utils.test_data import create_test_api_key def test_encrypt_decrypt_api_key(): @@ -119,3 +124,175 @@ def test_get_encryption_key(): assert isinstance(key, bytes) # The key is base64 encoded, so it should be 44 bytes assert len(key) == 44 # Base64 encoded Fernet key length is 44 bytes + + +class TestAPIKeyManager: + """Test suite for APIKeyManager class.""" + + def test_generate_returns_correct_tuple(self): + """Test that generate returns a tuple of (raw_key, key_prefix, key_hash).""" + raw_key, key_prefix, key_hash = APIKeyManager.generate() + + assert isinstance(raw_key, str) + assert isinstance(key_prefix, str) + assert isinstance(key_hash, str) + + def test_generate_raw_key_format(self): + """Test that generated raw key has correct format.""" + raw_key, key_prefix, key_hash = APIKeyManager.generate() + + # Should start with "ApiKey " + assert raw_key.startswith(APIKeyManager.PREFIX_NAME) + + # Should have correct length (7 for "ApiKey " + 22 for prefix + 43 for secret) + expected_length = len(APIKeyManager.PREFIX_NAME) + APIKeyManager.KEY_LENGTH + assert len(raw_key) == expected_length + + def test_generate_key_prefix_length(self): + """Test that generated key prefix has correct length.""" + raw_key, key_prefix, key_hash = APIKeyManager.generate() + + assert len(key_prefix) == APIKeyManager.PREFIX_LENGTH + + def test_generate_unique_keys(self): + """Test that generate creates unique keys on each call.""" + raw_key1, prefix1, hash1 = APIKeyManager.generate() + raw_key2, prefix2, hash2 = APIKeyManager.generate() + + assert raw_key1 != raw_key2 + assert prefix1 != prefix2 + assert hash1 != hash2 + + def test_generate_hash_is_bcrypt(self): + """Test that the generated hash uses bcrypt format.""" + raw_key, key_prefix, key_hash = APIKeyManager.generate() + + # bcrypt hashes start with $2b$ (or $2a$ or $2y$) + assert key_hash.startswith("$2") + + def test_extract_key_parts_new_format(self): + """Test extracting key parts from new format (65 chars).""" + raw_key, expected_prefix, _ = APIKeyManager.generate() + + result = APIKeyManager._extract_key_parts(raw_key) + + assert result is not None + extracted_prefix, secret = result + assert extracted_prefix == expected_prefix + assert len(secret) == APIKeyManager.KEY_LENGTH - APIKeyManager.PREFIX_LENGTH + + def test_extract_key_parts_old_format(self): + """Test extracting key parts from old format (43 chars).""" + old_prefix = "a" * 12 + old_secret = "b" * 31 + raw_key = f"{APIKeyManager.PREFIX_NAME}{old_prefix}{old_secret}" + + result = APIKeyManager._extract_key_parts(raw_key) + + assert result is not None + extracted_prefix, secret = result + assert extracted_prefix == old_prefix + assert secret == old_secret + + def test_extract_key_parts_invalid_prefix(self): + """Test that invalid prefix returns None.""" + invalid_key = "InvalidPrefix abcdefghij1234567890" + + result = APIKeyManager._extract_key_parts(invalid_key) + + assert result is None + + def test_extract_key_parts_invalid_length(self): + """Test that invalid length returns None.""" + invalid_key = f"{APIKeyManager.PREFIX_NAME}tooshort" + + result = APIKeyManager._extract_key_parts(invalid_key) + + assert result is None + + def test_verify_valid_key(self, db: Session): + """Test verifying a valid API key.""" + api_key = create_test_api_key(db) + + auth_context = APIKeyManager.verify(db, api_key.key) + + user = db.get(User, api_key.user_id) + organization = db.get(Organization, api_key.organization_id) + project = db.get(Project, api_key.project_id) + + assert auth_context is not None + assert isinstance(auth_context, AuthContext) + assert auth_context.user.id == api_key.user_id + assert auth_context.organization.id == api_key.organization_id + assert auth_context.project.id == api_key.project_id + assert auth_context.user == user + assert auth_context.organization == organization + assert auth_context.project == project + + def test_verify_invalid_key(self, db: Session): + """Test verifying an invalid API key.""" + # Generate a key but don't store it + raw_key, _, _ = APIKeyManager.generate() + + auth_context = APIKeyManager.verify(db, raw_key) + + assert auth_context is None + + def test_verify_wrong_secret(self, db: Session): + """Test verifying with correct prefix but wrong secret.""" + create_test_api_key(db) + + # Generate a different key to try verification + raw_key2, _, _ = APIKeyManager.generate() + + # Try to verify with key2 (wrong secret) + auth_context = APIKeyManager.verify(db, raw_key2) + + assert auth_context is None + + def test_verify_deleted_key(self, db: Session): + """Test that deleted API keys cannot be verified.""" + api_key_response = create_test_api_key(db) + raw_key = api_key_response.key + + api_key = db.get(APIKey, api_key_response.id) + api_key.is_deleted = True + db.commit() + + auth_context = APIKeyManager.verify(db, raw_key) + + assert auth_context is None + + def test_verify_malformed_key(self, db: Session): + """Test verifying with malformed key format.""" + malformed_keys = [ + "not_an_api_key", + "", + "ApiKey", + "ApiKey ", + None, + ] + + for malformed_key in malformed_keys: + if malformed_key is not None: + auth_context = APIKeyManager.verify(db, malformed_key) + assert auth_context is None + + def test_prefix_name_constant(self): + """Test that PREFIX_NAME is correct.""" + assert APIKeyManager.PREFIX_NAME == "ApiKey " + + def test_key_length_constants(self): + """Test that key length constants are correct.""" + assert APIKeyManager.PREFIX_LENGTH == 22 + assert APIKeyManager.KEY_LENGTH == 65 + assert APIKeyManager.KEY_LENGTH == APIKeyManager.PREFIX_LENGTH + 43 + + def test_generate_creates_verifiable_key(self, db: Session): + """Integration test: generated key can be verified.""" + api_key_response = create_test_api_key(db) + + auth_context = APIKeyManager.verify(db, api_key_response.key) + + assert auth_context is not None + assert auth_context.user.id == api_key_response.user_id diff --git a/backend/app/tests/crud/test_api_key.py b/backend/app/tests/crud/test_api_key.py index 9f4281dc..939837c4 100644 --- a/backend/app/tests/crud/test_api_key.py +++ b/backend/app/tests/crud/test_api_key.py @@ -1,117 +1,267 @@ -from sqlmodel import Session, select - -from app.crud import api_key as api_key_crud -from app.models import APIKey -from app.tests.utils.utils import get_non_existent_id +import pytest +from uuid import uuid4 +from sqlmodel import Session +from fastapi import HTTPException + +from app.crud import APIKeyCrud +from app.models import APIKey, Project, User +from app.tests.utils.test_data import create_test_project, create_test_api_key from app.tests.utils.user import create_random_user -from app.tests.utils.test_data import create_test_api_key, create_test_project +from app.tests.utils.utils import get_non_existent_id def test_create_api_key(db: Session) -> None: - user = create_random_user(db) + """Test creating a new API key""" project = create_test_project(db) + user = create_random_user(db) - api_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + raw_key, api_key = api_key_crud.create(user_id=user.id, project_id=project.id) - assert api_key.key.startswith("ApiKey ") - assert len(api_key.key) > 32 - assert api_key.organization_id == project.organization_id + assert api_key.id is not None assert api_key.user_id == user.id assert api_key.project_id == project.id + assert api_key.organization_id == project.organization_id + assert api_key.key_prefix is not None + assert api_key.key_hash is not None + assert api_key.is_deleted is False + assert api_key.deleted_at is None + assert raw_key is not None + assert len(raw_key) > 0 -def test_get_api_key(db: Session) -> None: - api_key = create_test_api_key(db) - retrieved_key = api_key_crud.get_api_key(db, api_key.id) +def test_create_api_key_with_nonexistent_project(db: Session) -> None: + """Test creating API key with a project that doesn't exist""" + user = create_random_user(db) + fake_project_id = get_non_existent_id(session=db, model=Project) - assert retrieved_key is not None - assert retrieved_key.id == api_key.id - assert retrieved_key.key.startswith("ApiKey ") - assert retrieved_key.project_id == api_key.project_id + api_key_crud = APIKeyCrud(session=db, project_id=fake_project_id) + with pytest.raises(HTTPException) as exc_info: + api_key_crud.create(user_id=user.id, project_id=fake_project_id) -def test_get_api_key_not_found(db: Session) -> None: - api_key_id = get_non_existent_id(db, APIKey) - result = api_key_crud.get_api_key(db, api_key_id) - assert result is None + assert exc_info.value.status_code == 404 + assert "Project not found" in str(exc_info.value.detail) -def test_delete_api_key(db: Session) -> None: - api_key = create_test_api_key(db) - api_key_crud.delete_api_key(db, api_key.id) +def test_create_api_key_with_nonexistent_user(db: Session) -> None: + """Test creating API key with a user that doesn't exist""" + project = create_test_project(db) + fake_user_id = get_non_existent_id(session=db, model=User) - deleted_key = db.exec(select(APIKey).where(APIKey.id == api_key.id)).first() + api_key_crud = APIKeyCrud(session=db, project_id=project.id) - assert deleted_key is not None - assert deleted_key.is_deleted is True - assert deleted_key.deleted_at is not None + with pytest.raises(HTTPException) as exc_info: + api_key_crud.create(user_id=fake_user_id, project_id=project.id) + assert exc_info.value.status_code == 404 + assert "User not found" in str(exc_info.value.detail) -def test_get_api_key_by_value(db: Session) -> None: - api_key = create_test_api_key(db) - raw_key = api_key.key - # Test retrieving the API key by its value - retrieved_key = api_key_crud.get_api_key_by_value(db, raw_key) +def test_read_one_api_key(db: Session) -> None: + """Test reading a single API key by ID""" + api_key = create_test_api_key(db=db) + api_key_crud = APIKeyCrud(session=db, project_id=api_key.project_id) + retrieved_key = api_key_crud.read_one(key_id=api_key.id) assert retrieved_key is not None assert retrieved_key.id == api_key.id - assert retrieved_key.organization_id == api_key.organization_id - assert retrieved_key.user_id == api_key.user_id - # The key should be in its original format - assert retrieved_key.key == raw_key # Should be exactly the same key - assert retrieved_key.key.startswith("ApiKey ") - assert len(retrieved_key.key) > 32 + assert retrieved_key.key_prefix == api_key.key_prefix + assert retrieved_key.project_id == api_key.project_id -def test_get_api_key_by_project_user(db: Session) -> None: - user = create_random_user(db) +def test_read_one_api_key_nonexistent(db: Session) -> None: + """Test reading an API key that doesn't exist""" project = create_test_project(db) - created_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) - retrieved_key = api_key_crud.get_api_key_by_project_user(db, project.id, user.id) + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + fake_key_id = uuid4() - assert retrieved_key is not None - assert retrieved_key.id == created_key.id - assert retrieved_key.project_id == project.id - assert retrieved_key.key.startswith("ApiKey ") + retrieved_key = api_key_crud.read_one(key_id=fake_key_id) + assert retrieved_key is None -def test_get_api_keys_by_project(db: Session) -> None: + +def test_read_one_api_key_wrong_project(db: Session) -> None: + """Test that reading an API key from a different project returns None""" + + api_key = create_test_api_key(db=db) + project2 = create_test_project(db) + + # Try to read it from project2 scope + api_key_crud2 = APIKeyCrud(session=db, project_id=project2.id) + retrieved_key = api_key_crud2.read_one(key_id=api_key.id) + + assert retrieved_key is None + + +def test_read_one_deleted_api_key(db: Session) -> None: + """Test that reading a deleted API key returns None""" + api_key = create_test_api_key(db=db) + + api_key_crud = APIKeyCrud(session=db, project_id=api_key.project_id) + + api_key_crud.delete(key_id=api_key.id) + retrieved_key = api_key_crud.read_one(key_id=api_key.id) + assert retrieved_key is None + + +def test_read_all_api_keys(db: Session) -> None: + """Test reading all API keys for a project""" + project = create_test_project(db) user = create_random_user(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + + # Create multiple API keys + key1 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + key2 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + key3 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + + # Read all keys + all_keys = api_key_crud.read_all() + + assert len(all_keys) == 3 + key_ids = {key.id for key in all_keys} + assert key1.id in key_ids + assert key2.id in key_ids + assert key3.id in key_ids + + +def test_read_all_api_keys_with_pagination(db: Session) -> None: + """Test reading API keys with skip and limit parameters""" project = create_test_project(db) + user = create_random_user(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + + # Create 5 API keys + for _ in range(5): + create_test_api_key(db=db, project_id=project.id, user_id=user.id) + + # Test pagination + page1 = api_key_crud.read_all(skip=0, limit=2) + assert len(page1) == 2 + + page2 = api_key_crud.read_all(skip=2, limit=2) + assert len(page2) == 2 - created_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id + page3 = api_key_crud.read_all(skip=4, limit=2) + assert len(page3) == 1 + + all_ids = ( + {key.id for key in page1} + | {key.id for key in page2} + | {key.id for key in page3} ) + assert len(all_ids) == 5 + + +def test_read_all_excludes_deleted_keys(db: Session) -> None: + """Test that read_all excludes deleted API keys""" + project = create_test_project(db) + user = create_random_user(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) - retrieved_keys = api_key_crud.get_api_keys_by_project(db, project.id) + # Create 3 API keys + key1 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + key2 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + key3 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) - assert retrieved_keys is not None - assert len(retrieved_keys) == 1 - retrieved_key = retrieved_keys[0] + # Delete one + api_key_crud.delete(key_id=key2.id) - assert retrieved_key.id == created_key.id - assert retrieved_key.project_id == project.id - assert retrieved_key.key.startswith("ApiKey ") + # Read all should only return 2 + all_keys = api_key_crud.read_all() + assert len(all_keys) == 2 + key_ids = {key.id for key in all_keys} + assert key1.id in key_ids + assert key2.id not in key_ids + assert key3.id in key_ids -def test_get_api_key_by_user_id(db: Session) -> None: +def test_read_all_scoped_to_project(db: Session) -> None: + """Test that read_all only returns keys for the specified project""" + project1 = create_test_project(db) + project2 = create_test_project(db) user = create_random_user(db) + + api_key_crud1 = APIKeyCrud(session=db, project_id=project1.id) + api_key_crud2 = APIKeyCrud(session=db, project_id=project2.id) + + # Create keys for both projects + api_key_crud1.create(user_id=user.id, project_id=project1.id) + api_key_crud1.create(user_id=user.id, project_id=project1.id) + api_key_crud2.create(user_id=user.id, project_id=project2.id) + + # Each project should only see its own keys + project1_keys = api_key_crud1.read_all() + project2_keys = api_key_crud2.read_all() + + assert len(project1_keys) == 2 + assert len(project2_keys) == 1 + assert all(key.project_id == project1.id for key in project1_keys) + assert all(key.project_id == project2.id for key in project2_keys) + + +def test_delete_api_key(db: Session) -> None: + """Test soft deleting an API key""" + api_key = create_test_api_key(db=db) + + api_key_crud = APIKeyCrud(session=db, project_id=api_key.project_id) + + api_key_crud.delete(key_id=api_key.id) + + db_key = db.get(APIKey, api_key.id) + assert db_key is not None + assert db_key.is_deleted is True + assert db_key.deleted_at is not None + + retrieved_key = api_key_crud.read_one(key_id=api_key.id) + assert retrieved_key is None + + +def test_delete_nonexistent_api_key(db: Session) -> None: + """Test deleting an API key that doesn't exist""" project = create_test_project(db) - created_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + fake_key_id = uuid4() - retrieved_key = api_key_crud.get_api_key_by_user_id(db, user.id) + with pytest.raises(HTTPException) as exc_info: + api_key_crud.delete(key_id=fake_key_id) - assert retrieved_key is not None + assert exc_info.value.status_code == 404 + assert "API Key not found" in str(exc_info.value.detail) + + +def test_delete_api_key_from_wrong_project(db: Session) -> None: + """Test that deleting an API key from a different project fails""" + api_key = create_test_api_key(db=db) + project2 = create_test_project(db) + + api_key_crud2 = APIKeyCrud(session=db, project_id=project2.id) + with pytest.raises(HTTPException) as exc_info: + api_key_crud2.delete(key_id=api_key.id) + + assert exc_info.value.status_code == 404 + assert "API Key not found" in str(exc_info.value.detail) + + db_key = db.get(APIKey, api_key.id) + assert db_key is not None + assert db_key.is_deleted is False + + +def test_delete_already_deleted_api_key(db: Session) -> None: + """Test deleting an API key that's already deleted""" + api_key = create_test_api_key(db=db) + api_key_crud = APIKeyCrud(session=db, project_id=api_key.project_id) + + api_key_crud.delete(key_id=api_key.id) + + with pytest.raises(HTTPException) as exc_info: + api_key_crud.delete(key_id=api_key.id) - assert retrieved_key.id == created_key.id - assert retrieved_key.user_id == user.id - assert retrieved_key.key.startswith("ApiKey ") + assert exc_info.value.status_code == 404 + assert "API Key not found" in str(exc_info.value.detail) diff --git a/backend/app/tests/crud/test_onboarding.py b/backend/app/tests/crud/test_onboarding.py index 3ff64bbf..613669d1 100644 --- a/backend/app/tests/crud/test_onboarding.py +++ b/backend/app/tests/crud/test_onboarding.py @@ -229,7 +229,6 @@ def test_onboard_project_api_key_generation(db: Session) -> None: ) ).first() assert api_key_record is not None - assert api_key_record.key != response.api_key def test_onboard_project_response_data_integrity(db: Session) -> None: diff --git a/backend/app/tests/utils/auth.py b/backend/app/tests/utils/auth.py new file mode 100644 index 00000000..1f35becb --- /dev/null +++ b/backend/app/tests/utils/auth.py @@ -0,0 +1,146 @@ +from uuid import UUID + +from sqlmodel import Session, select, SQLModel + +from app.models import User, Organization, Project, APIKey +from app.core.config import settings + + +class TestAuthContext(SQLModel): + """Authentication context for testing""" + + user_id: int + project_id: int + organization_id: int + key: str # The full unencrypted API key with "ApiKey " prefix + api_key_id: UUID # The UUID of the API key record + + # Complete nested objects + user: User + project: Project + organization: Organization + api_key: APIKey + + +def get_test_auth_context( + session: Session, + user_email: str, + project_name: str, + raw_key: str, + user_type: str = "User", +) -> TestAuthContext: + """ + Helper function to get authentication context from seeded data. + + Args: + session: Database session + user_email: Email of the user + project_name: Name of the project + raw_key: The full unencrypted API key with "ApiKey " prefix + user_type: Type of user for error messages (e.g., "Superuser", "User") + + Returns: + TestAuthContext with all IDs and keys from seeded data + + Raises: + ValueError: If the required data is not found in the database + """ + # Get user from seed data + user = session.exec(select(User).where(User.email == user_email)).first() + if not user: + raise ValueError( + f"{user_type} with email {user_email} not found. Ensure seed data is loaded." + ) + + # Get project from seed data + project = session.exec(select(Project).where(Project.name == project_name)).first() + if not project: + raise ValueError( + f"Project {project_name} not found. Ensure seed data is loaded." + ) + + # Get organization + org = session.exec( + select(Organization).where(Organization.id == project.organization_id) + ).first() + if not org: + raise ValueError(f"Organization for project {project_name} not found.") + + # Get API key for this user and project + api_key = session.exec( + select(APIKey) + .where(APIKey.user_id == user.id) + .where(APIKey.project_id == project.id) + .where(APIKey.is_deleted == False) + ).first() + if not api_key: + raise ValueError( + f"API key for {user_type.lower()} and project {project_name} not found." + ) + + # Return complete auth context + return TestAuthContext( + user_id=user.id, + project_id=project.id, + organization_id=org.id, + key=raw_key, + api_key_id=api_key.id, + user=user, + project=project, + organization=org, + api_key=api_key, + ) + + +def get_superuser_test_auth_context(session: Session) -> TestAuthContext: + """ + Get authentication context for superuser from seeded data. + + Uses SUPERUSER_EMAIL with Glific project based on seed_data.json: + - User: {{SUPERUSER_EMAIL}} (is_superuser: true) + - Project: Glific + - API Key: ApiKey No3x47A5qoIGhm0kVKjQ77dhCqEdWRIQZlEPzzzh7i8 + + Args: + session: Database session + + Returns: + TestAuthContext with all IDs and keys from seeded data + + Raises: + ValueError: If the required data is not found in the database + """ + return get_test_auth_context( + session=session, + user_email=settings.FIRST_SUPERUSER, + project_name="Glific", + raw_key="ApiKey No3x47A5qoIGhm0kVKjQ77dhCqEdWRIQZlEPzzzh7i8", + user_type="Superuser", + ) + + +def get_user_test_auth_context(session: Session) -> TestAuthContext: + """ + Get authentication context for normal user from seeded data. + + Uses ADMIN_EMAIL with Dalgo project based on seed_data.json: + - User: {{ADMIN_EMAIL}} (is_superuser: false) + - Project: Dalgo + - API Key: ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j9 + + Args: + session: Database session + + Returns: + TestAuthContext with all IDs and keys from seeded data + + Raises: + ValueError: If the required data is not found in the database + """ + return get_test_auth_context( + session=session, + user_email=settings.EMAIL_TEST_USER, + project_name="Dalgo", + raw_key="ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j", + user_type="User", + ) diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index e025f908..36e1ecf0 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -8,7 +8,7 @@ from app.models import Collection, Organization, Project from app.tests.utils.utils import get_user_id_by_email, get_project from app.tests.utils.test_data import create_test_project -from app.crud import create_api_key +from app.tests.utils.test_data import create_test_api_key class constants: diff --git a/backend/app/tests/utils/document.py b/backend/app/tests/utils/document.py index 5db59658..dddb1c2a 100644 --- a/backend/app/tests/utils/document.py +++ b/backend/app/tests/utils/document.py @@ -13,8 +13,9 @@ from app.core.config import settings from app.crud.project import get_project_by_id -from app.models import APIKeyPublic, Document, DocumentPublic, Project +from app.models import Document, DocumentPublic, Project from app.utils import APIResponse +from app.tests.utils.auth import TestAuthContext from .utils import SequentialUuidGenerator @@ -112,7 +113,7 @@ def append(self, doc: Document, suffix: str = None): @dataclass class WebCrawler: client: TestClient - user_api_key: APIKeyPublic + user_api_key: TestAuthContext def get(self, route: Route): return self.client.get( @@ -165,5 +166,5 @@ def to_public_dict(self) -> dict: @pytest.fixture -def crawler(client: TestClient, user_api_key: APIKeyPublic): +def crawler(client: TestClient, user_api_key: TestAuthContext): return WebCrawler(client, user_api_key=user_api_key) diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index bb6a93aa..c560bbca 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -4,6 +4,7 @@ Organization, Project, APIKey, + APIKeyCreateResponse, Credential, OrganizationCreate, ProjectCreate, @@ -17,10 +18,10 @@ from app.crud import ( create_organization, create_project, - create_api_key, set_creds_for_org, create_fine_tuning_job, create_model_evaluation, + APIKeyCrud, ) from app.core.providers import Provider from app.tests.utils.user import create_random_user @@ -62,23 +63,6 @@ def create_test_project(db: Session) -> Project: return create_project(session=db, project_create=project_in) -def create_test_api_key(db: Session) -> APIKey: - """ - Creates and returns an API key for a test project and test user. - - Persists a test user, organization, project, and API key to the database - """ - project = create_test_project(db) - user = create_random_user(db) - api_key = create_api_key( - db, - organization_id=project.organization_id, - user_id=user.id, - project_id=project.id, - ) - return api_key - - def test_credential_data(db: Session) -> CredsCreate: """ Returns credential data for a test project in the form of a CredsCreate schema. @@ -99,6 +83,29 @@ def test_credential_data(db: Session) -> CredsCreate: return creds_data +def create_test_api_key( + db: Session, + project_id: int | None = None, + user_id: int | None = None, +) -> APIKeyCreateResponse: + """ + Creates and returns a test API key for a specific project and user. + + Persists the API key to the database. + """ + if user_id is None: + user = create_random_user(db) + user_id = user.id + + if project_id is None: + project = create_test_project(db) + project_id = project.id + + api_key_crud = APIKeyCrud(session=db, project_id=project_id) + raw_key, api_key = api_key_crud.create(user_id=user_id, project_id=project_id) + return APIKeyCreateResponse(key=raw_key, **api_key.model_dump()) + + def create_test_credential(db: Session) -> tuple[list[Credential], Project]: """ Creates and returns test credentials (OpenAI and Langfuse) for a test project. diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 362b7434..0ea1e1d5 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -11,8 +11,7 @@ from app.core.config import settings from app.crud.user import get_user_by_email -from app.crud.api_key import get_api_key_by_value, get_api_key_by_user_id -from app.models import APIKeyPublic, Project, Assistant, Organization, Document +from app.models import Project, Assistant, Organization, Document T = TypeVar("T") @@ -42,26 +41,11 @@ def get_superuser_token_headers(client: TestClient) -> dict[str, str]: return headers -def get_api_key_by_email(db: Session, email: EmailStr) -> APIKeyPublic: - user = get_user_by_email(session=db, email=email) - api_key = get_api_key_by_user_id(db, user_id=user.id) - - return api_key - - def get_user_id_by_email(db: Session) -> int: user = get_user_by_email(session=db, email=settings.EMAIL_TEST_USER) return user.id -def get_user_from_api_key(db: Session, api_key_headers: dict[str, str]) -> APIKeyPublic: - key_value = api_key_headers["X-API-KEY"] - api_key = get_api_key_by_value(db, api_key_value=key_value) - if api_key is None: - raise ValueError("Invalid API Key") - return api_key - - def get_non_existent_id(session: Session, model: Type[T]) -> int: result = session.exec(select(model.id).order_by(model.id.desc())).first() return (result or 0) + 1 @@ -89,22 +73,24 @@ def get_project(session: Session, name: str | None = None) -> Project: return project -def get_assistant(session: Session, name: str | None = None) -> Assistant: +def get_assistant( + session: Session, project_id: int | None = None, name: str | None = None +) -> Assistant: """ Retrieve an active assistant from the database. If a assistant name is provided, fetch the active assistant with that name. If no name is provided, fetch any random assistant. """ + filters = [Assistant.is_deleted == False] + + if project_id is not None: + filters.append(Assistant.project_id == project_id) + if name: - statement = ( - select(Assistant) - .where(Assistant.name == name, Assistant.is_deleted == False) - .limit(1) - ) - else: - statement = select(Assistant).where(Assistant.is_deleted == False).limit(1) + filters.append(Assistant.name == name) + statement = select(Assistant).where(*filters).limit(1) assistant = session.exec(statement).first() if not assistant: