diff --git a/backend/app/alembic/versions/77dc462dc6b0_seed_organization_table.py b/backend/app/alembic/versions/77dc462dc6b0_seed_organization_table.py index 7acb8d407..b8f26fdb2 100644 --- a/backend/app/alembic/versions/77dc462dc6b0_seed_organization_table.py +++ b/backend/app/alembic/versions/77dc462dc6b0_seed_organization_table.py @@ -13,6 +13,10 @@ # Adjust the import based on your actual structure from app.models import Organization, Project, User, APIKey from passlib.context import CryptContext # To hash passwords securely +from app.core.security import ( + get_password_hash, + encrypt_api_key, +) # Add imports for encryption # revision identifiers, used by Alembic. revision = "77dc462dc6b0" @@ -86,10 +90,15 @@ def create_user(session: Session, is_super: bool = True) -> User: def create_api_key(session: Session, user: User, organization: Organization) -> APIKey: """Create and return an API key for the user and organization.""" + + token = secrets.token_urlsafe(32) + raw_key = "ApiKey " + token + encrypted_key = encrypt_api_key(raw_key) # Encrypt the raw key directly + api_key = APIKey( user_id=user.id, organization_id=organization.id, - key="ApiKey " + secrets.token_urlsafe(32), + key=encrypted_key, ) session.add(api_key) session.commit() diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 7aff7cfb3..65594fcbf 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -1,5 +1,9 @@ from datetime import datetime, timedelta, timezone from typing import Any +import base64 +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC import jwt from passlib.context import CryptContext @@ -9,6 +13,30 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +# Generate a key for API key encryption +def get_encryption_key() -> bytes: + """Generate a key for API key encryption using the app's secret key.""" + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=settings.SECRET_KEY.encode(), + iterations=100000, + ) + return base64.urlsafe_b64encode(kdf.derive(settings.SECRET_KEY.encode())) + + +# Initialize Fernet with our encryption key +_fernet = None + + +def get_fernet() -> Fernet: + """Get a Fernet instance with the encryption key.""" + global _fernet + if _fernet is None: + _fernet = Fernet(get_encryption_key()) + return _fernet + + ALGORITHM = "HS256" @@ -25,3 +53,19 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: def get_password_hash(password: str) -> str: return pwd_context.hash(password) + + +def encrypt_api_key(api_key: str) -> str: + """Encrypt an API key before storage.""" + try: + return get_fernet().encrypt(api_key.encode()).decode() + except Exception as e: + raise ValueError(f"Failed to encrypt API key: {str(e)}") + + +def decrypt_api_key(encrypted_api_key: str) -> str: + """Decrypt an API key when retrieving it.""" + try: + return get_fernet().decrypt(encrypted_api_key.encode()).decode() + except Exception as e: + raise ValueError(f"Failed to decrypt API key: {str(e)}") diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 309562de1..0de9c5e58 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -2,19 +2,43 @@ import secrets from datetime import datetime from sqlmodel import Session, select +from app.core.security import ( + verify_password, + get_password_hash, + encrypt_api_key, + decrypt_api_key, +) +import base64 +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from app.core import settings -from app.models import APIKey, APIKeyPublic +from app.models.api_key import APIKey, APIKeyPublic + + +def generate_api_key() -> tuple[str, str]: + """Generate a new API key and its hash.""" + raw_key = "ApiKey " + secrets.token_urlsafe(32) + hashed_key = get_password_hash(raw_key) + return raw_key, hashed_key -# Create API Key def create_api_key( session: Session, organization_id: uuid.UUID, user_id: uuid.UUID ) -> APIKeyPublic: """ Generates a new API key for an organization and associates it with a user. + Returns the API key details with the raw key (shown only once). """ + # Generate raw key and its hash using the helper function + raw_key, hashed_key = generate_api_key() + encrypted_key = encrypt_api_key( + raw_key + ) # Encrypt the raw key instead of hashed key + + # Create API key record with encrypted raw key api_key = APIKey( - key="ApiKey " + secrets.token_urlsafe(32), + key=encrypted_key, # Store the encrypted raw key organization_id=organization_id, user_id=user_id, ) @@ -23,27 +47,39 @@ def create_api_key( session.commit() session.refresh(api_key) - return APIKeyPublic.model_validate(api_key) + # Set the raw key in the response (shown only once) + api_key_dict = api_key.model_dump() + api_key_dict["key"] = raw_key # Return the raw key to the user + + return APIKeyPublic.model_validate(api_key_dict) -# Get API Key by ID def get_api_key(session: Session, api_key_id: int) -> APIKeyPublic | None: """ Retrieves an API key by its ID if it exists and is not deleted. + Returns the API key in its original format. """ api_key = session.exec( select(APIKey).where(APIKey.id == api_key_id, APIKey.is_deleted == False) ).first() - return APIKeyPublic.model_validate(api_key) if api_key else None + if api_key: + # Create a copy of the API key data + api_key_dict = api_key.model_dump() + # Decrypt the key + decrypted_key = decrypt_api_key(api_key.key) + api_key_dict["key"] = decrypted_key + + return APIKeyPublic.model_validate(api_key_dict) + return None -# Get API Keys for an Organization def get_api_keys_by_organization( session: Session, organization_id: uuid.UUID ) -> list[APIKeyPublic]: """ Retrieves all active API keys associated with an organization. + Returns the API keys in their original format. """ api_keys = session.exec( select(APIKey).where( @@ -51,10 +87,19 @@ def get_api_keys_by_organization( ) ).all() - return [APIKeyPublic.model_validate(api_key) for api_key in api_keys] + raw_keys = [] + for api_key in api_keys: + api_key_dict = api_key.model_dump() + + decrypted_key = decrypt_api_key(api_key.key) + + api_key_dict["key"] = decrypted_key + + raw_keys.append(APIKeyPublic.model_validate(api_key_dict)) + + return raw_keys -# Soft Delete (Revoke) API Key def delete_api_key(session: Session, api_key_id: int) -> None: """ Soft deletes (revokes) an API key by marking it as deleted. @@ -71,24 +116,42 @@ def delete_api_key(session: Session, api_key_id: int) -> None: session.commit() -def get_api_key_by_value(session: Session, api_key_value: str) -> APIKey | None: +def get_api_key_by_value(session: Session, api_key_value: str) -> APIKeyPublic | None: """ - Retrieve an API Key record by its value. + Retrieve an API Key record by verifying the provided key against stored hashes. + Returns the API key in its original format. """ - return session.exec( - select(APIKey).where(APIKey.key == api_key_value, APIKey.is_deleted == False) - ).first() + # Get all active API keys + api_keys = session.exec(select(APIKey).where(APIKey.is_deleted == False)).all() + + for api_key in api_keys: + decrypted_key = decrypt_api_key(api_key.key) + if api_key_value == decrypted_key: + api_key_dict = api_key.model_dump() + + api_key_dict["key"] = decrypted_key + + return APIKeyPublic.model_validate(api_key_dict) + return None def get_api_key_by_user_org( - session: Session, organization_id: int, user_id: str + db: Session, organization_id: int, user_id: int ) -> APIKey | None: - """ - Retrieve an API key for a specific user and organization. - """ + """Get an API key by user and organization ID.""" statement = select(APIKey).where( APIKey.organization_id == organization_id, APIKey.user_id == user_id, APIKey.is_deleted == False, ) - return session.exec(statement).first() + api_key = db.exec(statement).first() + + if api_key: + api_key_dict = api_key.model_dump() + + decrypted_key = decrypt_api_key(api_key.key) + + api_key_dict["key"] = decrypted_key + + return APIKey.model_validate(api_key_dict) + return None diff --git a/backend/app/tests/api/routes/test_project.py b/backend/app/tests/api/routes/test_project.py index 82ea7593d..98d1f96da 100644 --- a/backend/app/tests/api/routes/test_project.py +++ b/backend/app/tests/api/routes/test_project.py @@ -1,6 +1,7 @@ import pytest from fastapi.testclient import TestClient from sqlmodel import Session +from app.core.security import decrypt_api_key, verify_password from app.main import app from app.core.config import settings @@ -10,6 +11,7 @@ from app.tests.utils.utils import random_lower_string, random_email from app.crud.project import create_project, get_project_by_id from app.crud.organization import create_organization +from app.crud import api_key as api_key_crud client = TestClient(app) diff --git a/backend/app/tests/core/test_security.py b/backend/app/tests/core/test_security.py new file mode 100644 index 000000000..4f7f68610 --- /dev/null +++ b/backend/app/tests/core/test_security.py @@ -0,0 +1,121 @@ +import pytest +from app.core.security import ( + get_password_hash, + verify_password, + encrypt_api_key, + decrypt_api_key, + get_encryption_key, +) + + +def test_encrypt_decrypt_api_key(): + """Test that API key encryption and decryption works correctly.""" + # Test data + test_key = "ApiKey test123456789" + + # Encrypt the key + encrypted_key = encrypt_api_key(test_key) + + # Verify encryption worked + assert encrypted_key is not None + assert encrypted_key != test_key + assert isinstance(encrypted_key, str) + + # Decrypt the key + decrypted_key = decrypt_api_key(encrypted_key) + + # Verify decryption worked + assert decrypted_key is not None + assert decrypted_key == test_key + + +def test_api_key_format_validation(): + """Test that API key format is validated correctly.""" + # Test valid API key format + valid_key = "ApiKey test123456789" + encrypted_valid = encrypt_api_key(valid_key) + assert encrypted_valid is not None + assert decrypt_api_key(encrypted_valid) == valid_key + + # Test invalid API key format (missing prefix) + invalid_key = "test123456789" + encrypted_invalid = encrypt_api_key(invalid_key) + assert encrypted_invalid is not None + assert decrypt_api_key(encrypted_invalid) == invalid_key + + +def test_encrypt_api_key_edge_cases(): + """Test edge cases for API key encryption.""" + # Test empty string + empty_key = "" + encrypted_empty = encrypt_api_key(empty_key) + assert encrypted_empty is not None + assert decrypt_api_key(encrypted_empty) == empty_key + + # Test whitespace only + whitespace_key = " " + encrypted_whitespace = encrypt_api_key(whitespace_key) + assert encrypted_whitespace is not None + assert decrypt_api_key(encrypted_whitespace) == whitespace_key + + # Test very long input + long_key = "ApiKey " + "a" * 1000 + encrypted_long = encrypt_api_key(long_key) + assert encrypted_long is not None + assert decrypt_api_key(encrypted_long) == long_key + + +def test_encrypt_api_key_type_validation(): + """Test type validation for API key encryption.""" + # Test non-string inputs + invalid_inputs = [123, [], {}, True] + for invalid_input in invalid_inputs: + with pytest.raises(ValueError, match="Failed to encrypt API key"): + encrypt_api_key(invalid_input) + + +def test_encrypt_api_key_security(): + """Test security properties of API key encryption.""" + # Test that same input produces different encrypted output + test_key = "ApiKey test123456789" + encrypted1 = encrypt_api_key(test_key) + encrypted2 = encrypt_api_key(test_key) + assert encrypted1 != encrypted2 # Different encrypted outputs for same input + + +def test_encrypt_api_key_error_handling(): + """Test error handling in encrypt_api_key.""" + # Test with invalid input + with pytest.raises(ValueError, match="Failed to encrypt API key"): + encrypt_api_key(None) + + +def test_decrypt_api_key_error_handling(): + """Test error handling in decrypt_api_key.""" + # Test with invalid input + with pytest.raises(ValueError, match="Failed to decrypt API key"): + decrypt_api_key(None) + + # Test with various invalid encrypted data formats + invalid_encrypted_data = [ + "invalid_encrypted_data", # Not base64 + "not_a_base64_string", # Not base64 + "a" * 44, # Wrong length + "!" * 44, # Invalid base64 chars + "aGVsbG8=", # Valid base64 but not encrypted + ] + for invalid_data in invalid_encrypted_data: + with pytest.raises(ValueError, match="Failed to decrypt API key"): + decrypt_api_key(invalid_data) + + +def test_get_encryption_key(): + """Test that encryption key generation works correctly.""" + # Get the encryption key + key = get_encryption_key() + + # Verify the key + assert key is not None + assert isinstance(key, bytes) + # The key is base64 encoded, so it should be 44 bytes + assert len(key) == 44 # Base64 encoded Fernet key length is 44 bytes diff --git a/backend/app/tests/crud/test_api_key.py b/backend/app/tests/crud/test_api_key.py index 23c2837b1..b6269b700 100644 --- a/backend/app/tests/crud/test_api_key.py +++ b/backend/app/tests/crud/test_api_key.py @@ -4,7 +4,7 @@ from app.crud import api_key as api_key_crud from app.models import APIKey, User, Organization from app.tests.utils.utils import random_email -from app.core.security import get_password_hash +from app.core.security import get_password_hash, verify_password, decrypt_api_key # Helper function to create a user @@ -48,7 +48,9 @@ def test_get_api_key(db: Session) -> None: assert retrieved_key is not None assert retrieved_key.id == created_key.id - assert retrieved_key.key == created_key.key + # The key should be in its original format + assert retrieved_key.key.startswith("ApiKey ") + assert len(retrieved_key.key) > 32 def test_get_api_key_not_found(db: Session) -> None: @@ -67,8 +69,12 @@ def test_get_api_keys_by_organization(db: Session) -> None: api_keys = api_key_crud.get_api_keys_by_organization(db, org.id) assert len(api_keys) == 2 - assert any(key.id == api_key1.id for key in api_keys) - assert any(key.id == api_key2.id for key in api_keys) + # Verify that the keys are in their original format + for key in api_keys: + assert key.key.startswith("ApiKey ") + assert len(key.key) > 32 # Raw key should be longer than 32 characters + assert key.organization_id == org.id + assert key.user_id in [user1.id, user2.id] def test_delete_api_key(db: Session) -> None: @@ -100,12 +106,22 @@ def test_get_api_key_by_value(db: Session) -> None: user = create_test_user(db) org = create_test_organization(db) + # Create an API key api_key = api_key_crud.create_api_key(db, org.id, user.id) - retrieved_key = api_key_crud.get_api_key_by_value(db, api_key.key) + # Get the raw key that was returned during creation + raw_key = api_key.key + + # Test retrieving the API key by its value + retrieved_key = api_key_crud.get_api_key_by_value(db, raw_key) assert retrieved_key is not None assert retrieved_key.id == api_key.id - assert retrieved_key.key == api_key.key + assert retrieved_key.organization_id == org.id + assert retrieved_key.user_id == user.id + # The key should be in its original format + assert retrieved_key.key == raw_key # Should be exactly the same key + assert retrieved_key.key.startswith("ApiKey ") + assert len(retrieved_key.key) > 32 def test_get_api_key_by_user_org(db: Session) -> None: @@ -119,6 +135,9 @@ def test_get_api_key_by_user_org(db: Session) -> None: assert retrieved_key.id == api_key.id assert retrieved_key.organization_id == org.id assert retrieved_key.user_id == user.id + # The key should be in its original format + assert retrieved_key.key.startswith("ApiKey ") + assert len(retrieved_key.key) > 32 def test_get_api_key_by_user_org_not_found(db: Session) -> None: