Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# Adjust the import based on your actual structure
from app.models import Organization, Project, User, APIKey
from passlib.context import CryptContext # To hash passwords securely
from app.core.security import (
get_password_hash,
encrypt_api_key,
) # Add imports for encryption

# revision identifiers, used by Alembic.
revision = "77dc462dc6b0"
Expand Down Expand Up @@ -86,10 +90,15 @@ def create_user(session: Session, is_super: bool = True) -> User:

def create_api_key(session: Session, user: User, organization: Organization) -> APIKey:
"""Create and return an API key for the user and organization."""

token = secrets.token_urlsafe(32)
raw_key = "ApiKey " + token
encrypted_key = encrypt_api_key(raw_key) # Encrypt the raw key directly

api_key = APIKey(
user_id=user.id,
organization_id=organization.id,
key="ApiKey " + secrets.token_urlsafe(32),
key=encrypted_key,
)
session.add(api_key)
session.commit()
Expand Down
44 changes: 44 additions & 0 deletions backend/app/core/security.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from datetime import datetime, timedelta, timezone
from typing import Any
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC

import jwt
from passlib.context import CryptContext
Expand All @@ -9,6 +13,30 @@
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


# Generate a key for API key encryption
def get_encryption_key() -> bytes:
"""Generate a key for API key encryption using the app's secret key."""
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=settings.SECRET_KEY.encode(),
iterations=100000,
)
return base64.urlsafe_b64encode(kdf.derive(settings.SECRET_KEY.encode()))


# Initialize Fernet with our encryption key
_fernet = None


def get_fernet() -> Fernet:
"""Get a Fernet instance with the encryption key."""
global _fernet
if _fernet is None:
_fernet = Fernet(get_encryption_key())
return _fernet


ALGORITHM = "HS256"


Expand All @@ -25,3 +53,19 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:

def get_password_hash(password: str) -> str:
return pwd_context.hash(password)


def encrypt_api_key(api_key: str) -> str:
"""Encrypt an API key before storage."""
try:
return get_fernet().encrypt(api_key.encode()).decode()
except Exception as e:
raise ValueError(f"Failed to encrypt API key: {str(e)}")


def decrypt_api_key(encrypted_api_key: str) -> str:
"""Decrypt an API key when retrieving it."""
try:
return get_fernet().decrypt(encrypted_api_key.encode()).decode()
except Exception as e:
raise ValueError(f"Failed to decrypt API key: {str(e)}")
101 changes: 82 additions & 19 deletions backend/app/crud/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,43 @@
import secrets
from datetime import datetime
from sqlmodel import Session, select
from app.core.security import (
verify_password,
get_password_hash,
encrypt_api_key,
decrypt_api_key,
)
import base64
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from app.core import settings

from app.models import APIKey, APIKeyPublic
from app.models.api_key import APIKey, APIKeyPublic


def generate_api_key() -> tuple[str, str]:
"""Generate a new API key and its hash."""
raw_key = "ApiKey " + secrets.token_urlsafe(32)
hashed_key = get_password_hash(raw_key)
return raw_key, hashed_key


# Create API Key
def create_api_key(
session: Session, organization_id: uuid.UUID, user_id: uuid.UUID
) -> APIKeyPublic:
"""
Generates a new API key for an organization and associates it with a user.
Returns the API key details with the raw key (shown only once).
"""
# Generate raw key and its hash using the helper function
raw_key, hashed_key = generate_api_key()
encrypted_key = encrypt_api_key(
raw_key
) # Encrypt the raw key instead of hashed key

# Create API key record with encrypted raw key
api_key = APIKey(
key="ApiKey " + secrets.token_urlsafe(32),
key=encrypted_key, # Store the encrypted raw key
organization_id=organization_id,
user_id=user_id,
)
Expand All @@ -23,38 +47,59 @@
session.commit()
session.refresh(api_key)

return APIKeyPublic.model_validate(api_key)
# Set the raw key in the response (shown only once)
api_key_dict = api_key.model_dump()
api_key_dict["key"] = raw_key # Return the raw key to the user

return APIKeyPublic.model_validate(api_key_dict)


# Get API Key by ID
def get_api_key(session: Session, api_key_id: int) -> APIKeyPublic | None:
"""
Retrieves an API key by its ID if it exists and is not deleted.
Returns the API key in its original format.
"""
api_key = session.exec(
select(APIKey).where(APIKey.id == api_key_id, APIKey.is_deleted == False)
).first()

return APIKeyPublic.model_validate(api_key) if api_key else None
if api_key:
# Create a copy of the API key data
api_key_dict = api_key.model_dump()
# Decrypt the key
decrypted_key = decrypt_api_key(api_key.key)
api_key_dict["key"] = decrypted_key

return APIKeyPublic.model_validate(api_key_dict)
return None


# Get API Keys for an Organization
def get_api_keys_by_organization(
session: Session, organization_id: uuid.UUID
) -> list[APIKeyPublic]:
"""
Retrieves all active API keys associated with an organization.
Returns the API keys in their original format.
"""
api_keys = session.exec(
select(APIKey).where(
APIKey.organization_id == organization_id, APIKey.is_deleted == False
)
).all()

return [APIKeyPublic.model_validate(api_key) for api_key in api_keys]
raw_keys = []
for api_key in api_keys:
api_key_dict = api_key.model_dump()

decrypted_key = decrypt_api_key(api_key.key)

api_key_dict["key"] = decrypted_key

raw_keys.append(APIKeyPublic.model_validate(api_key_dict))

return raw_keys


# Soft Delete (Revoke) API Key
def delete_api_key(session: Session, api_key_id: int) -> None:
"""
Soft deletes (revokes) an API key by marking it as deleted.
Expand All @@ -71,24 +116,42 @@
session.commit()


def get_api_key_by_value(session: Session, api_key_value: str) -> APIKey | None:
def get_api_key_by_value(session: Session, api_key_value: str) -> APIKeyPublic | None:
"""
Retrieve an API Key record by its value.
Retrieve an API Key record by verifying the provided key against stored hashes.
Returns the API key in its original format.
"""
return session.exec(
select(APIKey).where(APIKey.key == api_key_value, APIKey.is_deleted == False)
).first()
# Get all active API keys
api_keys = session.exec(select(APIKey).where(APIKey.is_deleted == False)).all()

for api_key in api_keys:
decrypted_key = decrypt_api_key(api_key.key)
if api_key_value == decrypted_key:
api_key_dict = api_key.model_dump()

api_key_dict["key"] = decrypted_key

return APIKeyPublic.model_validate(api_key_dict)
return None

Check warning on line 135 in backend/app/crud/api_key.py

View check run for this annotation

Codecov / codecov/patch

backend/app/crud/api_key.py#L135

Added line #L135 was not covered by tests


def get_api_key_by_user_org(
session: Session, organization_id: int, user_id: str
db: Session, organization_id: int, user_id: int
) -> APIKey | None:
"""
Retrieve an API key for a specific user and organization.
"""
"""Get an API key by user and organization ID."""
statement = select(APIKey).where(
APIKey.organization_id == organization_id,
APIKey.user_id == user_id,
APIKey.is_deleted == False,
)
return session.exec(statement).first()
api_key = db.exec(statement).first()

if api_key:
api_key_dict = api_key.model_dump()

decrypted_key = decrypt_api_key(api_key.key)

api_key_dict["key"] = decrypted_key

return APIKey.model_validate(api_key_dict)
return None
2 changes: 2 additions & 0 deletions backend/app/tests/api/routes/test_project.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session
from app.core.security import decrypt_api_key, verify_password

from app.main import app
from app.core.config import settings
Expand All @@ -10,6 +11,7 @@
from app.tests.utils.utils import random_lower_string, random_email
from app.crud.project import create_project, get_project_by_id
from app.crud.organization import create_organization
from app.crud import api_key as api_key_crud

client = TestClient(app)

Expand Down
121 changes: 121 additions & 0 deletions backend/app/tests/core/test_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
from app.core.security import (
get_password_hash,
verify_password,
encrypt_api_key,
decrypt_api_key,
get_encryption_key,
)


def test_encrypt_decrypt_api_key():
"""Test that API key encryption and decryption works correctly."""
# Test data
test_key = "ApiKey test123456789"

# Encrypt the key
encrypted_key = encrypt_api_key(test_key)

# Verify encryption worked
assert encrypted_key is not None
assert encrypted_key != test_key
assert isinstance(encrypted_key, str)

# Decrypt the key
decrypted_key = decrypt_api_key(encrypted_key)

# Verify decryption worked
assert decrypted_key is not None
assert decrypted_key == test_key


def test_api_key_format_validation():
"""Test that API key format is validated correctly."""
# Test valid API key format
valid_key = "ApiKey test123456789"
encrypted_valid = encrypt_api_key(valid_key)
assert encrypted_valid is not None
assert decrypt_api_key(encrypted_valid) == valid_key

# Test invalid API key format (missing prefix)
invalid_key = "test123456789"
encrypted_invalid = encrypt_api_key(invalid_key)
assert encrypted_invalid is not None
assert decrypt_api_key(encrypted_invalid) == invalid_key


def test_encrypt_api_key_edge_cases():
"""Test edge cases for API key encryption."""
# Test empty string
empty_key = ""
encrypted_empty = encrypt_api_key(empty_key)
assert encrypted_empty is not None
assert decrypt_api_key(encrypted_empty) == empty_key

# Test whitespace only
whitespace_key = " "
encrypted_whitespace = encrypt_api_key(whitespace_key)
assert encrypted_whitespace is not None
assert decrypt_api_key(encrypted_whitespace) == whitespace_key

# Test very long input
long_key = "ApiKey " + "a" * 1000
encrypted_long = encrypt_api_key(long_key)
assert encrypted_long is not None
assert decrypt_api_key(encrypted_long) == long_key


def test_encrypt_api_key_type_validation():
"""Test type validation for API key encryption."""
# Test non-string inputs
invalid_inputs = [123, [], {}, True]
for invalid_input in invalid_inputs:
with pytest.raises(ValueError, match="Failed to encrypt API key"):
encrypt_api_key(invalid_input)


def test_encrypt_api_key_security():
"""Test security properties of API key encryption."""
# Test that same input produces different encrypted output
test_key = "ApiKey test123456789"
encrypted1 = encrypt_api_key(test_key)
encrypted2 = encrypt_api_key(test_key)
assert encrypted1 != encrypted2 # Different encrypted outputs for same input


def test_encrypt_api_key_error_handling():
"""Test error handling in encrypt_api_key."""
# Test with invalid input
with pytest.raises(ValueError, match="Failed to encrypt API key"):
encrypt_api_key(None)


def test_decrypt_api_key_error_handling():
"""Test error handling in decrypt_api_key."""
# Test with invalid input
with pytest.raises(ValueError, match="Failed to decrypt API key"):
decrypt_api_key(None)

# Test with various invalid encrypted data formats
invalid_encrypted_data = [
"invalid_encrypted_data", # Not base64
"not_a_base64_string", # Not base64
"a" * 44, # Wrong length
"!" * 44, # Invalid base64 chars
"aGVsbG8=", # Valid base64 but not encrypted
]
for invalid_data in invalid_encrypted_data:
with pytest.raises(ValueError, match="Failed to decrypt API key"):
decrypt_api_key(invalid_data)


def test_get_encryption_key():
"""Test that encryption key generation works correctly."""
# Get the encryption key
key = get_encryption_key()

# Verify the key
assert key is not None
assert isinstance(key, bytes)
# The key is base64 encoded, so it should be 44 bytes
assert len(key) == 44 # Base64 encoded Fernet key length is 44 bytes
Loading